Skip to content

Commit 6961555

Browse files
committed
Use AbstractSparseMatrixCSC for sparse factorizations
1 parent f218598 commit 6961555

File tree

4 files changed

+31
-19
lines changed

4 files changed

+31
-19
lines changed

lib/LinearSolvePardiso/src/LinearSolvePardiso.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ function LinearSolve.init_cacheval(alg::PardisoJL, A, b, u, Pl, Pr, maxiters::In
9090
Pardiso.set_iparm!(solver, 3, round(Int, abs(log10(reltol)), RoundDown) * 10 + 1)
9191
end
9292

93-
Pardiso.pardiso(solver, u, A, b)
93+
Pardiso.pardiso(solver, u, SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A), b)
9494

9595
return solver
9696
end

src/LinearSolve.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using Base: cache_dependencies, Bool
99
using LinearAlgebra
1010
using IterativeSolvers: Identity
1111
using SparseArrays
12+
using SparseArrays: AbstractSparseMatrixCSC, nonzeros, rowvals, getcolptr
1213
using SciMLBase: AbstractLinearAlgorithm
1314
using SciMLOperators
1415
using SciMLOperators: AbstractSciMLOperator, IdentityOperator

src/default.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,13 @@ function defaultalg(A::Diagonal, b, ::OperatorAssumptions{Nothing})
5050
DiagonalFactorization()
5151
end
5252

53-
function defaultalg(A::SparseMatrixCSC{Tv, Ti}, b,
53+
function defaultalg(A::AbstractSparseMatrixCSC{Tv, Ti}, b,
5454
::OperatorAssumptions{true}) where {Tv, Ti}
5555
SparspakFactorization()
5656
end
5757

5858
@static if INCLUDE_SPARSE
59-
function defaultalg(A::SparseMatrixCSC{<:Union{Float64, ComplexF64}, Ti}, b,
59+
function defaultalg(A::AbstractSparseMatrixCSC{<:Union{Float64, ComplexF64}, Ti}, b,
6060
::OperatorAssumptions{true}) where {Ti}
6161
if length(b) <= 10_000
6262
KLUFactorization()

src/factorization.jl

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ end
5252

5353
function do_factorization(alg::LUFactorization, A, b, u)
5454
A = convert(AbstractMatrix, A)
55-
if A isa SparseMatrixCSC
56-
return lu(A)
55+
if A isa AbstractSparseMatrixCSC
56+
return lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)))
5757
else
5858
fact = lu!(A, alg.pivot)
5959
end
@@ -277,7 +277,8 @@ function init_cacheval(alg::UMFPACKFactorization, A, b, u, Pl, Pr, maxiters::Int
277277
copy(nonzeros(A)), 0)
278278
finalizer(SuiteSparse.UMFPACK.umfpack_free_symbolic, res)
279279
else
280-
return SuiteSparse.UMFPACK.UmfpackLU(A)
280+
return SuiteSparse.UMFPACK.UmfpackLU(SparseMatrixCSC(size(A)..., getcolptr(A),
281+
rowvals(A), nonzeros(A)))
281282
end
282283
end
283284

@@ -290,12 +291,15 @@ function SciMLBase.solve(cache::LinearCache, alg::UMFPACKFactorization; kwargs..
290291
if alg.check_pattern && !(SuiteSparse.decrement(SparseArrays.getcolptr(A)) ==
291292
cache.cacheval.colptr &&
292293
SuiteSparse.decrement(SparseArrays.getrowval(A)) == cache.cacheval.rowval)
293-
fact = lu(A)
294+
fact = lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
295+
nonzeros(A)))
294296
else
295-
fact = lu!(cache.cacheval, A)
297+
fact = lu!(cache.cacheval,
298+
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
299+
nonzeros(A)))
296300
end
297301
else
298-
fact = lu(A)
302+
fact = lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)))
299303
end
300304
cache = set_cacheval(cache, fact)
301305
end
@@ -312,7 +316,8 @@ end
312316
function init_cacheval(alg::KLUFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol,
313317
reltol,
314318
verbose::Bool, assumptions::OperatorAssumptions)
315-
return KLU.KLUFactorization(convert(AbstractMatrix, A)) # this takes care of the copy internally.
319+
return KLU.KLUFactorization(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
320+
nonzeros(A)))
316321
end
317322

318323
function SciMLBase.solve(cache::LinearCache, alg::KLUFactorization; kwargs...)
@@ -323,21 +328,25 @@ function SciMLBase.solve(cache::LinearCache, alg::KLUFactorization; kwargs...)
323328
if alg.check_pattern && !(SuiteSparse.decrement(SparseArrays.getcolptr(A)) ==
324329
cache.cacheval.colptr &&
325330
SuiteSparse.decrement(SparseArrays.getrowval(A)) == cache.cacheval.rowval)
326-
fact = KLU.klu(A)
331+
fact = KLU.klu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
332+
nonzeros(A)))
327333
else
328334
# If we have a cacheval already, run umfpack_symbolic to ensure the symbolic factorization exists
329335
# This won't recompute if it does.
330336
KLU.klu_analyze!(cache.cacheval)
331-
copyto!(cache.cacheval.nzval, A.nzval)
337+
copyto!(cache.cacheval.nzval, nonzeros(A))
332338
if cache.cacheval._numeric === C_NULL # We MUST have a numeric factorization for reuse, unlike UMFPACK.
333339
KLU.klu_factor!(cache.cacheval)
334340
end
335-
fact = KLU.klu!(cache.cacheval, A)
341+
fact = KLU.klu!(cache.cacheval,
342+
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
343+
nonzeros(A)))
336344
end
337345
else
338346
# New fact each time since the sparsity pattern can change
339347
# and thus it needs to reallocate
340-
fact = KLU.klu(A)
348+
fact = KLU.klu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
349+
nonzeros(A)))
341350
end
342351
cache = set_cacheval(cache, fact)
343352
end
@@ -510,18 +519,20 @@ end
510519
function init_cacheval(::SparspakFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol,
511520
reltol,
512521
verbose::Bool, assumptions::OperatorAssumptions)
513-
A = convert(AbstractMatrix, A)
514-
return sparspaklu(A, factorize = false)
522+
return sparspaklu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
523+
factorize = false)
515524
end
516525

517526
function SciMLBase.solve(cache::LinearCache, alg::SparspakFactorization; kwargs...)
518527
A = cache.A
519-
A = convert(AbstractMatrix, A)
520528
if cache.isfresh
521529
if cache.cacheval !== nothing && alg.reuse_symbolic
522-
fact = sparspaklu!(cache.cacheval, A)
530+
fact = sparspaklu!(cache.cacheval,
531+
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
532+
nonzeros(A)))
523533
else
524-
fact = sparspaklu(A)
534+
fact = sparspaklu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
535+
nonzeros(A)))
525536
end
526537
cache = set_cacheval(cache, fact)
527538
end

0 commit comments

Comments
 (0)