Skip to content

Commit aeedba4

Browse files
Merge pull request #153 from Wimmerer/fastlapack
FastLAPACK
2 parents e4e450d + 44a9a33 commit aeedba4

File tree

4 files changed

+118
-1
lines changed

4 files changed

+118
-1
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "1.20.0"
66
[deps]
77
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"
88
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
9+
FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
910
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1011
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
1112
KLU = "ef3ab10e-7fda-4108-b977-705223b18434"

src/LinearSolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ using Setfield
1212
using UnPack
1313
using SuiteSparse
1414
using KLU
15+
using FastLapackInterface
1516
using DocStringExtensions
16-
1717
import GPUArraysCore
1818

1919
# wrap

src/factorization.jl

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,3 +346,101 @@ function init_cacheval(alg::GenericFactorization{<:RFWrapper},
346346
abstol, reltol, verbose)
347347
ArrayInterfaceCore.lu_instance(convert(AbstractMatrix, A))
348348
end
349+
350+
## FastLAPACKFactorizations
351+
352+
struct WorkspaceAndFactors{W, F}
353+
workspace::W
354+
factors::F
355+
end
356+
357+
# There's no options like pivot here.
358+
# But I'm not sure it makes sense as a GenericFactorization
359+
# since it just uses `LAPACK.getrf!`.
360+
struct FastLUFactorization <: AbstractFactorization end
361+
362+
function init_cacheval(::FastLUFactorization, A, b, u, Pl, Pr,
363+
maxiters, abstol, reltol, verbose)
364+
ws = LUWs(A)
365+
return WorkspaceAndFactors(ws, LinearAlgebra.LU(LAPACK.getrf!(ws, A)...))
366+
end
367+
368+
function SciMLBase.solve(cache::LinearCache, alg::FastLUFactorization)
369+
A = cache.A
370+
A = convert(AbstractMatrix, A)
371+
ws_and_fact = cache.cacheval
372+
if cache.isfresh
373+
# we will fail here if A is a different *size* than in a previous version of the same cache.
374+
# it may instead be desirable to resize the workspace.
375+
@set! ws_and_fact.factors = LinearAlgebra.LU(LAPACK.getrf!(ws_and_fact.workspace,
376+
A)...)
377+
cache = set_cacheval(cache, ws_and_fact)
378+
end
379+
y = ldiv!(cache.u, cache.cacheval.factors, cache.b)
380+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
381+
end
382+
383+
struct FastQRFactorization{P} <: AbstractFactorization
384+
pivot::P
385+
blocksize::Int
386+
end
387+
388+
function FastQRFactorization()
389+
pivot = @static if VERSION < v"1.7beta"
390+
Val(false)
391+
else
392+
NoPivot()
393+
end
394+
FastQRFactorization(pivot, 36) # is 36 or 16 better here? LinearAlgebra and FastLapackInterface use 36,
395+
# but QRFactorization uses 16.
396+
end
397+
398+
@static if VERSION < v"1.7beta"
399+
function init_cacheval(alg::FastQRFactorization{Val{false}}, A, b, u, Pl, Pr,
400+
maxiters, abstol, reltol, verbose)
401+
ws = QRWYWs(A; blocksize = alg.blocksize)
402+
return WorkspaceAndFactors(ws, LinearAlgebra.QRCompactWY(LAPACK.geqrt!(ws, A)...))
403+
end
404+
405+
function init_cacheval(::FastQRFactorization{Val{true}}, A, b, u, Pl, Pr,
406+
maxiters, abstol, reltol, verbose)
407+
ws = QRpWs(A)
408+
return WorkspaceAndFactors(ws, LinearAlgebra.QRPivoted(LAPACK.geqp3!(ws, A)...))
409+
end
410+
else
411+
function init_cacheval(alg::FastQRFactorization{NoPivot}, A, b, u, Pl, Pr,
412+
maxiters, abstol, reltol, verbose)
413+
ws = QRWYWs(A; blocksize = alg.blocksize)
414+
return WorkspaceAndFactors(ws, LinearAlgebra.QRCompactWY(LAPACK.geqrt!(ws, A)...))
415+
end
416+
function init_cacheval(::FastQRFactorization{ColumnNorm}, A, b, u, Pl, Pr,
417+
maxiters, abstol, reltol, verbose)
418+
ws = QRpWs(A)
419+
return WorkspaceAndFactors(ws, LinearAlgebra.QRPivoted(LAPACK.geqp3!(ws, A)...))
420+
end
421+
end
422+
423+
function SciMLBase.solve(cache::LinearCache, alg::FastQRFactorization{P}) where {P}
424+
A = cache.A
425+
A = convert(AbstractMatrix, A)
426+
ws_and_fact = cache.cacheval
427+
if cache.isfresh
428+
# we will fail here if A is a different *size* than in a previous version of the same cache.
429+
# it may instead be desirable to resize the workspace.
430+
nopivot = @static if VERSION < v"1.7beta"
431+
Val{false}
432+
else
433+
NoPivot
434+
end
435+
if P === nopivot
436+
@set! ws_and_fact.factors = LinearAlgebra.QRCompactWY(LAPACK.geqrt!(ws_and_fact.workspace,
437+
A)...)
438+
else
439+
@set! ws_and_fact.factors = LinearAlgebra.QRPivoted(LAPACK.geqp3!(ws_and_fact.workspace,
440+
A)...)
441+
end
442+
cache = set_cacheval(cache, ws_and_fact)
443+
end
444+
y = ldiv!(cache.u, cache.cacheval.factors, cache.b)
445+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
446+
end

test/basictests.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,24 @@ end
119119
@test_throws ArgumentError solve(cache)
120120
end
121121

122+
@testset "FastLAPACK Factorizations" begin
123+
A1 = A / 1
124+
b1 = rand(n)
125+
x1 = zero(b)
126+
A2 = A / 2
127+
b2 = rand(n)
128+
x2 = zero(b)
129+
130+
prob1 = LinearProblem(A1, b1; u0 = x1)
131+
prob2 = LinearProblem(A2, b2; u0 = x2)
132+
test_interface(LinearSolve.FastLUFactorization(), prob1, prob2)
133+
test_interface(LinearSolve.FastQRFactorization(), prob1, prob2)
134+
135+
# TODO: Resizing tests. Upstream doesn't currently support it.
136+
# Need to be absolutely certain we never segfault with incorrect
137+
# ws sizes.
138+
end
139+
122140
@testset "Concrete Factorizations" begin for alg in (LUFactorization(),
123141
QRFactorization(),
124142
SVDFactorization(),

0 commit comments

Comments
 (0)