Skip to content

Commit b979527

Browse files
Merge pull request #249 from SciML/sparse
Clean up sparse handling
2 parents f19ad41 + cd31ac9 commit b979527

File tree

6 files changed

+84
-32
lines changed

6 files changed

+84
-32
lines changed

src/common.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,15 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith
9898
kwargs...)
9999
@unpack A, b, u0, p = prob
100100

101+
A = alias_A ? A : deepcopy(A)
102+
b = if b isa SparseArrays.AbstractSparseArray && !(A isa Diagonal)
103+
Array(b) # the solution to a linear solve will always be dense!
104+
elseif alias_b
105+
b
106+
else
107+
deepcopy(b)
108+
end
109+
101110
u0 = if u0 !== nothing
102111
u0
103112
else
@@ -110,9 +119,6 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith
110119
isfresh = true
111120
Tc = typeof(cacheval)
112121

113-
A = alias_A ? A : deepcopy(A)
114-
b = alias_b ? b : deepcopy(b)
115-
116122
cache = LinearCache{
117123
typeof(A),
118124
typeof(b),

src/default.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ end
5353
end
5454
else
5555
function defaultalg(A::SparseMatrixCSC, b, ::OperatorAssumptions{true})
56-
KrylovJL_GMRES()
56+
SparspakFactorization()
5757
end
5858
end
5959

src/factorization.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,3 +500,33 @@ function SciMLBase.solve(cache::LinearCache, alg::FastQRFactorization{P};
500500
y = ldiv!(cache.u, cache.cacheval.factors, cache.b)
501501
SciMLBase.build_linear_solution(alg, y, nothing, cache)
502502
end
503+
504+
## SparspakFactorization is here since it's MIT licensed, not GPL
505+
506+
struct SparspakFactorization <: AbstractFactorization end
507+
508+
function init_cacheval(::SparspakFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol,
509+
reltol,
510+
verbose::Bool, assumptions::OperatorAssumptions)
511+
A = convert(AbstractMatrix, A)
512+
p = Sparspak.Problem.Problem(size(A)...)
513+
Sparspak.Problem.insparse!(p, A)
514+
Sparspak.Problem.infullrhs!(p, b)
515+
s = Sparspak.SparseSolver.SparseSolver(p)
516+
return s
517+
end
518+
519+
function SciMLBase.solve(cache::LinearCache, alg::SparspakFactorization; kwargs...)
520+
A = cache.A
521+
A = convert(AbstractMatrix, A)
522+
if cache.isfresh
523+
p = Sparspak.Problem.Problem(size(A)...)
524+
Sparspak.Problem.insparse!(p, A)
525+
Sparspak.Problem.infullrhs!(p, cache.b)
526+
s = Sparspak.SparseSolver.SparseSolver(p)
527+
cache = set_cacheval(cache, s)
528+
end
529+
Sparspak.SparseSolver.solve!(cache.cacheval)
530+
copyto!(cache.u, cache.cacheval.p.x)
531+
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
532+
end

src/factorization_sparse.jl

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,3 @@ function _ldiv!(x::Vector,
66
SuiteSparse.CHOLMOD.Factor}, b::Vector)
77
x .= A \ b
88
end
9-
10-
struct SparspakFactorization <: AbstractFactorization end
11-
12-
function init_cacheval(::SparspakFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol,
13-
reltol,
14-
verbose::Bool, assumptions::OperatorAssumptions)
15-
A = convert(AbstractMatrix, A)
16-
p = Sparspak.Problem.Problem(size(A)...)
17-
Sparspak.Problem.insparse!(p, A)
18-
Sparspak.Problem.infullrhs!(p, b)
19-
s = Sparspak.SparseSolver.SparseSolver(p)
20-
return s
21-
end
22-
23-
function SciMLBase.solve(cache::LinearCache, alg::SparspakFactorization; kwargs...)
24-
A = cache.A
25-
A = convert(AbstractMatrix, A)
26-
if cache.isfresh
27-
p = Sparspak.Problem.Problem(size(A)...)
28-
Sparspak.Problem.insparse!(p, A)
29-
Sparspak.Problem.infullrhs!(p, cache.b)
30-
s = Sparspak.SparseSolver.SparseSolver(p)
31-
cache = set_cacheval(cache, s)
32-
end
33-
Sparspak.SparseSolver.solve!(cache.cacheval)
34-
copyto!(cache.u, cache.cacheval.p.x)
35-
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
36-
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ if GROUP == "All" || GROUP == "Core"
2222
@time @safetestset "Basic Tests" begin include("basictests.jl") end
2323
@time @safetestset "Zero Initialization Tests" begin include("zeroinittests.jl") end
2424
@time @safetestset "Non-Square Tests" begin include("nonsquare.jl") end
25+
@time @safetestset "SparseVector b Tests" begin include("sparse_vector.jl") end
2526
end
2627

2728
if GROUP == "LinearSolveCUDA"

test/sparse_vector.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
using SparseArrays
2+
using LinearSolve
3+
using LinearAlgebra
4+
5+
# Constructing sparse array
6+
function hess_sparse(x::Vector{T}) where {T}
7+
return [
8+
-sin(x[1] + x[2]) + 1,
9+
-sin(x[1] + x[2]),
10+
-sin(x[1] + x[2]),
11+
-sin(x[1] + x[2]) + 1.0,
12+
1.0,
13+
1.0,
14+
12.0 * x[5]^2 + 1.0,
15+
1.0,
16+
]
17+
end
18+
rowval = [1, 1, 2, 2, 3, 4, 5, 6]
19+
colval = [1, 2, 1, 2, 3, 4, 5, 6]
20+
21+
# Constructing sparse vec
22+
function grad_sparse(x::Vector{T}) where {T <: Number}
23+
return [cos(x[1] + x[2]), cos(x[1] + x[2]), 2 * x[3], 1 / 2, 4 * x[5]^3, 1 / 2]
24+
end
25+
gradinds = [1, 2, 3, 4, 5, 6]
26+
27+
# Forming the matrix and vector
28+
x0 = [
29+
0.7853981648713337,
30+
0.7853981693418342,
31+
1.023999999999997e-7,
32+
-1.0,
33+
0.33141395338218227,
34+
-1.0,
35+
]
36+
n = length(x0)
37+
hess_mat = sparse(rowval, colval, hess_sparse(x0), n, n)
38+
grad_vec = sparsevec(gradinds, grad_sparse(x0), n)
39+
40+
# # Converting grad_vec to dense succeds in solving
41+
prob = LinearProblem(hess_mat, grad_vec)
42+
linsolve = init(prob)
43+
@test solve(linsolve).u hess_mat \ Array(grad_vec)

0 commit comments

Comments
 (0)