Skip to content

Commit 75faf50

Browse files
Merge pull request #25 from SciML/default
Fix default solve algorithm handling
2 parents cdfe484 + 073a9c8 commit 75faf50

File tree

4 files changed

+59
-10
lines changed

4 files changed

+59
-10
lines changed

src/common.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@ function set_cacheval(cache::LinearCache, alg_cache)
4040
return cache
4141
end
4242

43-
init_cacheval(alg::SciMLLinearSolveAlgorithm, A, b, u) = nothing
43+
init_cacheval(alg::Union{SciMLLinearSolveAlgorithm,Nothing}, A, b, u) = nothing
4444

45-
function SciMLBase.init(prob::LinearProblem, alg, args...;
45+
SciMLBase.init(prob::LinearProblem, args...; kwargs...) = SciMLBase.init(prob,nothing,args...;kwargs...)
46+
47+
function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorithm,Nothing}, args...;
4648
alias_A = false, alias_b = false,
4749
kwargs...,
4850
)
@@ -83,7 +85,9 @@ function SciMLBase.init(prob::LinearProblem, alg, args...;
8385
return cache
8486
end
8587

86-
SciMLBase.solve(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
88+
SciMLBase.solve(prob::LinearProblem, args...; kwargs...) = solve(init(prob, nothing, args...; kwargs...))
89+
90+
SciMLBase.solve(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorithm,Nothing},
8791
args...; kwargs...) = solve(init(prob, alg, args...; kwargs...))
8892

8993
SciMLBase.solve(cache::LinearCache, args...; kwargs...) =

src/default.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,15 @@
33
function SciMLBase.solve(cache::LinearCache, alg::Nothing,
44
args...; kwargs...)
55
@unpack A = cache
6+
if A isa DiffEqArrayOperator
7+
A = A.A
8+
end
9+
610
if A isa Matrix
7-
if ArrayInterface.can_setindex(x) && (size(A,1) <= 100 ||
11+
if ArrayInterface.can_setindex(cache.b) && (size(A,1) <= 100 ||
812
(isopenblas() && size(A,1) <= 500)
913
)
10-
alg = GenericFactorization(;fact_alg=:(RecursiveFactorization.lu!))
14+
alg = GenericFactorization(;fact_alg=RecursiveFactorization.lu!)
1115
SciMLBase.solve(cache, alg, args...; kwargs...)
1216
else
1317
alg = LUFactorization()

src/factorization.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ function SciMLBase.solve(cache::LinearCache, alg::AbstractFactorization)
44
cache = set_cacheval(cache, fact)
55
end
66

7-
ldiv!(cache.u,cache.cacheval, cache.b)
7+
ldiv!(cache.u, cache.cacheval, cache.b)
88
end
99

1010
## LUFactorization
@@ -25,6 +25,10 @@ end
2525
function init_cacheval(alg::LUFactorization, A, b, u)
2626
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
2727
error("LU is not defined for $(typeof(A))")
28+
29+
if A isa AbstractDiffEqOperator
30+
A = A.A
31+
end
2832
fact = lu!(A, alg.pivot)
2933
return fact
3034
end
@@ -49,7 +53,10 @@ function init_cacheval(alg::QRFactorization, A, b, u)
4953
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
5054
error("QR is not defined for $(typeof(A))")
5155

52-
fact = qr!(A.A, alg.pivot; blocksize = alg.blocksize)
56+
if A isa AbstractDiffEqOperator
57+
A = A.A
58+
end
59+
fact = qr!(A, alg.pivot; blocksize = alg.blocksize)
5360
return fact
5461
end
5562

@@ -66,6 +73,10 @@ function init_cacheval(alg::SVDFactorization, A, b, u)
6673
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
6774
error("SVD is not defined for $(typeof(A))")
6875

76+
if A isa AbstractDiffEqOperator
77+
A = A.A
78+
end
79+
6980
fact = svd!(A; full = alg.full, alg = alg.alg)
7081
return fact
7182
end
@@ -83,6 +94,9 @@ function init_cacheval(alg::GenericFactorization, A, b, u)
8394
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
8495
error("GenericFactorization is not defined for $(typeof(A))")
8596

97+
if A isa AbstractDiffEqOperator
98+
A = A.A
99+
end
86100
fact = alg.fact_alg(A)
87101
return fact
88102
end

test/runtests.jl

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using LinearSolve, LinearAlgebra
1+
using LinearSolve, LinearAlgebra, SparseArrays
22
using Test
33

44
n = 8
@@ -32,12 +32,39 @@ function test_interface(alg, prob1, prob2)
3232
return
3333
end
3434

35+
@testset "Default Linear Solver" begin
36+
test_interface(nothing, prob1, prob2)
37+
38+
A1 = prob1.A; b1 = prob1.b; x1 = prob1.u0
39+
y = solve(prob1)
40+
@test A1 * y b1
41+
42+
_prob = LinearProblem(SymTridiagonal(A1.A), b1; u0=x1)
43+
y = solve(prob1)
44+
@test A1 * y b1
45+
46+
_prob = LinearProblem(Tridiagonal(A1.A), b1; u0=x1)
47+
y = solve(prob1)
48+
@test A1 * y b1
49+
50+
_prob = LinearProblem(Symmetric(A1.A), b1; u0=x1)
51+
y = solve(prob1)
52+
@test A1 * y b1
53+
54+
_prob = LinearProblem(Hermitian(A1.A), b1; u0=x1)
55+
y = solve(prob1)
56+
@test A1 * y b1
57+
58+
_prob = LinearProblem(sparse(A1.A), b1; u0=x1)
59+
y = solve(prob1)
60+
@test A1 * y b1
61+
end
62+
3563
@testset "Concrete Factorizations" begin
3664
for alg in (
3765
LUFactorization(),
3866
QRFactorization(),
39-
SVDFactorization(),
40-
#nothing
67+
SVDFactorization()
4168
)
4269
@testset "$alg" begin
4370
test_interface(alg, prob1, prob2)

0 commit comments

Comments
 (0)