@@ -7,6 +7,9 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,
7
7
A = A. A
8
8
end
9
9
10
+ # Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when
11
+ # it makes sense according to the benchmarks, which is dependent on
12
+ # whether MKL or OpenBLAS is being used
10
13
if A isa Matrix
11
14
if ArrayInterface. can_setindex (cache. b) && (size (A,1 ) <= 100 ||
12
15
(isopenblas () && size (A,1 ) <= 500 )
@@ -17,6 +20,9 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,
17
20
alg = LUFactorization ()
18
21
SciMLBase. solve (cache, alg, args... ; kwargs... )
19
22
end
23
+
24
+ # These few cases ensure the choice is optimal without the
25
+ # dynamic dispatching of factorize
20
26
elseif A isa Tridiagonal
21
27
alg = GenericFactorization (;fact_alg= lu!)
22
28
SciMLBase. solve (cache, alg, args... ; kwargs... )
@@ -26,14 +32,26 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,
26
32
elseif A isa SparseMatrixCSC
27
33
alg = LUFactorization ()
28
34
SciMLBase. solve (cache, alg, args... ; kwargs... )
35
+
36
+ # This catches the cases where a factorization overload could exist
37
+ # For example, BlockBandedMatrix
29
38
elseif ArrayInterface. isstructured (A)
30
39
alg = GenericFactorization ()
31
40
SciMLBase. solve (cache, alg, args... ; kwargs... )
41
+
42
+ # This catches the case where A is a CuMatrix
43
+ # Which does not have LU fully defined
32
44
elseif ! (A isa AbstractDiffEqOperator)
33
45
alg = QRFactorization ()
34
46
SciMLBase. solve (cache, alg, args... ; kwargs... )
35
- else
47
+
48
+ # Not factorizable operator, default to only using A*x
49
+ # IterativeSolvers is faster on CPU but not GPU-compatible
50
+ elseif cache. u isa Array
36
51
alg = IterativeSolversJL_GMRES ()
37
52
SciMLBase. solve (cache, alg, args... ; kwargs... )
53
+ else
54
+ alg = KrylovJL_GMRES ()
55
+ SciMLBase. solve (cache, alg, args... ; kwargs... )
38
56
end
39
57
end
0 commit comments