Skip to content

Commit 4a6c243

Browse files
fix sparse vector handling
1 parent 9da256d commit 4a6c243

File tree

3 files changed

+37
-9
lines changed

3 files changed

+37
-9
lines changed

src/common.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,15 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith
9898
kwargs...)
9999
@unpack A, b, u0, p = prob
100100

101+
A = alias_A ? A : deepcopy(A)
102+
b = if b isa SparseArrays.AbstractSparseArray && !(A isa Diagonal)
103+
Array(b) # the solution to a linear solve will always be dense!
104+
elseif alias_b
105+
b
106+
else
107+
deepcopy(b)
108+
end
109+
101110
u0 = if u0 !== nothing
102111
u0
103112
else
@@ -110,15 +119,6 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith
110119
isfresh = true
111120
Tc = typeof(cacheval)
112121

113-
A = alias_A ? A : deepcopy(A)
114-
b = if b isa SparseArrays.AbstractSparseArray && !(A isa Diagonal)
115-
Array(b) # the solution to a linear solve will always be dense!
116-
elseif alias_b
117-
b
118-
else
119-
deepcopy(b)
120-
end
121-
122122
cache = LinearCache{
123123
typeof(A),
124124
typeof(b),

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ if GROUP == "All" || GROUP == "Core"
2222
@time @safetestset "Basic Tests" begin include("basictests.jl") end
2323
@time @safetestset "Zero Initialization Tests" begin include("zeroinittests.jl") end
2424
@time @safetestset "Non-Square Tests" begin include("nonsquare.jl") end
25+
@time @safetestset "SparseVector b Tests" begin include("sparse_vector.jl") end
2526
end
2627

2728
if GROUP == "LinearSolveCUDA"

test/sparse_vector.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using SparseArrays
2+
using LinearSolve
3+
using LinearAlgebra
4+
5+
# Constructing sparse array
6+
function hess_sparse(x::Vector{T}) where T
7+
return [-sin(x[1] + x[2]) + 1, -sin(x[1] + x[2]), -sin(x[1] + x[2]), -sin(x[1] + x[2]) + 1.0, 1.0, 1.0, 12.0*x[5]^2 + 1.0, 1.0]
8+
end
9+
rowval = [1, 1, 2, 2, 3, 4, 5, 6]
10+
colval = [1, 2, 1, 2, 3, 4, 5, 6]
11+
12+
# Constructing sparse vec
13+
function grad_sparse(x::Vector{T}) where T<: Number
14+
return [cos(x[1] + x[2]), cos(x[1] + x[2]), 2*x[3], 1/2, 4*x[5]^3, 1/2]
15+
end
16+
gradinds = [1, 2, 3, 4, 5, 6]
17+
18+
# Forming the matrix and vector
19+
x0 = [0.7853981648713337, 0.7853981693418342, 1.023999999999997e-7, -1.0, 0.33141395338218227, -1.0]
20+
n = length(x0)
21+
hess_mat = sparse(rowval, colval, hess_sparse(x0), n, n)
22+
grad_vec = sparsevec(gradinds, grad_sparse(x0), n)
23+
24+
# # Converting grad_vec to dense succeds in solving
25+
prob = LinearProblem(hess_mat, grad_vec)
26+
linsolve = init(prob)
27+
@test solve(linsolve).u hess_mat \ Array(grad_vec)

0 commit comments

Comments
 (0)