Skip to content

Commit 4f6225c

Browse files
WIP: Make SparseArrays an extension
We shouldn't need to load all of the sparse arrays infrastructure unless the user is using sparse arrays. There are a few things that need to be done here: 1. KLU.jl code needs to be copied so we can default to KLU calls without requiring a user-level `using KLU` 2. Some functions are made which only get a useful dispatch after using SparseArrays is loaded, but that should be sufficiently safe But one big issue is that Krylov.jl has a dependency on SparseArrays, so we'd either need to drop it as a dependency and a new Krylov library, or fix Krylov.jl to make it an extension.
1 parent e77fb72 commit 4f6225c

File tree

9 files changed

+1852
-258
lines changed

9 files changed

+1852
-258
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
1212
FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
1313
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1414
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
15-
KLU = "ef3ab10e-7fda-4108-b977-705223b18434"
1615
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
1716
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
1817
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
@@ -85,7 +84,6 @@ HYPRE = "1.4.0"
8584
InteractiveUtils = "1.10"
8685
IterativeSolvers = "0.9.3"
8786
JET = "0.8.28, 0.9"
88-
KLU = "0.6"
8987
KernelAbstractions = "0.9.27"
9088
Krylov = "0.9"
9189
KrylovKit = "0.8, 0.9"

ext/LinearSolveSparseArraysExt.jl

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
module LinearSolveSparseArrays
2+
3+
using LinearSolve, LinearAlgebra
4+
using SparseArrays
5+
using SparseArrays: AbstractSparseMatrixCSC, nonzeros, rowvals, getcolptr
6+
7+
# Can't `using KLU` because cannot have a dependency in there without
8+
# requiring the user does `using KLU`
9+
# But there's no reason to require it because SparseArrays will already
10+
# load SuiteSparse and thus all of the underlying KLU code
11+
include("../src/KLU/klu.jl")
12+
13+
LinearSolve.issparsematrixcsc(A::AbstractSparseMatrixCSC) = true
14+
15+
function LinearSolve.handle_sparsematrixcsc_lu(A::AbstractSparseMatrixCSC)
16+
lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
17+
check = false)
18+
end
19+
20+
function LinearSolve.init_cacheval(alg::GenericFactorization,
21+
A::Union{Hermitian{T, <:SparseMatrixCSC},
22+
Symmetric{T, <:SparseMatrixCSC}}, b, u, Pl, Pr,
23+
maxiters::Int, abstol, reltol, verbose::Bool,
24+
assumptions::OperatorAssumptions) where {T}
25+
newA = copy(convert(AbstractMatrix, A))
26+
LinearSolve.do_factorization(alg, newA, b, u)
27+
end
28+
29+
const PREALLOCATED_UMFPACK = SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC(0, 0, [1],
30+
Int[], Float64[]))
31+
32+
function LinearSolve.init_cacheval(alg::UMFPACKFactorization, A::SparseMatrixCSC{Float64, Int}, b, u,
33+
Pl, Pr,
34+
maxiters::Int, abstol, reltol,
35+
verbose::Bool, assumptions::OperatorAssumptions)
36+
PREALLOCATED_UMFPACK
37+
end
38+
39+
function LinearSolve.init_cacheval(alg::UMFPACKFactorization, A::AbstractSparseArray, b, u, Pl, Pr,
40+
maxiters::Int, abstol,
41+
reltol,
42+
verbose::Bool, assumptions::OperatorAssumptions)
43+
A = convert(AbstractMatrix, A)
44+
zerobased = SparseArrays.getcolptr(A)[1] == 0
45+
return SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC(size(A)..., getcolptr(A),
46+
rowvals(A), nonzeros(A)))
47+
end
48+
49+
function SciMLBase.solve!(cache::LinearCache, alg::UMFPACKFactorization; kwargs...)
50+
A = cache.A
51+
A = convert(AbstractMatrix, A)
52+
if cache.isfresh
53+
cacheval = LinearSolve.@get_cacheval(cache, :UMFPACKFactorization)
54+
if alg.reuse_symbolic
55+
# Caches the symbolic factorization: https://github.com/JuliaLang/julia/pull/33738
56+
if alg.check_pattern && pattern_changed(cacheval, A)
57+
fact = lu(
58+
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
59+
nonzeros(A)),
60+
check = false)
61+
else
62+
fact = lu!(cacheval,
63+
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
64+
nonzeros(A)), check = false)
65+
end
66+
else
67+
fact = lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
68+
check = false)
69+
end
70+
cache.cacheval = fact
71+
cache.isfresh = false
72+
end
73+
74+
F = LinearSolve.@get_cacheval(cache, :UMFPACKFactorization)
75+
if F.status == SparseArrays.UMFPACK.UMFPACK_OK
76+
y = ldiv!(cache.u, F, cache.b)
77+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
78+
else
79+
SciMLBase.build_linear_solution(
80+
alg, cache.u, nothing, cache; retcode = ReturnCode.Infeasible)
81+
end
82+
end
83+
84+
const PREALLOCATED_KLU = KLU.KLUFactorization(SparseMatrixCSC(0, 0, [1], Int[],
85+
Float64[]))
86+
87+
function LinearSolve.init_cacheval(alg::KLUFactorization, A::SparseMatrixCSC{Float64, Int}, b, u, Pl,
88+
Pr,
89+
maxiters::Int, abstol, reltol,
90+
verbose::Bool, assumptions::OperatorAssumptions)
91+
PREALLOCATED_KLU
92+
end
93+
94+
function LinearSolve.init_cacheval(alg::KLUFactorization, A::AbstractSparseArray, b, u, Pl, Pr,
95+
maxiters::Int, abstol,
96+
reltol,
97+
verbose::Bool, assumptions::OperatorAssumptions)
98+
A = convert(AbstractMatrix, A)
99+
return KLU.KLUFactorization(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
100+
nonzeros(A)))
101+
end
102+
103+
# TODO: guard this against errors
104+
function SciMLBase.solve!(cache::LinearCache, alg::KLUFactorization; kwargs...)
105+
A = cache.A
106+
A = convert(AbstractMatrix, A)
107+
if cache.isfresh
108+
cacheval = LinearSolve.@get_cacheval(cache, :KLUFactorization)
109+
if alg.reuse_symbolic
110+
if alg.check_pattern && pattern_changed(cacheval, A)
111+
fact = KLU.klu(
112+
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
113+
nonzeros(A)),
114+
check = false)
115+
else
116+
fact = KLU.klu!(cacheval, nonzeros(A), check = false)
117+
end
118+
else
119+
# New fact each time since the sparsity pattern can change
120+
# and thus it needs to reallocate
121+
fact = KLU.klu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
122+
nonzeros(A)))
123+
end
124+
cache.cacheval = fact
125+
cache.isfresh = false
126+
end
127+
F = LinearSolve.@get_cacheval(cache, :KLUFactorization)
128+
if F.common.status == KLU.KLU_OK
129+
y = ldiv!(cache.u, F, cache.b)
130+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
131+
else
132+
SciMLBase.build_linear_solution(
133+
alg, cache.u, nothing, cache; retcode = ReturnCode.Infeasible)
134+
end
135+
end
136+
137+
const PREALLOCATED_CHOLMOD = cholesky(SparseMatrixCSC(0, 0, [1], Int[], Float64[]))
138+
139+
function LinearSolve.init_cacheval(alg::CHOLMODFactorization,
140+
A::Union{SparseMatrixCSC{T, Int}, Symmetric{T, SparseMatrixCSC{T, Int}}}, b, u,
141+
Pl, Pr,
142+
maxiters::Int, abstol, reltol,
143+
verbose::Bool, assumptions::OperatorAssumptions) where {T <:
144+
Union{Float32, Float64}}
145+
PREALLOCATED_CHOLMOD
146+
end
147+
148+
function LinearSolve.init_cacheval(alg::NormalCholeskyFactorization,
149+
A::Union{AbstractSparseArray, GPUArraysCore.AnyGPUArray,
150+
Symmetric{<:Number, <:AbstractSparseArray}}, b, u, Pl, Pr,
151+
maxiters::Int, abstol, reltol, verbose::Bool,
152+
assumptions::OperatorAssumptions)
153+
ArrayInterface.cholesky_instance(convert(AbstractMatrix, A))
154+
end
155+
156+
# Specialize QR for the non-square case
157+
# Missing ldiv! definitions: https://github.com/JuliaSparse/SparseArrays.jl/issues/242
158+
function LinearSolve._ldiv!(x::Vector,
159+
A::Union{SparseArrays.QR, LinearAlgebra.QRCompactWY,
160+
SparseArrays.SPQR.QRSparse,
161+
SparseArrays.CHOLMOD.Factor}, b::Vector)
162+
x .= A \ b
163+
end
164+
165+
function LinearSolve._ldiv!(x::AbstractVector,
166+
A::Union{SparseArrays.QR, LinearAlgebra.QRCompactWY,
167+
SparseArrays.SPQR.QRSparse,
168+
SparseArrays.CHOLMOD.Factor}, b::AbstractVector)
169+
x .= A \ b
170+
end
171+
172+
# Ambiguity removal
173+
function LinearSolve._ldiv!(::SVector,
174+
A::Union{SparseArrays.CHOLMOD.Factor, LinearAlgebra.QR,
175+
LinearAlgebra.QRCompactWY, SparseArrays.SPQR.QRSparse},
176+
b::AbstractVector)
177+
(A \ b)
178+
end
179+
function LinearSolve._ldiv!(::SVector,
180+
A::Union{SparseArrays.CHOLMOD.Factor, LinearAlgebra.QR,
181+
LinearAlgebra.QRCompactWY, SparseArrays.SPQR.QRSparse},
182+
b::SVector)
183+
(A \ b)
184+
end
185+
186+
function pattern_changed(fact, A::SparseArrays.SparseMatrixCSC)
187+
!(SparseArrays.decrement(SparseArrays.getcolptr(A)) ==
188+
fact.colptr && SparseArrays.decrement(SparseArrays.getrowval(A)) ==
189+
fact.rowval)
190+
end
191+
192+
function LinearSolve.defaultalg(A::AbstractSparseMatrixCSC{<:Union{Float64, ComplexF64}, Ti}, b,
193+
assump::OperatorAssumptions{Bool}) where {Ti}
194+
if assump.issq
195+
if length(b) <= 10_000 && length(nonzeros(A)) / length(A) < 2e-4
196+
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.KLUFactorization)
197+
else
198+
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.UMFPACKFactorization)
199+
end
200+
else
201+
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.QRFactorization)
202+
end
203+
end
204+
205+
LinearSolve.PrecompileTools.@compile_workload begin
206+
A = sprand(4, 4, 0.3) + I
207+
b = rand(4)
208+
prob = LinearProblem(A, b)
209+
sol = solve(prob, KLUFactorization())
210+
sol = solve(prob, UMFPACKFactorization())
211+
end
212+
213+
end

ext/LinearSolveSparsepakExt.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
module LinearSolveSparsepakExt
2+
3+
using LinearSolve, LinearAlgebra
4+
using SparseArrays
5+
using SparseArrays: AbstractSparseMatrixCSC, nonzeros, rowvals, getcolptr
6+
using Sparspak
7+
8+
const PREALLOCATED_SPARSEPAK = sparspaklu(SparseMatrixCSC(0, 0, [1], Int[], Float64[]),
9+
factorize = false)
10+
11+
function LinearSolve.init_cacheval(::SparspakFactorization, A::SparseMatrixCSC{Float64, Int}, b, u, Pl,
12+
Pr, maxiters::Int, abstol,
13+
reltol,
14+
verbose::Bool, assumptions::OperatorAssumptions)
15+
PREALLOCATED_SPARSEPAK
16+
end
17+
18+
function init_cacheval(::SparspakFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol,
19+
reltol,
20+
verbose::Bool, assumptions::OperatorAssumptions)
21+
A = convert(AbstractMatrix, A)
22+
if A isa SparseArrays.AbstractSparseArray
23+
return sparspaklu(
24+
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
25+
nonzeros(A)),
26+
factorize = false)
27+
else
28+
return sparspaklu(SparseMatrixCSC(0, 0, [1], Int[], eltype(A)[]),
29+
factorize = false)
30+
end
31+
end
32+
33+
function SciMLBase.solve!(cache::LinearCache, alg::SparspakFactorization; kwargs...)
34+
A = cache.A
35+
if cache.isfresh
36+
if cache.cacheval !== nothing && alg.reuse_symbolic
37+
fact = sparspaklu!(LinearSolve.@get_cacheval(cache, :SparspakFactorization),
38+
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
39+
nonzeros(A)))
40+
else
41+
fact = sparspaklu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
42+
nonzeros(A)))
43+
end
44+
cache.cacheval = fact
45+
cache.isfresh = false
46+
end
47+
y = ldiv!(cache.u, LinearSolve.@get_cacheval(cache, :SparspakFactorization), cache.b)
48+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
49+
end
50+
51+
end

0 commit comments

Comments
 (0)