Skip to content

Commit c4f8ed5

Browse files
Fix default solve algorithm handling
1 parent cdfe484 commit c4f8ed5

File tree

4 files changed

+37
-8
lines changed

4 files changed

+37
-8
lines changed

src/common.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@ function set_cacheval(cache::LinearCache, alg_cache)
4040
return cache
4141
end
4242

43-
init_cacheval(alg::SciMLLinearSolveAlgorithm, A, b, u) = nothing
43+
init_cacheval(alg::Union{SciMLLinearSolveAlgorithm,Nothing}, A, b, u) = nothing
4444

45-
function SciMLBase.init(prob::LinearProblem, alg, args...;
45+
SciMLBase.init(prob::LinearProblem, args...; kwargs...) = SciMLBase.init(prob,nothing,args...;kwargs...)
46+
47+
function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorithm,Nothing}, args...;
4648
alias_A = false, alias_b = false,
4749
kwargs...,
4850
)
@@ -83,7 +85,9 @@ function SciMLBase.init(prob::LinearProblem, alg, args...;
8385
return cache
8486
end
8587

86-
SciMLBase.solve(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
88+
SciMLBase.solve(prob::LinearProblem, args...; kwargs...) = solve(init(prob, nothing, args...; kwargs...))
89+
90+
SciMLBase.solve(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorithm,Nothing},
8791
args...; kwargs...) = solve(init(prob, alg, args...; kwargs...))
8892

8993
SciMLBase.solve(cache::LinearCache, args...; kwargs...) =

src/default.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,15 @@
33
function SciMLBase.solve(cache::LinearCache, alg::Nothing,
44
args...; kwargs...)
55
@unpack A = cache
6+
if A isa DiffEqArrayOperator
7+
A = A.A
8+
end
9+
610
if A isa Matrix
7-
if ArrayInterface.can_setindex(x) && (size(A,1) <= 100 ||
11+
if ArrayInterface.can_setindex(cache.b) && (size(A,1) <= 100 ||
812
(isopenblas() && size(A,1) <= 500)
913
)
10-
alg = GenericFactorization(;fact_alg=:(RecursiveFactorization.lu!))
14+
alg = GenericFactorization(;fact_alg=RecursiveFactorization.lu!)
1115
SciMLBase.solve(cache, alg, args...; kwargs...)
1216
else
1317
alg = LUFactorization()

src/factorization.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ end
2525
function init_cacheval(alg::LUFactorization, A, b, u)
2626
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
2727
error("LU is not defined for $(typeof(A))")
28+
29+
if A isa AbstractDiffEqOperator
30+
A = A.A
31+
end
2832
fact = lu!(A, alg.pivot)
2933
return fact
3034
end
@@ -49,7 +53,10 @@ function init_cacheval(alg::QRFactorization, A, b, u)
4953
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
5054
error("QR is not defined for $(typeof(A))")
5155

52-
fact = qr!(A.A, alg.pivot; blocksize = alg.blocksize)
56+
if A isa AbstractDiffEqOperator
57+
A = A.A
58+
end
59+
fact = qr!(A, alg.pivot; blocksize = alg.blocksize)
5360
return fact
5461
end
5562

@@ -66,6 +73,10 @@ function init_cacheval(alg::SVDFactorization, A, b, u)
6673
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
6774
error("SVD is not defined for $(typeof(A))")
6875

76+
if A isa AbstractDiffEqOperator
77+
A = A.A
78+
end
79+
6980
fact = svd!(A; full = alg.full, alg = alg.alg)
7081
return fact
7182
end
@@ -83,6 +94,9 @@ function init_cacheval(alg::GenericFactorization, A, b, u)
8394
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
8495
error("GenericFactorization is not defined for $(typeof(A))")
8596

97+
if A isa AbstractDiffEqOperator
98+
A = A.A
99+
end
86100
fact = alg.fact_alg(A)
87101
return fact
88102
end

test/runtests.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,19 @@ function test_interface(alg, prob1, prob2)
3232
return
3333
end
3434

35+
@testset "Default Linear Solver" begin
36+
test_interface(nothing, prob1, prob2)
37+
38+
A1 = prob1.A; b1 = prob1.b; x1 = prob1.u0
39+
y = solve(prob1)
40+
@test A1 * y b1
41+
end
42+
3543
@testset "Concrete Factorizations" begin
3644
for alg in (
3745
LUFactorization(),
3846
QRFactorization(),
39-
SVDFactorization(),
40-
#nothing
47+
SVDFactorization()
4148
)
4249
@testset "$alg" begin
4350
test_interface(alg, prob1, prob2)

0 commit comments

Comments
 (0)