Skip to content

Commit 536611a

Browse files
finishing touches
1 parent 4f6225c commit 536611a

10 files changed

+75
-64
lines changed

Project.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
2525
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
2626
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2727
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
28-
Sparspak = "e56a9233-b9d6-4f03-8d0f-1825330902ac"
2928
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
3029
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
3130

@@ -44,6 +43,7 @@ Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
4443
Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
4544
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
4645
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
46+
Sparspak = "e56a9233-b9d6-4f03-8d0f-1825330902ac"
4747

4848
[extensions]
4949
LinearSolveBandedMatricesExt = "BandedMatrices"
@@ -60,6 +60,8 @@ LinearSolveMetalExt = "Metal"
6060
LinearSolvePardisoExt = "Pardiso"
6161
LinearSolveRecursiveArrayToolsExt = "RecursiveArrayTools"
6262
LinearSolveRecursiveFactorizationExt = "RecursiveFactorization"
63+
LinearSolveSparseArraysExt = "SparseArrays"
64+
LinearSolveSparspakExt = "Sparspak"
6365

6466
[compat]
6567
AllocCheck = "0.2"
@@ -141,10 +143,11 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
141143
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
142144
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
143145
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
146+
Sparspak = "e56a9233-b9d6-4f03-8d0f-1825330902ac"
144147
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
145148
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
146149
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
147150
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
148151

149152
[targets]
150-
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "KrylovPreconditioners", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote", "RecursiveFactorization"]
153+
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "KrylovPreconditioners", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote", "RecursiveFactorization", "Sparspak"]

ext/LinearSolveSparseArraysExt.jl

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module LinearSolveSparseArrays
1+
module LinearSolveSparseArraysExt
22

33
using LinearSolve, LinearAlgebra
44
using SparseArrays
@@ -11,12 +11,44 @@ using SparseArrays: AbstractSparseMatrixCSC, nonzeros, rowvals, getcolptr
1111
include("../src/KLU/klu.jl")
1212

1313
LinearSolve.issparsematrixcsc(A::AbstractSparseMatrixCSC) = true
14+
LinearSolve.issparsematrix(A::AbstractSparseArray) = true
15+
LinearSolve.make_SparseMatrixCSC(A::AbstractSparseArray) = SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A))
16+
LinearSolve.makeempty_SparaseMatrixCSC(A::AbstractSparseArray) = SparseMatrixCSC(0, 0, [1], Int[], eltype(A)[])
17+
18+
function LinearSolve.init_cacheval(alg::RFLUFactorization,
19+
A::Union{AbstractSparseArray, LinearSolve.SciMLOperators.AbstractSciMLOperator}, b, u, Pl, Pr,
20+
maxiters::Int,
21+
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
22+
nothing, nothing
23+
end
24+
25+
26+
function LinearSolve.init_cacheval(
27+
alg::QRFactorization, A::Symmetric{<:Number, <:SparseMatrixCSC}, b, u, Pl, Pr,
28+
maxiters::Int, abstol, reltol, verbose::Bool,
29+
assumptions::OperatorAssumptions)
30+
return nothing
31+
end
1432

1533
function LinearSolve.handle_sparsematrixcsc_lu(A::AbstractSparseMatrixCSC)
1634
lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
1735
check = false)
1836
end
1937

38+
function LinearSolve.defaultalg(
39+
A::Symmetric{<:Number, <:SparseMatrixCSC}, b, ::OperatorAssumptions{Bool})
40+
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.CHOLMODFactorization)
41+
end
42+
43+
function LinearSolve.defaultalg(A::AbstractSparseMatrixCSC{Tv, Ti}, b,
44+
assump::OperatorAssumptions{Bool}) where {Tv, Ti}
45+
if assump.issq
46+
DefaultLinearSolver(DefaultAlgorithmChoice.SparspakFactorization)
47+
else
48+
error("Generic number sparse factorization for non-square is not currently handled")
49+
end
50+
end
51+
2052
function LinearSolve.init_cacheval(alg::GenericFactorization,
2153
A::Union{Hermitian{T, <:SparseMatrixCSC},
2254
Symmetric{T, <:SparseMatrixCSC}}, b, u, Pl, Pr,
@@ -46,7 +78,7 @@ function LinearSolve.init_cacheval(alg::UMFPACKFactorization, A::AbstractSparseA
4678
rowvals(A), nonzeros(A)))
4779
end
4880

49-
function SciMLBase.solve!(cache::LinearCache, alg::UMFPACKFactorization; kwargs...)
81+
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::UMFPACKFactorization; kwargs...)
5082
A = cache.A
5183
A = convert(AbstractMatrix, A)
5284
if cache.isfresh
@@ -101,7 +133,7 @@ function LinearSolve.init_cacheval(alg::KLUFactorization, A::AbstractSparseArray
101133
end
102134

103135
# TODO: guard this against errors
104-
function SciMLBase.solve!(cache::LinearCache, alg::KLUFactorization; kwargs...)
136+
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::KLUFactorization; kwargs...)
105137
A = cache.A
106138
A = convert(AbstractMatrix, A)
107139
if cache.isfresh
@@ -146,11 +178,11 @@ function LinearSolve.init_cacheval(alg::CHOLMODFactorization,
146178
end
147179

148180
function LinearSolve.init_cacheval(alg::NormalCholeskyFactorization,
149-
A::Union{AbstractSparseArray, GPUArraysCore.AnyGPUArray,
181+
A::Union{AbstractSparseArray, LinearSolve.GPUArraysCore.AnyGPUArray,
150182
Symmetric{<:Number, <:AbstractSparseArray}}, b, u, Pl, Pr,
151183
maxiters::Int, abstol, reltol, verbose::Bool,
152184
assumptions::OperatorAssumptions)
153-
ArrayInterface.cholesky_instance(convert(AbstractMatrix, A))
185+
LinearSolve.ArrayInterface.cholesky_instance(convert(AbstractMatrix, A))
154186
end
155187

156188
# Specialize QR for the non-square case
@@ -170,16 +202,16 @@ function LinearSolve._ldiv!(x::AbstractVector,
170202
end
171203

172204
# Ambiguity removal
173-
function LinearSolve._ldiv!(::SVector,
205+
function LinearSolve._ldiv!(::LinearSolve.SVector,
174206
A::Union{SparseArrays.CHOLMOD.Factor, LinearAlgebra.QR,
175207
LinearAlgebra.QRCompactWY, SparseArrays.SPQR.QRSparse},
176208
b::AbstractVector)
177209
(A \ b)
178210
end
179-
function LinearSolve._ldiv!(::SVector,
211+
function LinearSolve._ldiv!(::LinearSolve.SVector,
180212
A::Union{SparseArrays.CHOLMOD.Factor, LinearAlgebra.QR,
181213
LinearAlgebra.QRCompactWY, SparseArrays.SPQR.QRSparse},
182-
b::SVector)
214+
b::LinearSolve.SVector)
183215
(A \ b)
184216
end
185217

ext/LinearSolveSparsepakExt.jl renamed to ext/LinearSolveSparspakExt.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
module LinearSolveSparsepakExt
1+
module LinearSolveSparspakExt
22

33
using LinearSolve, LinearAlgebra
4-
using SparseArrays
5-
using SparseArrays: AbstractSparseMatrixCSC, nonzeros, rowvals, getcolptr
64
using Sparspak
5+
using Sparspak.SparseCSCInterface.SparseArrays
6+
using SparseArrays: AbstractSparseMatrixCSC, nonzeros, rowvals, getcolptr
77

88
const PREALLOCATED_SPARSEPAK = sparspaklu(SparseMatrixCSC(0, 0, [1], Int[], Float64[]),
99
factorize = false)
@@ -15,7 +15,7 @@ function LinearSolve.init_cacheval(::SparspakFactorization, A::SparseMatrixCSC{F
1515
PREALLOCATED_SPARSEPAK
1616
end
1717

18-
function init_cacheval(::SparspakFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol,
18+
function LinearSolve.init_cacheval(::SparspakFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol,
1919
reltol,
2020
verbose::Bool, assumptions::OperatorAssumptions)
2121
A = convert(AbstractMatrix, A)
@@ -30,7 +30,7 @@ function init_cacheval(::SparspakFactorization, A, b, u, Pl, Pr, maxiters::Int,
3030
end
3131
end
3232

33-
function SciMLBase.solve!(cache::LinearCache, alg::SparspakFactorization; kwargs...)
33+
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::SparspakFactorization; kwargs...)
3434
A = cache.A
3535
if cache.isfresh
3636
if cache.cacheval !== nothing && alg.reuse_symbolic
@@ -48,4 +48,12 @@ function SciMLBase.solve!(cache::LinearCache, alg::SparspakFactorization; kwargs
4848
SciMLBase.build_linear_solution(alg, y, nothing, cache)
4949
end
5050

51+
LinearSolve.PrecompileTools.@compile_workload begin
52+
A = sprand(4, 4, 0.3) + I
53+
b = rand(4)
54+
prob = LinearProblem(A * A', b)
55+
sol = solve(prob) # in case sparspak is used as default
56+
sol = solve(prob, SparspakFactorization())
57+
end
58+
5159
end

src/LinearSolve.jl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ function _fast_sym_givens! end
9191

9292
issparsematrixcsc(A) = false
9393
handle_sparsematrixcsc_lu(A) = lu(A)
94+
issparsematrix(A) = false
95+
make_SparseMatrixCSC(A) = nothing
96+
makeempty_SparaseMatrixCSC(A) = nothing
9497

9598
EnumX.@enumx DefaultAlgorithmChoice begin
9699
LUFactorization
@@ -207,14 +210,6 @@ PrecompileTools.@compile_workload begin
207210
sol = solve(prob, KrylovJL_GMRES())
208211
end
209212

210-
PrecompileTools.@compile_workload begin
211-
A = sprand(4, 4, 0.3) + I
212-
b = rand(4)
213-
prob = LinearProblem(A * A', b)
214-
sol = solve(prob) # in case sparspak is used as default
215-
sol = solve(prob, SparspakFactorization())
216-
end
217-
218213
ALREADY_WARNED_CUDSS = Ref{Bool}(false)
219214
error_no_cudss_lu(A) = nothing
220215
cudss_loaded(A) = false

src/common.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,19 +194,19 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
194194
A
195195
elseif A isa Array
196196
copy(A)
197-
elseif A isa AbstractSparseMatrixCSC
198-
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A))
197+
elseif issparsematrixcsc(A)
198+
make_SparseMatrixCSC(A)
199199
else
200200
deepcopy(A)
201201
end
202202

203-
b = if b isa SparseArrays.AbstractSparseArray && !(A isa Diagonal)
203+
b = if issparsematrix(b) && !(A isa Diagonal)
204204
Array(b) # the solution to a linear solve will always be dense!
205205
elseif alias_b || b isa SVector
206206
b
207207
elseif b isa Array
208208
copy(b)
209-
elseif b isa AbstractSparseMatrixCSC
209+
elseif issparsematrixcsc(b)
210210
SparseMatrixCSC(size(b)..., getcolptr(b), rowvals(b), nonzeros(b))
211211
else
212212
deepcopy(b)

src/default.jl

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -92,20 +92,6 @@ function defaultalg(A::Symmetric{<:Number, <:Array}, b, ::OperatorAssumptions{Bo
9292
DefaultLinearSolver(DefaultAlgorithmChoice.BunchKaufmanFactorization)
9393
end
9494

95-
function defaultalg(
96-
A::Symmetric{<:Number, <:SparseMatrixCSC}, b, ::OperatorAssumptions{Bool})
97-
DefaultLinearSolver(DefaultAlgorithmChoice.CHOLMODFactorization)
98-
end
99-
100-
function defaultalg(A::AbstractSparseMatrixCSC{Tv, Ti}, b,
101-
assump::OperatorAssumptions{Bool}) where {Tv, Ti}
102-
if assump.issq
103-
DefaultLinearSolver(DefaultAlgorithmChoice.SparspakFactorization)
104-
else
105-
error("Generic number sparse factorization for non-square is not currently handled")
106-
end
107-
end
108-
10995
function defaultalg(A::GPUArraysCore.AnyGPUArray, b, assump::OperatorAssumptions{Bool})
11096
if assump.condition === OperatorCondition.IllConditioned || !assump.issq
11197
DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
@@ -308,7 +294,7 @@ cache.cacheval = NamedTuple(LUFactorization = cache of LUFactorization, ...)
308294
caches = map(first.(EnumX.symbol_map(DefaultAlgorithmChoice.T))) do alg
309295
if alg === :KrylovJL_GMRES || alg === :KrylovJL_CRAIGMR || alg === :KrylovJL_LSMR
310296
quote
311-
if A isa Matrix || A isa SparseMatrixCSC
297+
if A isa Matrix || issparsematrixcsc(A)
312298
nothing
313299
else
314300
init_cacheval($(algchoice_to_alg(alg)), A, b, u, Pl, Pr, maxiters,

src/extension_algs.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,12 @@ struct RFLUFactorization{P, T} <: AbstractDenseFactorization
9999
throwerror &&
100100
error("RFLUFactorization requires that RecursiveFactorization.jl is loaded, i.e. `using RecursiveFactorization`")
101101
end
102+
new{P, T}()
102103
end
103104
end
104105

105-
function RFLUFactorization(; pivot = Val(true), thread = Val(true))
106-
RFLUFactorization(pivot, thread)
106+
function RFLUFactorization(; pivot = Val(true), thread = Val(true), throwerror = true)
107+
RFLUFactorization(pivot, thread; throwerror)
107108
end
108109

109110
"""

src/factorization.jl

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,6 @@ function LinearSolve.init_cacheval(
4545
PREALLOCATED_LU, ipiv
4646
end
4747

48-
function LinearSolve.init_cacheval(alg::RFLUFactorization,
49-
A::Union{AbstractSparseArray, AbstractSciMLOperator}, b, u, Pl, Pr,
50-
maxiters::Int,
51-
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
52-
nothing, nothing
53-
end
54-
5548
function LinearSolve.init_cacheval(alg::RFLUFactorization,
5649
A::Union{Diagonal, SymTridiagonal, Tridiagonal}, b, u, Pl, Pr,
5750
maxiters::Int,
@@ -111,7 +104,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::LUFactorization; kwargs...)
111104
A = convert(AbstractMatrix, A)
112105
if cache.isfresh
113106
cacheval = @get_cacheval(cache, :LUFactorization)
114-
if A isa AbstractSparseMatrix && alg.reuse_symbolic
107+
if issparsematrix(A) && alg.reuse_symbolic
115108
# Caches the symbolic factorization: https://github.com/JuliaLang/julia/pull/33738
116109
# If SparseMatrixCSC, check if the pattern has changed
117110
if alg.check_pattern && pattern_changed(cacheval, A)
@@ -248,13 +241,6 @@ function init_cacheval(alg::QRFactorization, A::Symmetric{<:Number, <:Array}, b,
248241
return qr(convert(AbstractMatrix, A), alg.pivot)
249242
end
250243

251-
function init_cacheval(
252-
alg::QRFactorization, A::Symmetric{<:Number, <:SparseMatrixCSC}, b, u, Pl, Pr,
253-
maxiters::Int, abstol, reltol, verbose::Bool,
254-
assumptions::OperatorAssumptions)
255-
return nothing
256-
end
257-
258244
const PREALLOCATED_QR_ColumnNorm = ArrayInterface.qr_instance(rand(1, 1), ColumnNorm())
259245

260246
function init_cacheval(alg::QRFactorization{ColumnNorm}, A::Matrix{Float64}, b, u, Pl, Pr,

src/iterative_wrappers.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,16 +195,16 @@ function init_cacheval(alg::KrylovJL, A, b, u, Pl, Pr, maxiters::Int, abstol, re
195195
alg.KrylovAlg === Krylov.fgmres! ||
196196
alg.KrylovAlg === Krylov.gpmr! ||
197197
alg.KrylovAlg === Krylov.fom!)
198-
if A isa SparseMatrixCSC
199-
KS(SparseMatrixCSC(0, 0, [1], Int[], eltype(A)[]), eltype(b)[], 1)
198+
if issparsematrixcsc(A)
199+
KS(makeempty_SparaseMatrixCSC(A), eltype(b)[], 1)
200200
elseif A isa Matrix
201201
KS(Matrix{eltype(A)}(undef, 0, 0), eltype(b)[], 1)
202202
else
203203
KS(A, b, 1)
204204
end
205205
else
206-
if A isa SparseMatrixCSC
207-
KS(SparseMatrixCSC(0, 0, [1], Int[], eltype(A)[]), eltype(b)[])
206+
if issparsematrixcsc(A)
207+
KS(makeempty_SparaseMatrixCSC(A), eltype(b)[])
208208
elseif A isa Matrix
209209
KS(Matrix{eltype(A)}(undef, 0, 0), eltype(b)[])
210210
else

test/basictests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using LinearSolve, LinearAlgebra, SparseArrays, MultiFloats, ForwardDiff
2-
using SciMLOperators, RecursiveFactorization
2+
using SciMLOperators, RecursiveFactorization, Sparspak
33
using IterativeSolvers, KrylovKit, MKL_jll, KrylovPreconditioners
44
using Test
55
import Random

0 commit comments

Comments
 (0)