Skip to content

Commit 4085d31

Browse files
Merge pull request #327 from SciML/umfpack
Fix and better test UMFPACK default
2 parents b5ff8b1 + 03ae773 commit 4085d31

File tree

2 files changed

+21
-5
lines changed

2 files changed

+21
-5
lines changed

src/factorization.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ end
8484
function do_factorization(alg::LUFactorization, A, b, u)
8585
A = convert(AbstractMatrix, A)
8686
if A isa AbstractSparseMatrixCSC
87-
return lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)), check=false)
87+
return lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
88+
check = false)
8889
else
8990
fact = lu!(A, alg.pivot, check = false)
9091
end
@@ -707,16 +708,17 @@ function SciMLBase.solve!(cache::LinearCache, alg::UMFPACKFactorization; kwargs.
707708
A = cache.A
708709
A = convert(AbstractMatrix, A)
709710
if cache.isfresh
711+
cacheval = @get_cacheval(cache, :UMFPACKFactorization)
710712
if alg.reuse_symbolic
711713
# Caches the symbolic factorization: https://github.com/JuliaLang/julia/pull/33738
712714
if alg.check_pattern && !(SuiteSparse.decrement(SparseArrays.getcolptr(A)) ==
713-
cache.cacheval.colptr &&
715+
cacheval.colptr &&
714716
SuiteSparse.decrement(SparseArrays.getrowval(A)) ==
715-
@get_cacheval(cache, :UMFPACKFactorization).rowval)
717+
cacheval.rowval)
716718
fact = lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
717719
nonzeros(A)))
718720
else
719-
fact = lu!(@get_cacheval(cache, :UMFPACKFactorization),
721+
fact = lu!(cacheval,
720722
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
721723
nonzeros(A)))
722724
end

test/default_algs.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,35 @@
11
using LinearSolve, LinearAlgebra, SparseArrays, Test, JET
22
@test LinearSolve.defaultalg(nothing, zeros(3)).alg ===
33
LinearSolve.DefaultAlgorithmChoice.GenericLUFactorization
4+
prob = LinearProblem(rand(3, 3), rand(3))
5+
solve(prob)
6+
47
@test LinearSolve.defaultalg(nothing, zeros(50)).alg ===
58
LinearSolve.DefaultAlgorithmChoice.RFLUFactorization
9+
prob = LinearProblem(rand(50, 50), rand(50))
10+
solve(prob)
11+
612
@test LinearSolve.defaultalg(nothing, zeros(600)).alg ===
713
LinearSolve.DefaultAlgorithmChoice.GenericLUFactorization
14+
prob = LinearProblem(rand(600, 600), rand(600))
15+
solve(prob)
16+
817
@test LinearSolve.defaultalg(LinearAlgebra.Diagonal(zeros(5)), zeros(5)).alg ===
918
LinearSolve.DefaultAlgorithmChoice.DiagonalFactorization
1019

1120
@test LinearSolve.defaultalg(nothing, zeros(5),
1221
LinearSolve.OperatorAssumptions(false)).alg ===
1322
LinearSolve.DefaultAlgorithmChoice.QRFactorization
1423

15-
@test LinearSolve.defaultalg(sprand(1000, 1000, 0.01), zeros(1000)).alg ===
24+
@test LinearSolve.defaultalg(sprand(1000, 1000, 0.5), zeros(1000)).alg ===
1625
LinearSolve.DefaultAlgorithmChoice.KLUFactorization
26+
prob = LinearProblem(sprand(1000, 1000, 0.5), zeros(1000))
27+
solve(prob)
28+
1729
@test LinearSolve.defaultalg(sprand(11000, 11000, 0.001), zeros(11000)).alg ===
1830
LinearSolve.DefaultAlgorithmChoice.UMFPACKFactorization
31+
prob = LinearProblem(sprand(11000, 11000, 0.5), zeros(11000))
32+
solve(prob)
1933

2034
@static if VERSION >= v"v1.7-"
2135
# Test inference

0 commit comments

Comments
 (0)