Skip to content

Commit 3be8678

Browse files
Merge pull request #34 from SciML/ChrisRackauckas-patch-1
Change default to use KrylovJL_GMRES on non Array
2 parents e91e4ac + aa4356e commit 3be8678

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

src/default.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,
77
A = A.A
88
end
99

10+
# Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when
11+
# it makes sense according to the benchmarks, which is dependent on
12+
# whether MKL or OpenBLAS is being used
1013
if A isa Matrix
1114
if ArrayInterface.can_setindex(cache.b) && (size(A,1) <= 100 ||
1215
(isopenblas() && size(A,1) <= 500)
@@ -17,6 +20,9 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,
1720
alg = LUFactorization()
1821
SciMLBase.solve(cache, alg, args...; kwargs...)
1922
end
23+
24+
# These few cases ensure the choice is optimal without the
25+
# dynamic dispatching of factorize
2026
elseif A isa Tridiagonal
2127
alg = GenericFactorization(;fact_alg=lu!)
2228
SciMLBase.solve(cache, alg, args...; kwargs...)
@@ -26,14 +32,26 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,
2632
elseif A isa SparseMatrixCSC
2733
alg = LUFactorization()
2834
SciMLBase.solve(cache, alg, args...; kwargs...)
35+
36+
# This catches the cases where a factorization overload could exist
37+
# For example, BlockBandedMatrix
2938
elseif ArrayInterface.isstructured(A)
3039
alg = GenericFactorization()
3140
SciMLBase.solve(cache, alg, args...; kwargs...)
41+
42+
# This catches the case where A is a CuMatrix
43+
# Which does not have LU fully defined
3244
elseif !(A isa AbstractDiffEqOperator)
3345
alg = QRFactorization()
3446
SciMLBase.solve(cache, alg, args...; kwargs...)
35-
else
47+
48+
# Not factorizable operator, default to only using A*x
49+
# IterativeSolvers is faster on CPU but not GPU-compatible
50+
elseif cache.u isa Array
3651
alg = IterativeSolversJL_GMRES()
3752
SciMLBase.solve(cache, alg, args...; kwargs...)
53+
else
54+
alg = KrylovJL_GMRES()
55+
SciMLBase.solve(cache, alg, args...; kwargs...)
3856
end
3957
end

0 commit comments

Comments
 (0)