Skip to content

Commit 03edc78

Browse files
WIP: Make more libraries into an extension?
I think we should be considering some kind of split to LinearSolveCore vs LinearSolve, a la OrdinaryDiffEq. The default algorithm requires RecursiveFactorization and MKL, but we could definitely make a version that doesn't have these deps. But at least sparse stuff can go into an extension since those always dispatch on the existance of sparse matrices. This PR isn't quite complete yet, more needs to be removed with sparse, but it shows the right direction.
1 parent 272808b commit 03edc78

File tree

8 files changed

+377
-334
lines changed

8 files changed

+377
-334
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
1212
FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
1313
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1414
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
15-
KLU = "ef3ab10e-7fda-4108-b977-705223b18434"
1615
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
1716
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
1817
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
@@ -26,7 +25,6 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2625
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
2726
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
2827
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
29-
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
3028
Sparspak = "e56a9233-b9d6-4f03-8d0f-1825330902ac"
3129
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
3230
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
@@ -40,11 +38,13 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
4038
FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e"
4139
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
4240
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
41+
KLU = "ef3ab10e-7fda-4108-b977-705223b18434"
4342
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
4443
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
4544
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
4645
Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
4746
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
47+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
4848

4949
[extensions]
5050
LinearSolveBandedMatricesExt = "BandedMatrices"
@@ -55,11 +55,13 @@ LinearSolveEnzymeExt = "EnzymeCore"
5555
LinearSolveFastAlmostBandedMatricesExt = "FastAlmostBandedMatrices"
5656
LinearSolveHYPREExt = "HYPRE"
5757
LinearSolveIterativeSolversExt = "IterativeSolvers"
58+
LinearSolveKLUExt = "KLU"
5859
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
5960
LinearSolveKrylovKitExt = "KrylovKit"
6061
LinearSolveMetalExt = "Metal"
6162
LinearSolvePardisoExt = "Pardiso"
6263
LinearSolveRecursiveArrayToolsExt = "RecursiveArrayTools"
64+
LinearSolveSparseArraysExt = "SparseArrays"
6365

6466
[compat]
6567
AllocCheck = "0.1"

ext/LinearSolveKLUExt.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
module LinearSolveKLUExt
2+
3+
using LinearSolve, LinearSolve.LinearAlgebra
4+
using KLU, KLU.SparseArrays
5+
6+
const PREALLOCATED_KLU = KLU.KLUFactorization(SparseMatrixCSC(0, 0, [1], Int[],
7+
Float64[]))
8+
9+
function init_cacheval(alg::KLUFactorization,
10+
A, b, u, Pl, Pr,
11+
maxiters::Int, abstol, reltol,
12+
verbose::Bool, assumptions::OperatorAssumptions)
13+
nothing
14+
end
15+
16+
function init_cacheval(alg::KLUFactorization, A::SparseMatrixCSC{Float64, Int}, b, u, Pl,
17+
Pr,
18+
maxiters::Int, abstol, reltol,
19+
verbose::Bool, assumptions::OperatorAssumptions)
20+
PREALLOCATED_KLU
21+
end
22+
23+
function init_cacheval(alg::KLUFactorization, A::AbstractSparseArray, b, u, Pl, Pr,
24+
maxiters::Int, abstol,
25+
reltol,
26+
verbose::Bool, assumptions::OperatorAssumptions)
27+
A = convert(AbstractMatrix, A)
28+
return KLU.KLUFactorization(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
29+
nonzeros(A)))
30+
end
31+
32+
# TODO: guard this against errors
33+
function SciMLBase.solve!(cache::LinearCache, alg::KLUFactorization; kwargs...)
34+
A = cache.A
35+
A = convert(AbstractMatrix, A)
36+
if cache.isfresh
37+
cacheval = @get_cacheval(cache, :KLUFactorization)
38+
if alg.reuse_symbolic
39+
if alg.check_pattern && pattern_changed(cacheval, A)
40+
fact = KLU.klu(
41+
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
42+
nonzeros(A)),
43+
check = false)
44+
else
45+
fact = KLU.klu!(cacheval, nonzeros(A), check = false)
46+
end
47+
else
48+
# New fact each time since the sparsity pattern can change
49+
# and thus it needs to reallocate
50+
fact = KLU.klu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
51+
nonzeros(A)))
52+
end
53+
cache.cacheval = fact
54+
cache.isfresh = false
55+
end
56+
F = @get_cacheval(cache, :KLUFactorization)
57+
if F.common.status == KLU.KLU_OK
58+
y = ldiv!(cache.u, F, cache.b)
59+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
60+
else
61+
SciMLBase.build_linear_solution(
62+
alg, cache.u, nothing, cache; retcode = ReturnCode.Infeasible)
63+
end
64+
end
65+
66+
end

ext/LinearSolveSparseArraysExt.jl

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
module LinearSolveSparseArraysExt
2+
3+
using LinearSolve
4+
import LinearSolve: SciMLBase, LinearAlgebra, PrecompileTools, init_cacheval
5+
using LinearSolve: DefaultLinearSolver, DefaultAlgorithmChoice
6+
using SparseArrays
7+
using SparseArrays: AbstractSparseMatrixCSC, nonzeros, rowvals, getcolptr
8+
9+
# Specialize QR for the non-square case
10+
# Missing ldiv! definitions: https://github.com/JuliaSparse/SparseArrays.jl/issues/242
11+
function LinearSolve._ldiv!(x::Vector,
12+
A::Union{SparseArrays.QR, LinearAlgebra.QRCompactWY,
13+
SparseArrays.SPQR.QRSparse,
14+
SparseArrays.CHOLMOD.Factor}, b::Vector)
15+
x .= A \ b
16+
end
17+
18+
function LinearSolve._ldiv!(x::AbstractVector,
19+
A::Union{SparseArrays.QR, LinearAlgebra.QRCompactWY,
20+
SparseArrays.SPQR.QRSparse,
21+
SparseArrays.CHOLMOD.Factor}, b::AbstractVector)
22+
x .= A \ b
23+
end
24+
25+
# Ambiguity removal
26+
function LinearSolve._ldiv!(::SVector,
27+
A::Union{SparseArrays.CHOLMOD.Factor, LinearAlgebra.QR,
28+
LinearAlgebra.QRCompactWY, SparseArrays.SPQR.QRSparse},
29+
b::AbstractVector)
30+
(A \ b)
31+
end
32+
function LinearSolve._ldiv!(::SVector,
33+
A::Union{SparseArrays.CHOLMOD.Factor, LinearAlgebra.QR,
34+
LinearAlgebra.QRCompactWY, SparseArrays.SPQR.QRSparse},
35+
b::SVector)
36+
(A \ b)
37+
end
38+
39+
function LinearSolve.pattern_changed(fact, A::SparseArrays.SparseMatrixCSC)
40+
!(SparseArrays.decrement(SparseArrays.getcolptr(A)) ==
41+
fact.colptr && SparseArrays.decrement(SparseArrays.getrowval(A)) ==
42+
fact.rowval)
43+
end
44+
45+
const PREALLOCATED_UMFPACK = SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC(0, 0, [1],
46+
Int[], Float64[]))
47+
48+
function init_cacheval(alg::UMFPACKFactorization,
49+
A, b, u, Pl, Pr,
50+
maxiters::Int, abstol, reltol,
51+
verbose::Bool, assumptions::OperatorAssumptions)
52+
nothing
53+
end
54+
55+
function init_cacheval(alg::UMFPACKFactorization, A::SparseMatrixCSC{Float64, Int}, b, u,
56+
Pl, Pr,
57+
maxiters::Int, abstol, reltol,
58+
verbose::Bool, assumptions::OperatorAssumptions)
59+
PREALLOCATED_UMFPACK
60+
end
61+
62+
function init_cacheval(alg::UMFPACKFactorization, A::AbstractSparseArray, b, u, Pl, Pr,
63+
maxiters::Int, abstol,
64+
reltol,
65+
verbose::Bool, assumptions::OperatorAssumptions)
66+
A = convert(AbstractMatrix, A)
67+
zerobased = SparseArrays.getcolptr(A)[1] == 0
68+
return SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC(size(A)..., getcolptr(A),
69+
rowvals(A), nonzeros(A)))
70+
end
71+
72+
function SciMLBase.solve!(cache::LinearCache, alg::UMFPACKFactorization; kwargs...)
73+
A = cache.A
74+
A = convert(AbstractMatrix, A)
75+
if cache.isfresh
76+
cacheval = @get_cacheval(cache, :UMFPACKFactorization)
77+
if alg.reuse_symbolic
78+
# Caches the symbolic factorization: https://github.com/JuliaLang/julia/pull/33738
79+
if alg.check_pattern && pattern_changed(cacheval, A)
80+
fact = lu(
81+
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
82+
nonzeros(A)),
83+
check = false)
84+
else
85+
fact = lu!(cacheval,
86+
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
87+
nonzeros(A)), check = false)
88+
end
89+
else
90+
fact = lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
91+
check = false)
92+
end
93+
cache.cacheval = fact
94+
cache.isfresh = false
95+
end
96+
97+
F = @get_cacheval(cache, :UMFPACKFactorization)
98+
if F.status == SparseArrays.UMFPACK.UMFPACK_OK
99+
y = ldiv!(cache.u, F, cache.b)
100+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
101+
else
102+
SciMLBase.build_linear_solution(
103+
alg, cache.u, nothing, cache; retcode = ReturnCode.Infeasible)
104+
end
105+
end
106+
107+
const PREALLOCATED_CHOLMOD = cholesky(SparseMatrixCSC(0, 0, [1], Int[], Float64[]))
108+
109+
function init_cacheval(alg::CHOLMODFactorization,
110+
A, b, u, Pl, Pr,
111+
maxiters::Int, abstol, reltol,
112+
verbose::Bool, assumptions::OperatorAssumptions)
113+
nothing
114+
end
115+
116+
function init_cacheval(alg::CHOLMODFactorization,
117+
A::Union{SparseMatrixCSC{T, Int}, Symmetric{T, SparseMatrixCSC{T, Int}}}, b, u,
118+
Pl, Pr,
119+
maxiters::Int, abstol, reltol,
120+
verbose::Bool, assumptions::OperatorAssumptions) where {T <:
121+
Union{Float32, Float64}}
122+
PREALLOCATED_CHOLMOD
123+
end
124+
125+
function SciMLBase.solve!(cache::LinearCache, alg::CHOLMODFactorization; kwargs...)
126+
A = cache.A
127+
A = convert(AbstractMatrix, A)
128+
129+
if cache.isfresh
130+
cacheval = @get_cacheval(cache, :CHOLMODFactorization)
131+
fact = cholesky(A; check = false)
132+
if !LinearAlgebra.issuccess(fact)
133+
ldlt!(fact, A; check = false)
134+
end
135+
cache.cacheval = fact
136+
cache.isfresh = false
137+
end
138+
139+
cache.u .= @get_cacheval(cache, :CHOLMODFactorization) \ cache.b
140+
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
141+
end
142+
143+
function LinearSolve.defaultalg(
144+
A::Symmetric{<:Number, <:SparseMatrixCSC}, b, ::OperatorAssumptions{Bool})
145+
DefaultLinearSolver(DefaultAlgorithmChoice.CHOLMODFactorization)
146+
end
147+
148+
function LinearSolve.defaultalg(A::AbstractSparseMatrixCSC{Tv, Ti}, b,
149+
assump::OperatorAssumptions{Bool}) where {Tv, Ti}
150+
if assump.issq
151+
DefaultLinearSolver(DefaultAlgorithmChoice.SparspakFactorization)
152+
else
153+
error("Generic number sparse factorization for non-square is not currently handled")
154+
end
155+
end
156+
157+
function LinearSolve.defaultalg(A::AbstractSparseMatrixCSC{<:Union{Float64, ComplexF64}, Ti}, b,
158+
assump::OperatorAssumptions{Bool}) where {Ti}
159+
if assump.issq
160+
if length(b) <= 10_000 && length(nonzeros(A)) / length(A) < 2e-4
161+
DefaultLinearSolver(DefaultAlgorithmChoice.KLUFactorization)
162+
else
163+
DefaultLinearSolver(DefaultAlgorithmChoice.UMFPACKFactorization)
164+
end
165+
else
166+
DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
167+
end
168+
end
169+
170+
PrecompileTools.@compile_workload begin
171+
A = sprand(4, 4, 0.3) + I
172+
b = rand(4)
173+
prob = LinearProblem(A, b)
174+
sol = solve(prob, KLUFactorization())
175+
sol = solve(prob, UMFPACKFactorization())
176+
end
177+
178+
end

0 commit comments

Comments
 (0)