Skip to content

Commit aa19103

Browse files
committed
Simple quick fix for refactor issue
1 parent 8550cf9 commit aa19103

File tree

3 files changed

+32
-16
lines changed

3 files changed

+32
-16
lines changed

src/factorization.jl

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ end
267267

268268
Base.@kwdef struct UMFPACKFactorization <: AbstractFactorization
269269
reuse_symbolic::Bool = true
270+
check::Bool = true # Check factorization re-use
270271
end
271272

272273
function init_cacheval(alg::UMFPACKFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol,
@@ -290,7 +291,13 @@ function SciMLBase.solve(cache::LinearCache, alg::UMFPACKFactorization; kwargs..
290291
if cache.isfresh
291292
if cache.cacheval !== nothing && alg.reuse_symbolic
292293
# Caches the symbolic factorization: https://github.com/JuliaLang/julia/pull/33738
293-
fact = lu!(cache.cacheval, A)
294+
if alg.check && !(SuiteSparse.decrement(SparseArrays.getcolptr(A)) ==
295+
cache.cacheval.colptr &&
296+
SuiteSparse.decrement(SparseArrays.getrowval(A)) == cache.cacheval.rowval)
297+
fact = lu(A)
298+
else
299+
fact = lu!(cache.cacheval, A)
300+
end
294301
else
295302
fact = lu(A)
296303
end
@@ -303,6 +310,7 @@ end
303310

304311
Base.@kwdef struct KLUFactorization <: AbstractFactorization
305312
reuse_symbolic::Bool = true
313+
check::Bool = true
306314
end
307315

308316
function init_cacheval(alg::KLUFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol,
@@ -316,14 +324,20 @@ function SciMLBase.solve(cache::LinearCache, alg::KLUFactorization; kwargs...)
316324
A = convert(AbstractMatrix, A)
317325
if cache.isfresh
318326
if cache.cacheval !== nothing && alg.reuse_symbolic
319-
# If we have a cacheval already, run umfpack_symbolic to ensure the symbolic factorization exists
320-
# This won't recompute if it does.
321-
KLU.klu_analyze!(cache.cacheval)
322-
copyto!(cache.cacheval.nzval, A.nzval)
323-
if cache.cacheval._numeric === C_NULL # We MUST have a numeric factorization for reuse, unlike UMFPACK.
324-
KLU.klu_factor!(cache.cacheval)
327+
if alg.check && !(SuiteSparse.decrement(SparseArrays.getcolptr(A)) ==
328+
cache.cacheval.colptr &&
329+
SuiteSparse.decrement(SparseArrays.getrowval(A)) == cache.cacheval.rowval)
330+
fact = KLU.klu(A)
331+
else
332+
# If we have a cacheval already, run umfpack_symbolic to ensure the symbolic factorization exists
333+
# This won't recompute if it does.
334+
KLU.klu_analyze!(cache.cacheval)
335+
copyto!(cache.cacheval.nzval, A.nzval)
336+
if cache.cacheval._numeric === C_NULL # We MUST have a numeric factorization for reuse, unlike UMFPACK.
337+
KLU.klu_factor!(cache.cacheval)
338+
end
339+
fact = KLU.klu!(cache.cacheval, A)
325340
end
326-
fact = KLU.klu!(cache.cacheval, A)
327341
else
328342
# New fact each time since the sparsity pattern can change
329343
# and thus it needs to reallocate

test/basictests.jl

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,11 @@ end
9191
test_interface(UMFPACKFactorization(), prob1, prob2)
9292
test_interface(UMFPACKFactorization(reuse_symbolic = false), prob1, prob2)
9393

94-
# Test that refactoring wrong throws.
94+
# Test that refactoring is checked and handled.
9595
cache = SciMLBase.init(prob1, UMFPACKFactorization(); cache_kwargs...) # initialize cache
9696
y = solve(cache)
9797
cache = LinearSolve.set_A(cache, sprand(n, n, 0.8))
98-
@test_throws ArgumentError solve(cache)
98+
y2 = solve(cache) # we just need to know this doesn't fail.
9999
end
100100

101101
@testset "KLU Factorization" begin
@@ -111,14 +111,11 @@ end
111111
test_interface(KLUFactorization(), prob1, prob2)
112112
test_interface(KLUFactorization(reuse_symbolic = false), prob1, prob2)
113113

114-
# Test that refactoring wrong throws.
114+
# Test that refactoring wrong is checked and handled.
115115
cache = SciMLBase.init(prob1, KLUFactorization(); cache_kwargs...) # initialize cache
116116
y = solve(cache)
117-
X = copy(A1)
118-
X[8, 8] = 0.0
119-
X[7, 8] = 1.0
120-
cache = LinearSolve.set_A(cache, sparse(X))
121-
@test_throws ArgumentError solve(cache)
117+
cache = LinearSolve.set_A(cache, sprand(n, n, 0.8))
118+
y2 = solve(cache)
122119
end
123120

124121
@testset "FastLAPACK Factorizations" begin

test/zeroinittests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ A = Diagonal(ones(4))
44
b = rand(4)
55
A = sparse(A)
66
Anz = deepcopy(A)
7+
C = copy(A)
8+
C[begin, end] = 1e-8
79
A.nzval .= 0
810
cache_kwargs = (; verbose = true, abstol = 1e-8, reltol = 1e-8, maxiter = 30)
911

@@ -14,6 +16,9 @@ function test_nonzero_init(alg = nothing)
1416
cache = LinearSolve.set_A(cache, Anz)
1517
sol = solve(cache; cache_kwargs...)
1618
@test sol.u == b
19+
cache = LinearSolve.set_A(cache, C)
20+
sol = solve(cache; cache_kwargs...)
21+
@test sol.u b
1722
end
1823

1924
test_nonzero_init()

0 commit comments

Comments
 (0)