Skip to content

Commit 743b0d3

Browse files
Clean up sparse handling
- Fix case with sparse b, fixes #248 - Removes Sparspak from the section off code that's only for non-GPL cases. Sparspak.jl isn't GPL-licensed so we can always use it. - Restructured the default algorithm so that Sparspak is used when GPL is disabled, making it so it can always factorize properly.
1 parent f19ad41 commit 743b0d3

File tree

4 files changed

+53
-30
lines changed

4 files changed

+53
-30
lines changed

src/common.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,14 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith
111111
Tc = typeof(cacheval)
112112

113113
A = alias_A ? A : deepcopy(A)
114-
b = alias_b ? b : deepcopy(b)
114+
b = if b isa SparseArray && !(A isa Diagonal)
115+
Array(b) # the solution to a linear solve will always be dense!
116+
elseif alias_b
117+
b
118+
else
119+
deepcopy(b)
120+
end
121+
end
115122

116123
cache = LinearCache{
117124
typeof(A),

src/default.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,20 @@ function defaultalg(A, b, ::OperatorAssumptions{true})
152152
alg
153153
end
154154

155+
function defaultalg(A::SparseMatrixCSC, b, ::OperatorAssumptions{true})
156+
# If GPL libraries are loaded, then use SuiteSparse. Otherwise Sparspak
157+
if INCLUDE_SPARSE
158+
if length(b) <= 10000
159+
alg = KLUFactorization()
160+
else
161+
alg = UMFPACKFactorization()
162+
end
163+
else
164+
alg = SparspakFactorization()
165+
end
166+
alg
167+
end
168+
155169
function defaultalg(A, b, ::OperatorAssumptions{false})
156170
QRFactorization()
157171
end

src/factorization.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,3 +500,33 @@ function SciMLBase.solve(cache::LinearCache, alg::FastQRFactorization{P};
500500
y = ldiv!(cache.u, cache.cacheval.factors, cache.b)
501501
SciMLBase.build_linear_solution(alg, y, nothing, cache)
502502
end
503+
504+
## SparspakFactorization is here since it's MIT licensed, not GPL
505+
506+
struct SparspakFactorization <: AbstractFactorization end
507+
508+
function init_cacheval(::SparspakFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol,
509+
reltol,
510+
verbose::Bool, assumptions::OperatorAssumptions)
511+
A = convert(AbstractMatrix, A)
512+
p = Sparspak.Problem.Problem(size(A)...)
513+
Sparspak.Problem.insparse!(p, A)
514+
Sparspak.Problem.infullrhs!(p, b)
515+
s = Sparspak.SparseSolver.SparseSolver(p)
516+
return s
517+
end
518+
519+
function SciMLBase.solve(cache::LinearCache, alg::SparspakFactorization; kwargs...)
520+
A = cache.A
521+
A = convert(AbstractMatrix, A)
522+
if cache.isfresh
523+
p = Sparspak.Problem.Problem(size(A)...)
524+
Sparspak.Problem.insparse!(p, A)
525+
Sparspak.Problem.infullrhs!(p, cache.b)
526+
s = Sparspak.SparseSolver.SparseSolver(p)
527+
cache = set_cacheval(cache, s)
528+
end
529+
Sparspak.SparseSolver.solve!(cache.cacheval)
530+
copyto!(cache.u, cache.cacheval.p.x)
531+
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
532+
end

src/factorization_sparse.jl

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,4 @@ function _ldiv!(x::Vector,
55
SuiteSparse.SPQR.QRSparse,
66
SuiteSparse.CHOLMOD.Factor}, b::Vector)
77
x .= A \ b
8-
end
9-
10-
struct SparspakFactorization <: AbstractFactorization end
11-
12-
function init_cacheval(::SparspakFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol,
13-
reltol,
14-
verbose::Bool, assumptions::OperatorAssumptions)
15-
A = convert(AbstractMatrix, A)
16-
p = Sparspak.Problem.Problem(size(A)...)
17-
Sparspak.Problem.insparse!(p, A)
18-
Sparspak.Problem.infullrhs!(p, b)
19-
s = Sparspak.SparseSolver.SparseSolver(p)
20-
return s
21-
end
22-
23-
function SciMLBase.solve(cache::LinearCache, alg::SparspakFactorization; kwargs...)
24-
A = cache.A
25-
A = convert(AbstractMatrix, A)
26-
if cache.isfresh
27-
p = Sparspak.Problem.Problem(size(A)...)
28-
Sparspak.Problem.insparse!(p, A)
29-
Sparspak.Problem.infullrhs!(p, cache.b)
30-
s = Sparspak.SparseSolver.SparseSolver(p)
31-
cache = set_cacheval(cache, s)
32-
end
33-
Sparspak.SparseSolver.solve!(cache.cacheval)
34-
copyto!(cache.u, cache.cacheval.p.x)
35-
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
36-
end
8+
end

0 commit comments

Comments
 (0)