Skip to content

Commit 2aae22a

Browse files
Merge pull request #207 from SciML/rk/fix_sparse_reuse
Simple quick fix for refactor issue
2 parents c976838 + b17f725 commit 2aae22a

File tree

3 files changed

+47
-25
lines changed

3 files changed

+47
-25
lines changed

src/factorization.jl

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ end
267267

268268
Base.@kwdef struct UMFPACKFactorization <: AbstractFactorization
269269
reuse_symbolic::Bool = true
270+
check_pattern::Bool = true # Check factorization re-use
270271
end
271272

272273
function init_cacheval(alg::UMFPACKFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol,
@@ -290,7 +291,13 @@ function SciMLBase.solve(cache::LinearCache, alg::UMFPACKFactorization; kwargs..
290291
if cache.isfresh
291292
if cache.cacheval !== nothing && alg.reuse_symbolic
292293
# Caches the symbolic factorization: https://github.com/JuliaLang/julia/pull/33738
293-
fact = lu!(cache.cacheval, A)
294+
if alg.check_pattern && !(SuiteSparse.decrement(SparseArrays.getcolptr(A)) ==
295+
cache.cacheval.colptr &&
296+
SuiteSparse.decrement(SparseArrays.getrowval(A)) == cache.cacheval.rowval)
297+
fact = lu(A)
298+
else
299+
fact = lu!(cache.cacheval, A)
300+
end
294301
else
295302
fact = lu(A)
296303
end
@@ -303,6 +310,7 @@ end
303310

304311
Base.@kwdef struct KLUFactorization <: AbstractFactorization
305312
reuse_symbolic::Bool = true
313+
check_pattern::Bool = true
306314
end
307315

308316
function init_cacheval(alg::KLUFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol,
@@ -316,14 +324,20 @@ function SciMLBase.solve(cache::LinearCache, alg::KLUFactorization; kwargs...)
316324
A = convert(AbstractMatrix, A)
317325
if cache.isfresh
318326
if cache.cacheval !== nothing && alg.reuse_symbolic
319-
# If we have a cacheval already, run umfpack_symbolic to ensure the symbolic factorization exists
320-
# This won't recompute if it does.
321-
KLU.klu_analyze!(cache.cacheval)
322-
copyto!(cache.cacheval.nzval, A.nzval)
323-
if cache.cacheval._numeric === C_NULL # We MUST have a numeric factorization for reuse, unlike UMFPACK.
324-
KLU.klu_factor!(cache.cacheval)
327+
if alg.check_pattern && !(SuiteSparse.decrement(SparseArrays.getcolptr(A)) ==
328+
cache.cacheval.colptr &&
329+
SuiteSparse.decrement(SparseArrays.getrowval(A)) == cache.cacheval.rowval)
330+
fact = KLU.klu(A)
331+
else
332+
# If we have a cacheval already, run umfpack_symbolic to ensure the symbolic factorization exists
333+
# This won't recompute if it does.
334+
KLU.klu_analyze!(cache.cacheval)
335+
copyto!(cache.cacheval.nzval, A.nzval)
336+
if cache.cacheval._numeric === C_NULL # We MUST have a numeric factorization for reuse, unlike UMFPACK.
337+
KLU.klu_factor!(cache.cacheval)
338+
end
339+
fact = KLU.klu!(cache.cacheval, A)
325340
end
326-
fact = KLU.klu!(cache.cacheval, A)
327341
else
328342
# New fact each time since the sparsity pattern can change
329343
# and thus it needs to reallocate

test/basictests.jl

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -79,46 +79,49 @@ end
7979
end
8080

8181
@testset "UMFPACK Factorization" begin
82-
A1 = A / 1
82+
A1 = sparse(A / 1)
8383
b1 = rand(n)
8484
x1 = zero(b)
85-
A2 = A / 2
85+
A2 = sparse(A / 2)
8686
b2 = rand(n)
8787
x2 = zero(b)
8888

89-
prob1 = LinearProblem(sparse(A1), b1; u0 = x1)
90-
prob2 = LinearProblem(sparse(A2), b2; u0 = x2)
89+
prob1 = LinearProblem(A1, b1; u0 = x1)
90+
prob2 = LinearProblem(A2, b2; u0 = x2)
9191
test_interface(UMFPACKFactorization(), prob1, prob2)
9292
test_interface(UMFPACKFactorization(reuse_symbolic = false), prob1, prob2)
9393

94-
# Test that refactoring wrong throws.
94+
# Test that refactoring is checked and handled.
9595
cache = SciMLBase.init(prob1, UMFPACKFactorization(); cache_kwargs...) # initialize cache
9696
y = solve(cache)
97-
cache = LinearSolve.set_A(cache, sprand(n, n, 0.8))
98-
@test_throws ArgumentError solve(cache)
97+
cache = LinearSolve.set_A(cache, A2)
98+
@test A2 * solve(cache) b1
99+
X = sprand(n, n, 0.8)
100+
cache = LinearSolve.set_A(cache, X)
101+
@test X * solve(cache) b1
99102
end
100103

101104
@testset "KLU Factorization" begin
102-
A1 = A / 1
105+
A1 = sparse(A / 1)
103106
b1 = rand(n)
104107
x1 = zero(b)
105-
A2 = A / 2
108+
A2 = sparse(A / 2)
106109
b2 = rand(n)
107110
x2 = zero(b)
108111

109-
prob1 = LinearProblem(sparse(A1), b1; u0 = x1)
110-
prob2 = LinearProblem(sparse(A2), b2; u0 = x2)
112+
prob1 = LinearProblem(A1, b1; u0 = x1)
113+
prob2 = LinearProblem(A2, b2; u0 = x2)
111114
test_interface(KLUFactorization(), prob1, prob2)
112115
test_interface(KLUFactorization(reuse_symbolic = false), prob1, prob2)
113116

114-
# Test that refactoring wrong throws.
117+
# Test that refactoring wrong is checked and handled.
115118
cache = SciMLBase.init(prob1, KLUFactorization(); cache_kwargs...) # initialize cache
116119
y = solve(cache)
117-
X = copy(A1)
118-
X[8, 8] = 0.0
119-
X[7, 8] = 1.0
120-
cache = LinearSolve.set_A(cache, sparse(X))
121-
@test_throws ArgumentError solve(cache)
120+
cache = LinearSolve.set_A(cache, A2)
121+
@test A2 * solve(cache) b1
122+
X = sprand(n, n, 0.8)
123+
cache = LinearSolve.set_A(cache, X)
124+
@test X * solve(cache) b1
122125
end
123126

124127
@testset "FastLAPACK Factorizations" begin

test/zeroinittests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ A = Diagonal(ones(4))
44
b = rand(4)
55
A = sparse(A)
66
Anz = deepcopy(A)
7+
C = copy(A)
8+
C[begin, end] = 1e-8
79
A.nzval .= 0
810
cache_kwargs = (; verbose = true, abstol = 1e-8, reltol = 1e-8, maxiter = 30)
911

@@ -14,6 +16,9 @@ function test_nonzero_init(alg = nothing)
1416
cache = LinearSolve.set_A(cache, Anz)
1517
sol = solve(cache; cache_kwargs...)
1618
@test sol.u == b
19+
cache = LinearSolve.set_A(cache, C)
20+
sol = solve(cache; cache_kwargs...)
21+
@test sol.u b
1722
end
1823

1924
test_nonzero_init()

0 commit comments

Comments
 (0)