@@ -9,7 +9,7 @@ function defaultalg(A,b)
9
9
# Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when
10
10
# it makes sense according to the benchmarks, which is dependent on
11
11
# whether MKL or OpenBLAS is being used
12
- if A === nothing || A isa Matrix
12
+ if ( A === nothing && ! isgpu (b)) || A isa Matrix
13
13
if (A === nothing || eltype (A) <: Union{Float32,Float64,ComplexF32,ComplexF64} ) &&
14
14
ArrayInterface. can_setindex (b) && (length (b) <= 100 ||
15
15
(isopenblas () && length (b) <= 500 )
@@ -30,18 +30,15 @@ function defaultalg(A,b)
30
30
31
31
# This catches the cases where a factorization overload could exist
32
32
# For example, BlockBandedMatrix
33
- elseif ArrayInterface. isstructured (A)
33
+ elseif A != = nothing && ArrayInterface. isstructured (A)
34
34
alg = GenericFactorization ()
35
35
36
36
# This catches the case where A is a CuMatrix
37
37
# Which does not have LU fully defined
38
- elseif ! (A isa AbstractDiffEqOperator )
38
+ elseif isgpu (A) || isgpu (b )
39
39
alg = QRFactorization (false )
40
40
41
41
# Not factorizable operator, default to only using A*x
42
- # IterativeSolvers is faster on CPU but not GPU-compatible
43
- elseif cache. u isa Array
44
- alg = IterativeSolversJL_GMRES ()
45
42
else
46
43
alg = KrylovJL_GMRES ()
47
44
end
@@ -92,15 +89,12 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,
92
89
93
90
# This catches the case where A is a CuMatrix
94
91
# Which does not have LU fully defined
95
- elseif ! (A isa AbstractDiffEqOperator )
92
+ elseif isgpu (A )
96
93
alg = QRFactorization (false )
97
94
SciMLBase. solve (cache, alg, args... ; kwargs... )
98
95
99
96
# Not factorizable operator, default to only using A*x
100
97
# IterativeSolvers is faster on CPU but not GPU-compatible
101
- elseif cache. u isa Array
102
- alg = IterativeSolversJL_GMRES ()
103
- SciMLBase. solve (cache, alg, args... ; kwargs... )
104
98
else
105
99
alg = KrylovJL_GMRES ()
106
100
SciMLBase. solve (cache, alg, args... ; kwargs... )
@@ -147,15 +141,12 @@ function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters, abstol, reltol,
147
141
148
142
# This catches the case where A is a CuMatrix
149
143
# Which does not have LU fully defined
150
- elseif ! (A isa AbstractDiffEqOperator )
144
+ elseif isgpu (A )
151
145
alg = QRFactorization (false )
152
146
init_cacheval (alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
153
147
154
148
# Not factorizable operator, default to only using A*x
155
149
# IterativeSolvers is faster on CPU but not GPU-compatible
156
- elseif u isa Array
157
- alg = IterativeSolversJL_GMRES ()
158
- init_cacheval (alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
159
150
else
160
151
alg = KrylovJL_GMRES ()
161
152
init_cacheval (alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
0 commit comments