Skip to content

Commit 3beafc8

Browse files
Make FastLapackInterface.jl an extension as well
1 parent 3bff586 commit 3beafc8

File tree

7 files changed

+123
-111
lines changed

7 files changed

+123
-111
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
99
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
1010
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1111
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
12-
FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
1312
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1413
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
1514
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
@@ -35,6 +34,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3534
CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e"
3635
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
3736
FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e"
37+
FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
3838
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
3939
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
4040
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
@@ -52,6 +52,7 @@ LinearSolveCUDAExt = "CUDA"
5252
LinearSolveCUDSSExt = "CUDSS"
5353
LinearSolveEnzymeExt = "EnzymeCore"
5454
LinearSolveFastAlmostBandedMatricesExt = "FastAlmostBandedMatrices"
55+
LinearSolveFastLapackInterfaceExt = "FastLapackInterface"
5556
LinearSolveHYPREExt = "HYPRE"
5657
LinearSolveIterativeSolversExt = "IterativeSolvers"
5758
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
@@ -126,6 +127,7 @@ BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
126127
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
127128
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
128129
FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e"
130+
FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
129131
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
130132
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
131133
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
@@ -150,4 +152,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
150152
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
151153

152154
[targets]
153-
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "KrylovPreconditioners", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote", "RecursiveFactorization", "Sparspak"]
155+
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "KrylovPreconditioners", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote", "RecursiveFactorization", "Sparspak", "FastLapackInterface"]

docs/src/solvers/solvers.md

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ use `Krylov_GMRES()`.
8484

8585
### RecursiveFactorization.jl
8686

87+
!!! note
88+
89+
Using this solver requires adding the package RecursiveFactorization.jl, i.e. `using RecursiveFactorization`
90+
8791
```@docs
8892
RFLUFactorization
8993
```
@@ -123,7 +127,13 @@ FastLapackInterface.jl is a package that allows for a lower-level interface to t
123127
calls to allow for preallocating workspaces to decrease the overhead of the wrappers.
124128
LinearSolve.jl provides a wrapper to these routines in a way where an initialized solver
125129
has a non-allocating LU factorization. In theory, this post-initialized solve should always
126-
be faster than the Base.LinearAlgebra version.
130+
be faster than the Base.LinearAlgebra version. In practice, with the way we wrap the solvers,
131+
we do not see a performance benefit and in fact benchmarks tend to show this inhibits
132+
performance.
133+
134+
!!! note
135+
136+
Using this solver requires adding the package FastLapackInterface.jl, i.e. `using FastLapackInterface`
127137

128138
```@docs
129139
FastLUFactorization
@@ -157,10 +167,6 @@ KrylovJL
157167

158168
### MKL.jl
159169

160-
!!! note
161-
162-
Using this solver requires adding the package MKL_jll.jl, i.e. `using MKL_jll`
163-
164170
```@docs
165171
MKLLUFactorization
166172
```
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
module LinearSolveFastLapackInterfaceExt
2+
3+
using LinearSolve, LinearAlgebra
4+
using FastLapackInterface
5+
6+
struct WorkspaceAndFactors{W, F}
7+
workspace::W
8+
factors::F
9+
end
10+
11+
function LinearSolve.init_cacheval(::FastLUFactorization, A, b, u, Pl, Pr,
12+
maxiters::Int, abstol, reltol, verbose::Bool,
13+
assumptions::OperatorAssumptions)
14+
ws = LUWs(A)
15+
return WorkspaceAndFactors(ws, LinearSolve.ArrayInterface.lu_instance(convert(AbstractMatrix, A)))
16+
end
17+
18+
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::FastLUFactorization; kwargs...)
19+
A = cache.A
20+
A = convert(AbstractMatrix, A)
21+
ws_and_fact = LinearSolve.@get_cacheval(cache, :FastLUFactorization)
22+
if cache.isfresh
23+
# we will fail here if A is a different *size* than in a previous version of the same cache.
24+
# it may instead be desirable to resize the workspace.
25+
LinearSolve.@set! ws_and_fact.factors = LinearAlgebra.LU(LAPACK.getrf!(ws_and_fact.workspace,
26+
A)...)
27+
cache.cacheval = ws_and_fact
28+
cache.isfresh = false
29+
end
30+
y = ldiv!(cache.u, cache.cacheval.factors, cache.b)
31+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
32+
end
33+
34+
function LinearSolve.init_cacheval(alg::FastQRFactorization{NoPivot}, A::AbstractMatrix, b, u, Pl, Pr,
35+
maxiters::Int, abstol, reltol, verbose::Bool,
36+
assumptions::OperatorAssumptions)
37+
ws = QRWYWs(A; blocksize = alg.blocksize)
38+
return WorkspaceAndFactors(ws,
39+
LinearSolve.ArrayInterface.qr_instance(convert(AbstractMatrix, A)))
40+
end
41+
function LinearSolve.init_cacheval(::FastQRFactorization{ColumnNorm}, A::AbstractMatrix, b, u, Pl, Pr,
42+
maxiters::Int, abstol, reltol, verbose::Bool,
43+
assumptions::OperatorAssumptions)
44+
ws = QRpWs(A)
45+
return WorkspaceAndFactors(ws,
46+
LinearSolve.ArrayInterface.qr_instance(convert(AbstractMatrix, A)))
47+
end
48+
49+
function LinearSolve.init_cacheval(alg::FastQRFactorization, A, b, u, Pl, Pr,
50+
maxiters::Int, abstol, reltol, verbose::Bool,
51+
assumptions::OperatorAssumptions)
52+
return init_cacheval(alg, convert(AbstractMatrix, A), b, u, Pl, Pr,
53+
maxiters::Int, abstol, reltol, verbose::Bool,
54+
assumptions::OperatorAssumptions)
55+
end
56+
57+
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::FastQRFactorization{P};
58+
kwargs...) where {P}
59+
A = cache.A
60+
A = convert(AbstractMatrix, A)
61+
ws_and_fact = LinearSolve.@get_cacheval(cache, :FastQRFactorization)
62+
if cache.isfresh
63+
# we will fail here if A is a different *size* than in a previous version of the same cache.
64+
# it may instead be desirable to resize the workspace.
65+
if P === NoPivot
66+
LinearSolve.@set! ws_and_fact.factors = LinearAlgebra.QRCompactWY(LAPACK.geqrt!(
67+
ws_and_fact.workspace,
68+
A)...)
69+
else
70+
LinearSolve.@set! ws_and_fact.factors = LinearAlgebra.QRPivoted(LAPACK.geqp3!(
71+
ws_and_fact.workspace,
72+
A)...)
73+
end
74+
cache.cacheval = ws_and_fact
75+
cache.isfresh = false
76+
end
77+
y = ldiv!(cache.u, cache.cacheval.factors, cache.b)
78+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
79+
end
80+
81+
82+
end

src/LinearSolve.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ using SciMLOperators
1414
using SciMLOperators: AbstractSciMLOperator, IdentityOperator
1515
using Setfield
1616
using UnPack
17-
using FastLapackInterface
1817
using DocStringExtensions
1918
using EnumX
2019
using Markdown

src/extension_algs.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,31 @@ function RFLUFactorization(; pivot = Val(true), thread = Val(true), throwerror =
107107
RFLUFactorization(pivot, thread; throwerror)
108108
end
109109

110+
# There's no options like pivot here.
111+
# But I'm not sure it makes sense as a GenericFactorization
112+
# since it just uses `LAPACK.getrf!`.
113+
"""
114+
`FastLUFactorization()`
115+
116+
The FastLapackInterface.jl version of the LU factorization. Notably,
117+
this version does not allow for choice of pivoting method.
118+
"""
119+
struct FastLUFactorization <: AbstractDenseFactorization end
120+
121+
"""
122+
`FastQRFactorization()`
123+
124+
The FastLapackInterface.jl version of the QR factorization.
125+
"""
126+
struct FastQRFactorization{P} <: AbstractDenseFactorization
127+
pivot::P
128+
blocksize::Int
129+
end
130+
131+
# is 36 or 16 better here? LinearAlgebra and FastLapackInterface use 36,
132+
# but QRFactorization uses 16.
133+
FastQRFactorization() = FastQRFactorization(NoPivot(), 36)
134+
110135
"""
111136
```julia
112137
MKLPardisoFactorize(; nprocs::Union{Int, Nothing} = nothing,

src/factorization.jl

Lines changed: 0 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,108 +1026,6 @@ function SciMLBase.solve!(cache::LinearCache, alg::DiagonalFactorization;
10261026
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
10271027
end
10281028

1029-
## FastLAPACKFactorizations
1030-
1031-
struct WorkspaceAndFactors{W, F}
1032-
workspace::W
1033-
factors::F
1034-
end
1035-
1036-
# There's no options like pivot here.
1037-
# But I'm not sure it makes sense as a GenericFactorization
1038-
# since it just uses `LAPACK.getrf!`.
1039-
"""
1040-
`FastLUFactorization()`
1041-
1042-
The FastLapackInterface.jl version of the LU factorization. Notably,
1043-
this version does not allow for choice of pivoting method.
1044-
"""
1045-
struct FastLUFactorization <: AbstractDenseFactorization end
1046-
1047-
function init_cacheval(::FastLUFactorization, A, b, u, Pl, Pr,
1048-
maxiters::Int, abstol, reltol, verbose::Bool,
1049-
assumptions::OperatorAssumptions)
1050-
ws = LUWs(A)
1051-
return WorkspaceAndFactors(ws, ArrayInterface.lu_instance(convert(AbstractMatrix, A)))
1052-
end
1053-
1054-
function SciMLBase.solve!(cache::LinearCache, alg::FastLUFactorization; kwargs...)
1055-
A = cache.A
1056-
A = convert(AbstractMatrix, A)
1057-
ws_and_fact = @get_cacheval(cache, :FastLUFactorization)
1058-
if cache.isfresh
1059-
# we will fail here if A is a different *size* than in a previous version of the same cache.
1060-
# it may instead be desirable to resize the workspace.
1061-
@set! ws_and_fact.factors = LinearAlgebra.LU(LAPACK.getrf!(ws_and_fact.workspace,
1062-
A)...)
1063-
cache.cacheval = ws_and_fact
1064-
cache.isfresh = false
1065-
end
1066-
y = ldiv!(cache.u, cache.cacheval.factors, cache.b)
1067-
SciMLBase.build_linear_solution(alg, y, nothing, cache)
1068-
end
1069-
1070-
"""
1071-
`FastQRFactorization()`
1072-
1073-
The FastLapackInterface.jl version of the QR factorization.
1074-
"""
1075-
struct FastQRFactorization{P} <: AbstractDenseFactorization
1076-
pivot::P
1077-
blocksize::Int
1078-
end
1079-
1080-
# is 36 or 16 better here? LinearAlgebra and FastLapackInterface use 36,
1081-
# but QRFactorization uses 16.
1082-
FastQRFactorization() = FastQRFactorization(NoPivot(), 36)
1083-
1084-
function init_cacheval(alg::FastQRFactorization{NoPivot}, A::AbstractMatrix, b, u, Pl, Pr,
1085-
maxiters::Int, abstol, reltol, verbose::Bool,
1086-
assumptions::OperatorAssumptions)
1087-
ws = QRWYWs(A; blocksize = alg.blocksize)
1088-
return WorkspaceAndFactors(ws,
1089-
ArrayInterface.qr_instance(convert(AbstractMatrix, A)))
1090-
end
1091-
function init_cacheval(::FastQRFactorization{ColumnNorm}, A::AbstractMatrix, b, u, Pl, Pr,
1092-
maxiters::Int, abstol, reltol, verbose::Bool,
1093-
assumptions::OperatorAssumptions)
1094-
ws = QRpWs(A)
1095-
return WorkspaceAndFactors(ws,
1096-
ArrayInterface.qr_instance(convert(AbstractMatrix, A)))
1097-
end
1098-
1099-
function init_cacheval(alg::FastQRFactorization, A, b, u, Pl, Pr,
1100-
maxiters::Int, abstol, reltol, verbose::Bool,
1101-
assumptions::OperatorAssumptions)
1102-
return init_cacheval(alg, convert(AbstractMatrix, A), b, u, Pl, Pr,
1103-
maxiters::Int, abstol, reltol, verbose::Bool,
1104-
assumptions::OperatorAssumptions)
1105-
end
1106-
1107-
function SciMLBase.solve!(cache::LinearCache, alg::FastQRFactorization{P};
1108-
kwargs...) where {P}
1109-
A = cache.A
1110-
A = convert(AbstractMatrix, A)
1111-
ws_and_fact = @get_cacheval(cache, :FastQRFactorization)
1112-
if cache.isfresh
1113-
# we will fail here if A is a different *size* than in a previous version of the same cache.
1114-
# it may instead be desirable to resize the workspace.
1115-
if P === NoPivot
1116-
@set! ws_and_fact.factors = LinearAlgebra.QRCompactWY(LAPACK.geqrt!(
1117-
ws_and_fact.workspace,
1118-
A)...)
1119-
else
1120-
@set! ws_and_fact.factors = LinearAlgebra.QRPivoted(LAPACK.geqp3!(
1121-
ws_and_fact.workspace,
1122-
A)...)
1123-
end
1124-
cache.cacheval = ws_and_fact
1125-
cache.isfresh = false
1126-
end
1127-
y = ldiv!(cache.u, cache.cacheval.factors, cache.b)
1128-
SciMLBase.build_linear_solution(alg, y, nothing, cache)
1129-
end
1130-
11311029
## SparspakFactorization is here since it's MIT licensed, not GPL
11321030

11331031
"""

test/basictests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using LinearSolve, LinearAlgebra, SparseArrays, MultiFloats, ForwardDiff
2-
using SciMLOperators, RecursiveFactorization, Sparspak
2+
using SciMLOperators, RecursiveFactorization, Sparspak, FastLapackInterface
33
using IterativeSolvers, KrylovKit, MKL_jll, KrylovPreconditioners
44
using Test
55
import Random

0 commit comments

Comments
 (0)