Skip to content

Commit 5840ee7

Browse files
Merge pull request #273 from j-fu/abstractsparse
Use AbstractSparseMatrixCSC for sparse factorizations
2 parents a7b7d2b + 82da413 commit 5840ee7

File tree

6 files changed

+52
-25
lines changed

6 files changed

+52
-25
lines changed

lib/LinearSolvePardiso/Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
99
Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
1010
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
11+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1112
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
1213

1314
[compat]
14-
SciMLBase = "1.25"
1515
LinearSolve = "1.24"
16-
Pardiso = "0.5"
16+
Pardiso = "0.5"
17+
SciMLBase = "1.25"
1718
UnPack = "1"
1819
julia = "1.6"
1920

lib/LinearSolvePardiso/src/LinearSolvePardiso.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
module LinearSolvePardiso
22

33
using Pardiso, LinearSolve, SciMLBase
4+
using SparseArrays
5+
using SparseArrays: nonzeros, rowvals, getcolptr
6+
47
using UnPack
58

69
Base.@kwdef struct PardisoJL <: LinearSolve.SciMLLinearSolveAlgorithm
@@ -17,8 +20,16 @@ LinearSolve.needs_concrete_A(alg::PardisoJL) = true
1720

1821
# TODO schur complement functionality
1922

20-
function LinearSolve.init_cacheval(alg::PardisoJL, A, b, u, Pl, Pr, maxiters::Int, abstol,
21-
reltol, verbose::Bool,
23+
function LinearSolve.init_cacheval(alg::PardisoJL,
24+
A,
25+
b,
26+
u,
27+
Pl,
28+
Pr,
29+
maxiters::Int,
30+
abstol,
31+
reltol,
32+
verbose::Bool,
2233
assumptions::LinearSolve.OperatorAssumptions)
2334
@unpack nprocs, solver_type, matrix_type, iparm, dparm = alg
2435
A = convert(AbstractMatrix, A)
@@ -90,7 +101,10 @@ function LinearSolve.init_cacheval(alg::PardisoJL, A, b, u, Pl, Pr, maxiters::In
90101
Pardiso.set_iparm!(solver, 3, round(Int, abs(log10(reltol)), RoundDown) * 10 + 1)
91102
end
92103

93-
Pardiso.pardiso(solver, u, A, b)
104+
Pardiso.pardiso(solver,
105+
u,
106+
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
107+
b)
94108

95109
return solver
96110
end

lib/LinearSolvePardiso/test/runtests.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@ cache_kwargs = (; verbose = true, abstol = 1e-8, reltol = 1e-8, maxiter = 30)
1717

1818
prob2 = LinearProblem(A2, b2)
1919

20-
for alg in (PardisoJL(),
21-
MKLPardisoFactorize(),
22-
MKLPardisoIterate())
20+
for alg in (PardisoJL(), MKLPardisoFactorize(), MKLPardisoIterate())
2321
u = solve(prob1, alg; cache_kwargs...).u
2422
@test A1 * u b1
2523

src/LinearSolve.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using Base: cache_dependencies, Bool
99
using LinearAlgebra
1010
using IterativeSolvers: Identity
1111
using SparseArrays
12+
using SparseArrays: AbstractSparseMatrixCSC, nonzeros, rowvals, getcolptr
1213
using SciMLBase: AbstractLinearAlgorithm
1314
using SciMLOperators
1415
using SciMLOperators: AbstractSciMLOperator, IdentityOperator

src/default.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,13 @@ function defaultalg(A::Diagonal, b, ::OperatorAssumptions{Nothing})
5050
DiagonalFactorization()
5151
end
5252

53-
function defaultalg(A::SparseMatrixCSC{Tv, Ti}, b,
53+
function defaultalg(A::AbstractSparseMatrixCSC{Tv, Ti}, b,
5454
::OperatorAssumptions{true}) where {Tv, Ti}
5555
SparspakFactorization()
5656
end
5757

5858
@static if INCLUDE_SPARSE
59-
function defaultalg(A::SparseMatrixCSC{<:Union{Float64, ComplexF64}, Ti}, b,
59+
function defaultalg(A::AbstractSparseMatrixCSC{<:Union{Float64, ComplexF64}, Ti}, b,
6060
::OperatorAssumptions{true}) where {Ti}
6161
if length(b) <= 10_000
6262
KLUFactorization()

src/factorization.jl

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ end
5252

5353
function do_factorization(alg::LUFactorization, A, b, u)
5454
A = convert(AbstractMatrix, A)
55-
if A isa SparseMatrixCSC
56-
return lu(A)
55+
if A isa AbstractSparseMatrixCSC
56+
return lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)))
5757
else
5858
fact = lu!(A, alg.pivot)
5959
end
@@ -277,7 +277,8 @@ function init_cacheval(alg::UMFPACKFactorization, A, b, u, Pl, Pr, maxiters::Int
277277
copy(nonzeros(A)), 0)
278278
finalizer(SuiteSparse.UMFPACK.umfpack_free_symbolic, res)
279279
else
280-
return SuiteSparse.UMFPACK.UmfpackLU(A)
280+
return SuiteSparse.UMFPACK.UmfpackLU(SparseMatrixCSC(size(A)..., getcolptr(A),
281+
rowvals(A), nonzeros(A)))
281282
end
282283
end
283284

@@ -290,12 +291,15 @@ function SciMLBase.solve(cache::LinearCache, alg::UMFPACKFactorization; kwargs..
290291
if alg.check_pattern && !(SuiteSparse.decrement(SparseArrays.getcolptr(A)) ==
291292
cache.cacheval.colptr &&
292293
SuiteSparse.decrement(SparseArrays.getrowval(A)) == cache.cacheval.rowval)
293-
fact = lu(A)
294+
fact = lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
295+
nonzeros(A)))
294296
else
295-
fact = lu!(cache.cacheval, A)
297+
fact = lu!(cache.cacheval,
298+
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
299+
nonzeros(A)))
296300
end
297301
else
298-
fact = lu(A)
302+
fact = lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)))
299303
end
300304
cache = set_cacheval(cache, fact)
301305
end
@@ -312,7 +316,9 @@ end
312316
function init_cacheval(alg::KLUFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol,
313317
reltol,
314318
verbose::Bool, assumptions::OperatorAssumptions)
315-
return KLU.KLUFactorization(convert(AbstractMatrix, A)) # this takes care of the copy internally.
319+
A = convert(AbstractMatrix, A)
320+
return KLU.KLUFactorization(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
321+
nonzeros(A)))
316322
end
317323

318324
function SciMLBase.solve(cache::LinearCache, alg::KLUFactorization; kwargs...)
@@ -323,21 +329,25 @@ function SciMLBase.solve(cache::LinearCache, alg::KLUFactorization; kwargs...)
323329
if alg.check_pattern && !(SuiteSparse.decrement(SparseArrays.getcolptr(A)) ==
324330
cache.cacheval.colptr &&
325331
SuiteSparse.decrement(SparseArrays.getrowval(A)) == cache.cacheval.rowval)
326-
fact = KLU.klu(A)
332+
fact = KLU.klu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
333+
nonzeros(A)))
327334
else
328335
# If we have a cacheval already, run umfpack_symbolic to ensure the symbolic factorization exists
329336
# This won't recompute if it does.
330337
KLU.klu_analyze!(cache.cacheval)
331-
copyto!(cache.cacheval.nzval, A.nzval)
338+
copyto!(cache.cacheval.nzval, nonzeros(A))
332339
if cache.cacheval._numeric === C_NULL # We MUST have a numeric factorization for reuse, unlike UMFPACK.
333340
KLU.klu_factor!(cache.cacheval)
334341
end
335-
fact = KLU.klu!(cache.cacheval, A)
342+
fact = KLU.klu!(cache.cacheval,
343+
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
344+
nonzeros(A)))
336345
end
337346
else
338347
# New fact each time since the sparsity pattern can change
339348
# and thus it needs to reallocate
340-
fact = KLU.klu(A)
349+
fact = KLU.klu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
350+
nonzeros(A)))
341351
end
342352
cache = set_cacheval(cache, fact)
343353
end
@@ -511,17 +521,20 @@ function init_cacheval(::SparspakFactorization, A, b, u, Pl, Pr, maxiters::Int,
511521
reltol,
512522
verbose::Bool, assumptions::OperatorAssumptions)
513523
A = convert(AbstractMatrix, A)
514-
return sparspaklu(A, factorize = false)
524+
return sparspaklu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
525+
factorize = false)
515526
end
516527

517528
function SciMLBase.solve(cache::LinearCache, alg::SparspakFactorization; kwargs...)
518529
A = cache.A
519-
A = convert(AbstractMatrix, A)
520530
if cache.isfresh
521531
if cache.cacheval !== nothing && alg.reuse_symbolic
522-
fact = sparspaklu!(cache.cacheval, A)
532+
fact = sparspaklu!(cache.cacheval,
533+
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
534+
nonzeros(A)))
523535
else
524-
fact = sparspaklu(A)
536+
fact = sparspaklu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
537+
nonzeros(A)))
525538
end
526539
cache = set_cacheval(cache, fact)
527540
end

0 commit comments

Comments
 (0)