Skip to content

Commit aa4356e

Browse files
Change default to use KrylovJL_GMRES on non Array
Because this is GPU-compatible. IterativeSolvers.jl does some low rank Q updating which might make it more efficient on CPU but precludes it from being used on the GPU, so if that's the case then we use IterativeSolvers for the speed but more generally fall back to Krylov.jl
1 parent e91e4ac commit aa4356e

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

src/default.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,
77
A = A.A
88
end
99

10+
# Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when
11+
# it makes sense according to the benchmarks, which is dependent on
12+
# whether MKL or OpenBLAS is being used
1013
if A isa Matrix
1114
if ArrayInterface.can_setindex(cache.b) && (size(A,1) <= 100 ||
1215
(isopenblas() && size(A,1) <= 500)
@@ -17,6 +20,9 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,
1720
alg = LUFactorization()
1821
SciMLBase.solve(cache, alg, args...; kwargs...)
1922
end
23+
24+
# These few cases ensure the choice is optimal without the
25+
# dynamic dispatching of factorize
2026
elseif A isa Tridiagonal
2127
alg = GenericFactorization(;fact_alg=lu!)
2228
SciMLBase.solve(cache, alg, args...; kwargs...)
@@ -26,14 +32,26 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,
2632
elseif A isa SparseMatrixCSC
2733
alg = LUFactorization()
2834
SciMLBase.solve(cache, alg, args...; kwargs...)
35+
36+
# This catches the cases where a factorization overload could exist
37+
# For example, BlockBandedMatrix
2938
elseif ArrayInterface.isstructured(A)
3039
alg = GenericFactorization()
3140
SciMLBase.solve(cache, alg, args...; kwargs...)
41+
42+
# This catches the case where A is a CuMatrix
43+
# Which does not have LU fully defined
3244
elseif !(A isa AbstractDiffEqOperator)
3345
alg = QRFactorization()
3446
SciMLBase.solve(cache, alg, args...; kwargs...)
35-
else
47+
48+
# Not factorizable operator, default to only using A*x
49+
# IterativeSolvers is faster on CPU but not GPU-compatible
50+
elseif cache.u isa Array
3651
alg = IterativeSolversJL_GMRES()
3752
SciMLBase.solve(cache, alg, args...; kwargs...)
53+
else
54+
alg = KrylovJL_GMRES()
55+
SciMLBase.solve(cache, alg, args...; kwargs...)
3856
end
3957
end

0 commit comments

Comments
 (0)