Skip to content

Commit 93665b7

Browse files
committed
port from laptop, basic func, no tests
1 parent e4e450d commit 93665b7

File tree

3 files changed

+82
-1
lines changed

3 files changed

+82
-1
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "1.20.0"
66
[deps]
77
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"
88
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
9+
FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
910
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1011
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
1112
KLU = "ef3ab10e-7fda-4108-b977-705223b18434"

src/LinearSolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ using Setfield
1212
using UnPack
1313
using SuiteSparse
1414
using KLU
15+
using FastLapackInterface
1516
using DocStringExtensions
16-
1717
import GPUArraysCore
1818

1919
# wrap

src/factorization.jl

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,3 +346,83 @@ function init_cacheval(alg::GenericFactorization{<:RFWrapper},
346346
abstol, reltol, verbose)
347347
ArrayInterfaceCore.lu_instance(convert(AbstractMatrix, A))
348348
end
349+
350+
351+
## FastLAPACKFactorizations
352+
353+
struct WorkspaceAndFactors{W, F}
354+
workspace::W
355+
factors::F
356+
end
357+
358+
# There's no options like pivot here.
359+
# But I'm not sure it makes sense as a GenericFactorization
360+
# since it just uses `LAPACK.getrf!`.
361+
struct FastLUFactorization <: AbstractFactorization end
362+
363+
function init_cacheval(::FastLUFactorization, A, b, u, Pl, Pr,
364+
maxiters, abstol, reltol, verbose)
365+
ws = LUWs(A)
366+
return WorkspaceAndFactors(ws, LinearAlgebra.LU(LAPACK.getrf!(ws, A)...))
367+
end
368+
369+
function SciMLBase.solve(cache::LinearCache, alg::FastLUFactorization)
370+
A = cache.A
371+
A = convert(AbstractMatrix, A)
372+
ws_and_fact = cache.cacheval
373+
if cache.isfresh
374+
# we will fail here if A is a different *size* than in a previous version of the same cache.
375+
# it may instead be desirable to resize the workspace.
376+
@set! ws_and_fact.factors = LinearAlgebra.LU(LAPACK.getrf!(ws_and_fact.workspace, A)...)
377+
cache = set_cacheval(cache, ws_and_fact)
378+
end
379+
y = ldiv!(cache.u, cache.cacheval.factors, cache.b)
380+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
381+
end
382+
383+
struct FastQRFactorization{P} <: AbstractFactorization
384+
pivot::P
385+
blocksize::Int
386+
end
387+
388+
function FastQRFactorization()
389+
pivot = @static if VERSION < v"1.7beta"
390+
Val(false)
391+
else
392+
NoPivot()
393+
end
394+
FastQRFactorization(pivot, 36) # is 36 or 16 better here? LinearAlgebra and FastLapackInterface use 36,
395+
# but QRFactorization uses 16.
396+
end
397+
398+
function init_cacheval(alg::FastQRFactorization{NoPivot}, A, b, u, Pl, Pr,
399+
maxiters, abstol, reltol, verbose)
400+
ws = QRWYWs(A; blocksize = alg.blocksize)
401+
return WorkspaceAndFactors(ws, LinearAlgebra.QRCompactWY(LAPACK.geqrt!(ws, A)...))
402+
end
403+
404+
function init_cacheval(::FastQRFactorization{ColumnNorm}, A, b, u, Pl, Pr,
405+
maxiters, abstol, reltol, verbose)
406+
ws = QRpWs(A)
407+
return WorkspaceAndFactors(ws, LinearAlgebra.QRPivoted(LAPACK.geqp3!(ws, A)...))
408+
end
409+
410+
function SciMLBase.solve(cache::LinearCache, alg::FastQRFactorization{P}) where {P}
411+
A = cache.A
412+
A = convert(AbstractMatrix, A)
413+
ws_and_fact = cache.cacheval
414+
if cache.isfresh
415+
# we will fail here if A is a different *size* than in a previous version of the same cache.
416+
# it may instead be desirable to resize the workspace.
417+
if P === NoPivot
418+
@set! ws_and_fact.factors = LinearAlgebra.QRCompactWY(LAPACK.geqrt!(ws_and_fact.workspace, A)...)
419+
elseif P === ColumnNorm
420+
@set! ws_and_fact.factors = LinearAlgebra.QRPivoted(LAPACK.geqp3!(ws_and_fact.workspace, A)...)
421+
else
422+
error("No FastLAPACK Factorization defined for $P")
423+
end
424+
cache = set_cacheval(cache, ws_and_fact)
425+
end
426+
y = ldiv!(cache.u, cache.cacheval.factors, cache.b)
427+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
428+
end

0 commit comments

Comments
 (0)