Skip to content

Commit ff024ae

Browse files
WIP: Make more libraries into an extension?
I think we should be considering some kind of split to LinearSolveCore vs LinearSolve, a la OrdinaryDiffEq. The default algorithm requires RecursiveFactorization and MKL, but we could definitely make a version that doesn't have these deps. But at least sparse stuff can go into an extension since those always dispatch on the existance of sparse matrices. This PR isn't quite complete yet, more needs to be removed with sparse, but it shows the right direction.
1 parent 1c30db0 commit ff024ae

File tree

8 files changed

+377
-334
lines changed

8 files changed

+377
-334
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
1313
FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
1414
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1515
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
16-
KLU = "ef3ab10e-7fda-4108-b977-705223b18434"
1716
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
1817
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
1918
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
@@ -27,7 +26,6 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2726
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
2827
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
2928
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
30-
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
3129
Sparspak = "e56a9233-b9d6-4f03-8d0f-1825330902ac"
3230
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
3331
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
@@ -42,11 +40,13 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
4240
FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e"
4341
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
4442
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
43+
KLU = "ef3ab10e-7fda-4108-b977-705223b18434"
4544
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
4645
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
4746
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
4847
Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
4948
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
49+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
5050

5151
[extensions]
5252
LinearSolveBandedMatricesExt = "BandedMatrices"
@@ -57,11 +57,13 @@ LinearSolveEnzymeExt = ["Enzyme", "EnzymeCore"]
5757
LinearSolveFastAlmostBandedMatricesExt = ["FastAlmostBandedMatrices"]
5858
LinearSolveHYPREExt = "HYPRE"
5959
LinearSolveIterativeSolversExt = "IterativeSolvers"
60+
LinearSolveKLUExt = "KLU"
6061
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
6162
LinearSolveKrylovKitExt = "KrylovKit"
6263
LinearSolveMetalExt = "Metal"
6364
LinearSolvePardisoExt = "Pardiso"
6465
LinearSolveRecursiveArrayToolsExt = "RecursiveArrayTools"
66+
LinearSolveSparseArraysExt = "SparseArrays"
6567

6668
[compat]
6769
AllocCheck = "0.1"

ext/LinearSolveKLUExt.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
module LinearSolveKLUExt
2+
3+
using LinearSolve, LinearSolve.LinearAlgebra
4+
using KLU, KLU.SparseArrays
5+
6+
const PREALLOCATED_KLU = KLU.KLUFactorization(SparseMatrixCSC(0, 0, [1], Int[],
7+
Float64[]))
8+
9+
function init_cacheval(alg::KLUFactorization,
10+
A, b, u, Pl, Pr,
11+
maxiters::Int, abstol, reltol,
12+
verbose::Bool, assumptions::OperatorAssumptions)
13+
nothing
14+
end
15+
16+
function init_cacheval(alg::KLUFactorization, A::SparseMatrixCSC{Float64, Int}, b, u, Pl,
17+
Pr,
18+
maxiters::Int, abstol, reltol,
19+
verbose::Bool, assumptions::OperatorAssumptions)
20+
PREALLOCATED_KLU
21+
end
22+
23+
function init_cacheval(alg::KLUFactorization, A::AbstractSparseArray, b, u, Pl, Pr,
24+
maxiters::Int, abstol,
25+
reltol,
26+
verbose::Bool, assumptions::OperatorAssumptions)
27+
A = convert(AbstractMatrix, A)
28+
return KLU.KLUFactorization(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
29+
nonzeros(A)))
30+
end
31+
32+
# TODO: guard this against errors
33+
function SciMLBase.solve!(cache::LinearCache, alg::KLUFactorization; kwargs...)
34+
A = cache.A
35+
A = convert(AbstractMatrix, A)
36+
if cache.isfresh
37+
cacheval = @get_cacheval(cache, :KLUFactorization)
38+
if alg.reuse_symbolic
39+
if alg.check_pattern && pattern_changed(cacheval, A)
40+
fact = KLU.klu(
41+
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
42+
nonzeros(A)),
43+
check = false)
44+
else
45+
fact = KLU.klu!(cacheval, nonzeros(A), check = false)
46+
end
47+
else
48+
# New fact each time since the sparsity pattern can change
49+
# and thus it needs to reallocate
50+
fact = KLU.klu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
51+
nonzeros(A)))
52+
end
53+
cache.cacheval = fact
54+
cache.isfresh = false
55+
end
56+
F = @get_cacheval(cache, :KLUFactorization)
57+
if F.common.status == KLU.KLU_OK
58+
y = ldiv!(cache.u, F, cache.b)
59+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
60+
else
61+
SciMLBase.build_linear_solution(
62+
alg, cache.u, nothing, cache; retcode = ReturnCode.Infeasible)
63+
end
64+
end
65+
66+
end

ext/LinearSolveSparseArraysExt.jl

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
module LinearSolveSparseArraysExt
2+
3+
using LinearSolve
4+
import LinearSolve: SciMLBase, LinearAlgebra, PrecompileTools, init_cacheval
5+
using LinearSolve: DefaultLinearSolver, DefaultAlgorithmChoice
6+
using SparseArrays
7+
using SparseArrays: AbstractSparseMatrixCSC, nonzeros, rowvals, getcolptr
8+
9+
# Specialize QR for the non-square case
10+
# Missing ldiv! definitions: https://github.com/JuliaSparse/SparseArrays.jl/issues/242
11+
function LinearSolve._ldiv!(x::Vector,
12+
A::Union{SparseArrays.QR, LinearAlgebra.QRCompactWY,
13+
SparseArrays.SPQR.QRSparse,
14+
SparseArrays.CHOLMOD.Factor}, b::Vector)
15+
x .= A \ b
16+
end
17+
18+
function LinearSolve._ldiv!(x::AbstractVector,
19+
A::Union{SparseArrays.QR, LinearAlgebra.QRCompactWY,
20+
SparseArrays.SPQR.QRSparse,
21+
SparseArrays.CHOLMOD.Factor}, b::AbstractVector)
22+
x .= A \ b
23+
end
24+
25+
# Ambiguity removal
26+
function LinearSolve._ldiv!(::SVector,
27+
A::Union{SparseArrays.CHOLMOD.Factor, LinearAlgebra.QR,
28+
LinearAlgebra.QRCompactWY, SparseArrays.SPQR.QRSparse},
29+
b::AbstractVector)
30+
(A \ b)
31+
end
32+
function LinearSolve._ldiv!(::SVector,
33+
A::Union{SparseArrays.CHOLMOD.Factor, LinearAlgebra.QR,
34+
LinearAlgebra.QRCompactWY, SparseArrays.SPQR.QRSparse},
35+
b::SVector)
36+
(A \ b)
37+
end
38+
39+
function LinearSolve.pattern_changed(fact, A::SparseArrays.SparseMatrixCSC)
40+
!(SparseArrays.decrement(SparseArrays.getcolptr(A)) ==
41+
fact.colptr && SparseArrays.decrement(SparseArrays.getrowval(A)) ==
42+
fact.rowval)
43+
end
44+
45+
const PREALLOCATED_UMFPACK = SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC(0, 0, [1],
46+
Int[], Float64[]))
47+
48+
function init_cacheval(alg::UMFPACKFactorization,
49+
A, b, u, Pl, Pr,
50+
maxiters::Int, abstol, reltol,
51+
verbose::Bool, assumptions::OperatorAssumptions)
52+
nothing
53+
end
54+
55+
function init_cacheval(alg::UMFPACKFactorization, A::SparseMatrixCSC{Float64, Int}, b, u,
56+
Pl, Pr,
57+
maxiters::Int, abstol, reltol,
58+
verbose::Bool, assumptions::OperatorAssumptions)
59+
PREALLOCATED_UMFPACK
60+
end
61+
62+
function init_cacheval(alg::UMFPACKFactorization, A::AbstractSparseArray, b, u, Pl, Pr,
63+
maxiters::Int, abstol,
64+
reltol,
65+
verbose::Bool, assumptions::OperatorAssumptions)
66+
A = convert(AbstractMatrix, A)
67+
zerobased = SparseArrays.getcolptr(A)[1] == 0
68+
return SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC(size(A)..., getcolptr(A),
69+
rowvals(A), nonzeros(A)))
70+
end
71+
72+
function SciMLBase.solve!(cache::LinearCache, alg::UMFPACKFactorization; kwargs...)
73+
A = cache.A
74+
A = convert(AbstractMatrix, A)
75+
if cache.isfresh
76+
cacheval = @get_cacheval(cache, :UMFPACKFactorization)
77+
if alg.reuse_symbolic
78+
# Caches the symbolic factorization: https://github.com/JuliaLang/julia/pull/33738
79+
if alg.check_pattern && pattern_changed(cacheval, A)
80+
fact = lu(
81+
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
82+
nonzeros(A)),
83+
check = false)
84+
else
85+
fact = lu!(cacheval,
86+
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
87+
nonzeros(A)), check = false)
88+
end
89+
else
90+
fact = lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
91+
check = false)
92+
end
93+
cache.cacheval = fact
94+
cache.isfresh = false
95+
end
96+
97+
F = @get_cacheval(cache, :UMFPACKFactorization)
98+
if F.status == SparseArrays.UMFPACK.UMFPACK_OK
99+
y = ldiv!(cache.u, F, cache.b)
100+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
101+
else
102+
SciMLBase.build_linear_solution(
103+
alg, cache.u, nothing, cache; retcode = ReturnCode.Infeasible)
104+
end
105+
end
106+
107+
const PREALLOCATED_CHOLMOD = cholesky(SparseMatrixCSC(0, 0, [1], Int[], Float64[]))
108+
109+
function init_cacheval(alg::CHOLMODFactorization,
110+
A, b, u, Pl, Pr,
111+
maxiters::Int, abstol, reltol,
112+
verbose::Bool, assumptions::OperatorAssumptions)
113+
nothing
114+
end
115+
116+
function init_cacheval(alg::CHOLMODFactorization,
117+
A::Union{SparseMatrixCSC{T, Int}, Symmetric{T, SparseMatrixCSC{T, Int}}}, b, u,
118+
Pl, Pr,
119+
maxiters::Int, abstol, reltol,
120+
verbose::Bool, assumptions::OperatorAssumptions) where {T <:
121+
Union{Float32, Float64}}
122+
PREALLOCATED_CHOLMOD
123+
end
124+
125+
function SciMLBase.solve!(cache::LinearCache, alg::CHOLMODFactorization; kwargs...)
126+
A = cache.A
127+
A = convert(AbstractMatrix, A)
128+
129+
if cache.isfresh
130+
cacheval = @get_cacheval(cache, :CHOLMODFactorization)
131+
fact = cholesky(A; check = false)
132+
if !LinearAlgebra.issuccess(fact)
133+
ldlt!(fact, A; check = false)
134+
end
135+
cache.cacheval = fact
136+
cache.isfresh = false
137+
end
138+
139+
cache.u .= @get_cacheval(cache, :CHOLMODFactorization) \ cache.b
140+
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
141+
end
142+
143+
function LinearSolve.defaultalg(
144+
A::Symmetric{<:Number, <:SparseMatrixCSC}, b, ::OperatorAssumptions{Bool})
145+
DefaultLinearSolver(DefaultAlgorithmChoice.CHOLMODFactorization)
146+
end
147+
148+
function LinearSolve.defaultalg(A::AbstractSparseMatrixCSC{Tv, Ti}, b,
149+
assump::OperatorAssumptions{Bool}) where {Tv, Ti}
150+
if assump.issq
151+
DefaultLinearSolver(DefaultAlgorithmChoice.SparspakFactorization)
152+
else
153+
error("Generic number sparse factorization for non-square is not currently handled")
154+
end
155+
end
156+
157+
function LinearSolve.defaultalg(A::AbstractSparseMatrixCSC{<:Union{Float64, ComplexF64}, Ti}, b,
158+
assump::OperatorAssumptions{Bool}) where {Ti}
159+
if assump.issq
160+
if length(b) <= 10_000 && length(nonzeros(A)) / length(A) < 2e-4
161+
DefaultLinearSolver(DefaultAlgorithmChoice.KLUFactorization)
162+
else
163+
DefaultLinearSolver(DefaultAlgorithmChoice.UMFPACKFactorization)
164+
end
165+
else
166+
DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
167+
end
168+
end
169+
170+
PrecompileTools.@compile_workload begin
171+
A = sprand(4, 4, 0.3) + I
172+
b = rand(4)
173+
prob = LinearProblem(A, b)
174+
sol = solve(prob, KLUFactorization())
175+
sol = solve(prob, UMFPACKFactorization())
176+
end
177+
178+
end

0 commit comments

Comments
 (0)