Skip to content

Commit c7ad0b9

Browse files
authored
Make cholmod vector solve return vector (#320)
1 parent e530f2f commit c7ad0b9

File tree

3 files changed

+60
-35
lines changed

3 files changed

+60
-35
lines changed

src/linalg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -613,10 +613,10 @@ end
613613
## triangular solvers
614614
function ldiv!(A::TriangularSparse{T}, B::StridedVecOrMat{T}) where T
615615
require_one_based_indexing(A, B)
616-
nrowB, ncolB = size(B, 1), size(B, 2)
616+
nrowB = size(B, 1)
617617
ncol = LinearAlgebra.checksquare(A)
618618
if nrowB != ncol
619-
throw(DimensionMismatch("A is $(ncol) columns and B has $(nrowB) rows"))
619+
throw(DimensionMismatch("A has $(ncol) columns and B has $(nrowB) rows"))
620620
end
621621
_ldiv!(A, B)
622622
end

src/solvers/cholmod.jl

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ export
2828
Factor,
2929
Sparse
3030

31-
import SparseArrays: AbstractSparseMatrix, SparseMatrixCSC, indtype, sparse, spzeros, nnz
31+
import SparseArrays: AbstractSparseMatrix, SparseMatrixCSC, indtype, sparse, spzeros, nnz,
32+
sparsevec
3233

3334
import ..increment, ..increment!, ..AdjType, ..TransType
3435

@@ -864,7 +865,18 @@ function _trim_nz_builder!(m, n, colptr, rowval, nzval)
864865
l = colptr[end] - 1
865866
resize!(rowval, l)
866867
resize!(nzval, l)
867-
SparseMatrixCSC(m, n, colptr, rowval, nzval)
868+
return (m, n, colptr, rowval, nzval)
869+
end
870+
871+
function SparseVector{Tv,SuiteSparse_long}(A::Sparse{Tv}) where Tv
872+
s = unsafe_load(pointer(A))
873+
if s.stype != 0
874+
throw(ArgumentError("matrix has stype != 0. Convert to matrix " *
875+
"with stype == 0 before converting to SparseVector"))
876+
end
877+
args = _extract_args(s, Tv)
878+
s.sorted == 0 && _sort_buffers!(args...);
879+
return SparseVector(args[1], args[4], args[5])
868880
end
869881

870882
function SparseMatrixCSC{Tv,SuiteSparse_long}(A::Sparse{Tv}) where Tv
@@ -875,15 +887,15 @@ function SparseMatrixCSC{Tv,SuiteSparse_long}(A::Sparse{Tv}) where Tv
875887
end
876888
args = _extract_args(s, Tv)
877889
s.sorted == 0 && _sort_buffers!(args...);
878-
return _trim_nz_builder!(args...)
890+
return SparseMatrixCSC(_trim_nz_builder!(args...)...)
879891
end
880892

881893
function Symmetric{Float64,SparseMatrixCSC{Float64,SuiteSparse_long}}(A::Sparse{Float64})
882894
s = unsafe_load(pointer(A))
883895
issymmetric(A) || throw(ArgumentError("matrix is not symmetric"))
884896
args = _extract_args(s, Float64)
885897
s.sorted == 0 && _sort_buffers!(args...)
886-
Symmetric(_trim_nz_builder!(args...), s.stype > 0 ? :U : :L)
898+
Symmetric(SparseMatrixCSC(_trim_nz_builder!(args...)...), s.stype > 0 ? :U : :L)
887899
end
888900
convert(T::Type{Symmetric{Float64,SparseMatrixCSC{Float64,SuiteSparse_long}}}, A::Sparse{Float64}) = T(A)
889901

@@ -892,10 +904,16 @@ function Hermitian{Tv,SparseMatrixCSC{Tv,SuiteSparse_long}}(A::Sparse{Tv}) where
892904
ishermitian(A) || throw(ArgumentError("matrix is not Hermitian"))
893905
args = _extract_args(s, Tv)
894906
s.sorted == 0 && _sort_buffers!(args...)
895-
Hermitian(_trim_nz_builder!(args...), s.stype > 0 ? :U : :L)
907+
Hermitian(SparseMatrixCSC(_trim_nz_builder!(args...)...), s.stype > 0 ? :U : :L)
896908
end
897909
convert(T::Type{Hermitian{Tv,SparseMatrixCSC{Tv,SuiteSparse_long}}}, A::Sparse{Tv}) where {Tv<:VTypes} = T(A)
898910

911+
function sparsevec(A::Sparse{Tv}) where {Tv}
912+
s = unsafe_load(pointer(A))
913+
@assert s.stype == 0
914+
return SparseVector{Tv,SuiteSparse_long}(A)
915+
end
916+
899917
function sparse(A::Sparse{Float64}) # Notice! Cannot be type stable because of stype
900918
s = unsafe_load(pointer(A))
901919
if s.stype == 0
@@ -1527,7 +1545,10 @@ end
15271545
function (\)(L::FactorComponent, B::Matrix)
15281546
Matrix(L\Dense(B))
15291547
end
1530-
function (\)(L::FactorComponent, B::SparseVecOrMat)
1548+
function (\)(L::FactorComponent, B::SparseVector)
1549+
sparsevec(L\Sparse(B))
1550+
end
1551+
function (\)(L::FactorComponent, B::SparseMatrixCSC)
15311552
sparse(L\Sparse(B,0))
15321553
end
15331554
(\)(L::FactorComponent, B::Adjoint{<:Any,<:SparseMatrixCSC}) = L \ copy(B)
@@ -1553,7 +1574,7 @@ end
15531574
(\)(L::Factor, B::SparseMatrixCSC) = sparse(spsolve(CHOLMOD_A, L, Sparse(B, 0)))
15541575
(\)(L::Factor, B::Adjoint{<:Any,<:SparseMatrixCSC}) = L \ copy(B)
15551576
(\)(L::Factor, B::Transpose{<:Any,<:SparseMatrixCSC}) = L \ copy(B)
1556-
(\)(L::Factor, B::SparseVector) = sparse(spsolve(CHOLMOD_A, L, Sparse(B)))
1577+
(\)(L::Factor, B::SparseVector) = sparsevec(spsolve(CHOLMOD_A, L, Sparse(B)))
15571578

15581579
\(adjL::AdjType{<:Any,<:Factor}, B::Dense) = (L = adjL.parent; solve(CHOLMOD_A, L, B))
15591580
\(adjL::AdjType{<:Any,<:Factor}, B::Sparse) = (L = adjL.parent; spsolve(CHOLMOD_A, L, B))

test/cholmod.jl

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -515,16 +515,17 @@ end
515515
@test sparse(Fs.L) Lf
516516
@test sparse(Fs) As
517517
b = rand(3)
518-
@test Fs\b Af\b
518+
bs = sparse(b)
519+
@test Fs\b Af\b (Fs\bs)::SparseVector
519520
@test Fs.UP\(Fs.PtL\b) Af\b
520-
@test Fs.L\b Lf\b
521-
@test Fs.U\b Lf'\b
522-
@test Fs.L'\b Lf'\b
523-
@test Fs.U'\b Lf\b
524-
@test Fs.PtL\b Lf\b
525-
@test Fs.UP\b Lf'\b
526-
@test Fs.PtL'\b Lf'\b
527-
@test Fs.UP'\b Lf\b
521+
@test Fs.L\b Lf\b (Fs.L\bs)::SparseVector
522+
@test Fs.U\b Lf'\b (Fs.U\bs)::SparseVector
523+
@test Fs.L'\b Lf'\b (Fs.L'\bs)::SparseVector
524+
@test Fs.U'\b Lf\b (Fs.U'\bs)::SparseVector
525+
@test Fs.PtL\b Lf\b (Fs.PtL\bs)::SparseVector
526+
@test Fs.UP\b Lf'\b (Fs.UP\bs)::SparseVector
527+
@test Fs.PtL'\b Lf'\b (Fs.PtL'\bs)::SparseVector
528+
@test Fs.UP'\b Lf\b (Fs.UP'\bs)::SparseVector
528529
@test_throws CHOLMOD.CHOLMODException Fs.D
529530
@test_throws CHOLMOD.CHOLMODException Fs.LD
530531
@test_throws CHOLMOD.CHOLMODException Fs.DU
@@ -544,16 +545,17 @@ end
544545
@test P' * Ls * Ls' * P As
545546
@test sparse(Fs) As
546547
b = rand(3)
547-
@test Fs\b Af\b
548+
bs = sparse(b)
549+
@test Fs\b Af\b (Fs\bs)::SparseVector
548550
@test Fs.UP\(Fs.PtL\b) Af\b
549-
@test Fs.L\b Lfp\b
550-
@test Fs.U'\b Lfp\b
551-
@test Fs.U\b Lfp'\b
552-
@test Fs.L'\b Lfp'\b
553-
@test Fs.PtL\b Lfp\b[p]
554-
@test Fs.UP\b (Lfp'\b)[p_inv]
555-
@test Fs.PtL'\b (Lfp'\b)[p_inv]
556-
@test Fs.UP'\b Lfp\b[p]
551+
@test Fs.L\b Lfp\b (Fs.L\bs)::SparseVector
552+
@test Fs.U'\b Lfp\b (Fs.U'\bs)::SparseVector
553+
@test Fs.U\b Lfp'\b (Fs.U\bs)::SparseVector
554+
@test Fs.L'\b Lfp'\b (Fs.L'\bs)::SparseVector
555+
@test Fs.PtL\b Lfp\b[p] (Fs.PtL\bs)::SparseVector
556+
@test Fs.UP\b (Lfp'\b)[p_inv] (Fs.UP\bs)::SparseVector
557+
@test Fs.PtL'\b (Lfp'\b)[p_inv] (Fs.PtL'\bs)::SparseVector
558+
@test Fs.UP'\b Lfp\b[p] (Fs.UP'\bs)::SparseVector
557559
@test_throws CHOLMOD.CHOLMODException Fs.PL
558560
@test_throws CHOLMOD.CHOLMODException Fs.UPt
559561
@test_throws CHOLMOD.CHOLMODException Fs.D
@@ -569,14 +571,15 @@ end
569571
@test sparse(Fs.LD) LDf
570572
@test sparse(Fs) As
571573
b = rand(3)
572-
@test Fs\b Af\b
574+
bs = sparse(b)
575+
@test Fs\b Af\b (Fs\bs)::SparseVector
573576
@test Fs.UP\(Fs.PtLD\b) Af\b
574577
@test Fs.DUP\(Fs.PtL\b) Af\b
575-
@test Fs.L\b L_f\b
576-
@test Fs.U\b L_f'\b
578+
@test Fs.L\b L_f\b (Fs.L\bs)::SparseVector
579+
@test Fs.U\b L_f'\b (Fs.U\bs)::SparseVector
577580
@test Fs.L'\b L_f'\b
578581
@test Fs.U'\b L_f\b
579-
@test Fs.PtL\b L_f\b
582+
@test Fs.PtL\b L_f\b (Fs.PtL\bs)::SparseVector
580583
@test Fs.UP\b L_f'\b
581584
@test Fs.PtL'\b L_f'\b
582585
@test Fs.UP'\b L_f\b
@@ -597,19 +600,20 @@ end
597600
@test Fs.p == p
598601
@test sparse(Fs) As
599602
b = rand(3)
603+
bs = sparse(b)
600604
Asp = As[p,p]
601605
LDp = sparse(ldlt(Asp, perm=[1,2,3]).LD)
602606
# LDp = sparse(Fs.LD)
603607
Lp, dp = CHOLMOD.getLd!(copy(LDp))
604608
Dp = sparse(Diagonal(dp))
605-
@test Fs\b Af\b
609+
@test Fs\b Af\b (Fs\bs)::SparseVector
606610
@test Fs.UP\(Fs.PtLD\b) Af\b
607611
@test Fs.DUP\(Fs.PtL\b) Af\b
608-
@test Fs.L\b Lp\b
609-
@test Fs.U\b Lp'\b
612+
@test Fs.L\b Lp\b (Fs.L\bs)::SparseVector
613+
@test Fs.U\b Lp'\b (Fs.U\bs)::SparseVector
610614
@test Fs.L'\b Lp'\b
611615
@test Fs.U'\b Lp\b
612-
@test Fs.PtL\b Lp\b[p]
616+
@test Fs.PtL\b Lp\b[p] (Fs.PtL\bs)::SparseVector
613617
@test Fs.UP\b (Lp'\b)[p_inv]
614618
@test Fs.PtL'\b (Lp'\b)[p_inv]
615619
@test Fs.UP'\b Lp\b[p]

0 commit comments

Comments
 (0)