|
1 |
| -## Default algorithm |
| 1 | +defaultalg(A::DiffEqArrayOperator, b, assumptions::OperatorAssumptions) = defaultalg(A.A, b, assumptions) |
2 | 2 |
|
3 |
| -# Allows A === nothing as a stand-in for dense matrix |
4 |
| -function defaultalg(A, b) |
5 |
| - if A isa DiffEqArrayOperator |
6 |
| - A = A.A |
7 |
| - end |
8 |
| - |
9 |
| - # Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when |
10 |
| - # it makes sense according to the benchmarks, which is dependent on |
11 |
| - # whether MKL or OpenBLAS is being used |
12 |
| - if (A === nothing && !(b isa GPUArraysCore.AbstractGPUArray)) || A isa Matrix |
13 |
| - if (A === nothing || eltype(A) <: Union{Float32, Float64, ComplexF32, ComplexF64}) && |
14 |
| - ArrayInterfaceCore.can_setindex(b) |
15 |
| - if length(b) <= 10 |
16 |
| - alg = GenericLUFactorization() |
17 |
| - elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) && |
18 |
| - eltype(A) <: Union{Float32, Float64} |
19 |
| - alg = RFLUFactorization() |
20 |
| - #elseif A === nothing || A isa Matrix |
21 |
| - # alg = FastLUFactorization() |
22 |
| - else |
23 |
| - alg = LUFactorization() |
24 |
| - end |
25 |
| - else |
26 |
| - alg = LUFactorization() |
27 |
| - end |
| 3 | +# Ambiguity handling |
| 4 | +defaultalg(A::DiffEqArrayOperator, b, assumptions::OperatorAssumptions{nothing}) = defaultalg(A.A, b, assumptions) |
28 | 5 |
|
29 |
| - # These few cases ensure the choice is optimal without the |
30 |
| - # dynamic dispatching of factorize |
31 |
| - elseif A isa Tridiagonal |
32 |
| - alg = GenericFactorization(; fact_alg = lu!) |
33 |
| - elseif A isa SymTridiagonal |
34 |
| - alg = GenericFactorization(; fact_alg = ldlt!) |
35 |
| - elseif A isa SparseMatrixCSC |
36 |
| - if length(b) <= 10_000 |
37 |
| - alg = KLUFactorization() |
38 |
| - else |
39 |
| - alg = UMFPACKFactorization() |
40 |
| - end |
| 6 | +function defaultalg(A, b, ::OperatorAssumptions{nothing}) |
| 7 | + issquare = size(A,1) == size(A,2) |
| 8 | + defaultalg(A, b, OperatorAssumptions(Val(issquare))) |
| 9 | +end |
41 | 10 |
|
42 |
| - # This catches the cases where a factorization overload could exist |
43 |
| - # For example, BlockBandedMatrix |
44 |
| - elseif A !== nothing && ArrayInterfaceCore.isstructured(A) |
45 |
| - alg = GenericFactorization() |
| 11 | +defaultalg(A::Tridiagonal, b, ::OperatorAssumptions{true}) = GenericFactorization(; fact_alg = lu!) |
| 12 | +defaultalg(A::Tridiagonal, b, ::OperatorAssumptions{false}) = GenericFactorization(; fact_alg = qr!) |
| 13 | +defaultalg(A::SymTridiagonal, b, ::OperatorAssumptions{true}) = GenericFactorization(; fact_alg = ldlt!) |
46 | 14 |
|
47 |
| - # This catches the case where A is a CuMatrix |
48 |
| - # Which does not have LU fully defined |
49 |
| - elseif A isa GPUArraysCore.AbstractGPUArray || b isa GPUArraysCore.AbstractGPUArray |
50 |
| - if VERSION >= v"1.8-" |
51 |
| - alg = LUFactorization() |
52 |
| - else |
53 |
| - alg = QRFactorization() |
54 |
| - end |
| 15 | +function defaultalg(A::SparseMatrixCSC, b, ::OperatorAssumptions{true}) |
| 16 | + if length(b) <= 10_000 |
| 17 | + KLUFactorization() |
| 18 | + else |
| 19 | + UMFPACKFactorization() |
| 20 | + end |
| 21 | +end |
55 | 22 |
|
56 |
| - # Not factorizable operator, default to only using A*x |
| 23 | +function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, ::OperatorAssumptions{true}) |
| 24 | + if VERSION >= v"1.8-" |
| 25 | + LUFactorization() |
57 | 26 | else
|
58 |
| - alg = KrylovJL_GMRES() |
| 27 | + QRFactorization() |
59 | 28 | end
|
60 |
| - alg |
61 | 29 | end
|
62 | 30 |
|
63 |
| -## Other dispatches are to decrease the dispatch cost |
| 31 | +function defaultalg(A, b::GPUArraysCore.AbstractGPUArray, ::OperatorAssumptions{true}) |
| 32 | + if VERSION >= v"1.8-" |
| 33 | + LUFactorization() |
| 34 | + else |
| 35 | + QRFactorization() |
| 36 | + end |
| 37 | +end |
64 | 38 |
|
65 |
| -function SciMLBase.solve(cache::LinearCache, alg::Nothing, |
66 |
| - args...; kwargs...) |
67 |
| - @unpack A = cache |
68 |
| - if A isa DiffEqArrayOperator |
69 |
| - A = A.A |
| 39 | +# Handle ambiguity |
| 40 | +function defaultalg(A::GPUArraysCore.AbstractGPUArray, b::GPUArraysCore.AbstractGPUArray, ::OperatorAssumptions{true}) |
| 41 | + if VERSION >= v"1.8-" |
| 42 | + LUFactorization() |
| 43 | + else |
| 44 | + QRFactorization() |
70 | 45 | end
|
| 46 | +end |
| 47 | + |
| 48 | +function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, ::OperatorAssumptions{false}) |
| 49 | + QRFactorization() |
| 50 | +end |
| 51 | + |
| 52 | +function defaultalg(A, b::GPUArraysCore.AbstractGPUArray, ::OperatorAssumptions{false}) |
| 53 | + QRFactorization() |
| 54 | +end |
71 | 55 |
|
| 56 | +# Handle ambiguity |
| 57 | +function defaultalg(A::GPUArraysCore.AbstractGPUArray, b::GPUArraysCore.AbstractGPUArray, ::OperatorAssumptions{false}) |
| 58 | + QRFactorization() |
| 59 | +end |
| 60 | + |
| 61 | +# Allows A === nothing as a stand-in for dense matrix |
| 62 | +function defaultalg(A, b, ::Assumptions{true}) |
72 | 63 | # Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when
|
73 | 64 | # it makes sense according to the benchmarks, which is dependent on
|
74 | 65 | # whether MKL or OpenBLAS is being used
|
75 |
| - if A isa Matrix |
76 |
| - b = cache.b |
| 66 | + if (A === nothing && !(b isa GPUArraysCore.AbstractGPUArray)) || A isa Matrix |
77 | 67 | if (A === nothing || eltype(A) <: Union{Float32, Float64, ComplexF32, ComplexF64}) &&
|
78 | 68 | ArrayInterfaceCore.can_setindex(b)
|
79 | 69 | if length(b) <= 10
|
80 | 70 | alg = GenericLUFactorization()
|
81 |
| - SciMLBase.solve(cache, alg, args...; kwargs...) |
82 | 71 | elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) &&
|
83 | 72 | eltype(A) <: Union{Float32, Float64}
|
84 | 73 | alg = RFLUFactorization()
|
85 |
| - SciMLBase.solve(cache, alg, args...; kwargs...) |
86 |
| - #elseif A isa Matrix |
| 74 | + #elseif A === nothing || A isa Matrix |
87 | 75 | # alg = FastLUFactorization()
|
88 |
| - # SciMLBase.solve(cache, alg, args...; kwargs...) |
89 | 76 | else
|
90 | 77 | alg = LUFactorization()
|
91 |
| - SciMLBase.solve(cache, alg, args...; kwargs...) |
92 | 78 | end
|
93 | 79 | else
|
94 | 80 | alg = LUFactorization()
|
95 |
| - SciMLBase.solve(cache, alg, args...; kwargs...) |
96 |
| - end |
97 |
| - |
98 |
| - # These few cases ensure the choice is optimal without the |
99 |
| - # dynamic dispatching of factorize |
100 |
| - elseif A isa Tridiagonal |
101 |
| - alg = GenericFactorization(; fact_alg = lu!) |
102 |
| - SciMLBase.solve(cache, alg, args...; kwargs...) |
103 |
| - elseif A isa SymTridiagonal |
104 |
| - alg = GenericFactorization(; fact_alg = ldlt!) |
105 |
| - SciMLBase.solve(cache, alg, args...; kwargs...) |
106 |
| - elseif A isa SparseMatrixCSC |
107 |
| - b = cache.b |
108 |
| - if length(b) <= 10_000 |
109 |
| - alg = KLUFactorization() |
110 |
| - SciMLBase.solve(cache, alg, args...; kwargs...) |
111 |
| - else |
112 |
| - alg = UMFPACKFactorization() |
113 |
| - SciMLBase.solve(cache, alg, args...; kwargs...) |
114 | 81 | end
|
115 | 82 |
|
116 | 83 | # This catches the cases where a factorization overload could exist
|
117 | 84 | # For example, BlockBandedMatrix
|
118 |
| - elseif ArrayInterfaceCore.isstructured(A) |
| 85 | + elseif A !== nothing && ArrayInterfaceCore.isstructured(A) |
119 | 86 | alg = GenericFactorization()
|
120 |
| - SciMLBase.solve(cache, alg, args...; kwargs...) |
121 | 87 |
|
122 |
| - # This catches the case where A is a CuMatrix |
123 |
| - # Which does not have LU fully defined |
124 |
| - elseif A isa GPUArraysCore.AbstractGPUArray |
125 |
| - if VERSION >= v"1.8-" |
126 |
| - alg = LUFactorization() |
127 |
| - SciMLBase.solve(cache, alg, args...; kwargs...) |
128 |
| - else |
129 |
| - alg = QRFactorization() |
130 |
| - SciMLBase.solve(cache, alg, args...; kwargs...) |
131 |
| - end |
132 | 88 | # Not factorizable operator, default to only using A*x
|
133 |
| - # IterativeSolvers is faster on CPU but not GPU-compatible |
134 | 89 | else
|
135 | 90 | alg = KrylovJL_GMRES()
|
136 |
| - SciMLBase.solve(cache, alg, args...; kwargs...) |
137 | 91 | end
|
| 92 | + alg |
138 | 93 | end
|
139 | 94 |
|
140 |
| -function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) |
141 |
| - if A isa DiffEqArrayOperator |
142 |
| - A = A.A |
143 |
| - end |
144 |
| - |
145 |
| - # Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when |
146 |
| - # it makes sense according to the benchmarks, which is dependent on |
147 |
| - # whether MKL or OpenBLAS is being used |
148 |
| - if A isa Matrix |
149 |
| - if (A === nothing || eltype(A) <: Union{Float32, Float64, ComplexF32, ComplexF64}) && |
150 |
| - ArrayInterfaceCore.can_setindex(b) |
151 |
| - if length(b) <= 10 |
152 |
| - alg = GenericLUFactorization() |
153 |
| - init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) |
154 |
| - elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) && |
155 |
| - eltype(A) <: Union{Float32, Float64} |
156 |
| - alg = RFLUFactorization() |
157 |
| - init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) |
158 |
| - #elseif A isa Matrix |
159 |
| - # alg = FastLUFactorization() |
160 |
| - # init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) |
161 |
| - else |
162 |
| - alg = LUFactorization() |
163 |
| - init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) |
164 |
| - end |
165 |
| - else |
166 |
| - alg = LUFactorization() |
167 |
| - init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) |
168 |
| - end |
| 95 | +function defaultalg(A, b, ::Val{false}) |
| 96 | + QRFactorization() |
| 97 | +end |
169 | 98 |
|
170 |
| - # These few cases ensure the choice is optimal without the |
171 |
| - # dynamic dispatching of factorize |
172 |
| - elseif A isa Tridiagonal |
173 |
| - alg = GenericFactorization(; fact_alg = lu!) |
174 |
| - init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) |
175 |
| - elseif A isa SymTridiagonal |
176 |
| - alg = GenericFactorization(; fact_alg = ldlt!) |
177 |
| - init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) |
178 |
| - elseif A isa SparseMatrixCSC |
179 |
| - if length(b) <= 10_000 |
180 |
| - alg = KLUFactorization() |
181 |
| - init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) |
182 |
| - else |
183 |
| - alg = UMFPACKFactorization() |
184 |
| - init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) |
185 |
| - end |
| 99 | +## Catch high level interface |
186 | 100 |
|
187 |
| - # This catches the cases where a factorization overload could exist |
188 |
| - # For example, BlockBandedMatrix |
189 |
| - elseif ArrayInterfaceCore.isstructured(A) |
190 |
| - alg = GenericFactorization() |
191 |
| - init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) |
| 101 | +function SciMLBase.solve(cache::LinearCache, alg::Nothing, |
| 102 | + args...; assumptions::OperatorAssumptions = OperatorAssumptions(), kwargs...) |
| 103 | + @unpack A, b = cache |
| 104 | + SciMLBase.solve(cache, default_alg(A,b,assumptions), args...; kwargs...) |
| 105 | +end |
192 | 106 |
|
193 |
| - # This catches the case where A is a CuMatrix |
194 |
| - # Which does not have LU fully defined |
195 |
| - elseif A isa GPUArraysCore.AbstractGPUArray |
196 |
| - if VERSION >= v"1.8-" |
197 |
| - alg = LUFactorization() |
198 |
| - init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) |
199 |
| - else |
200 |
| - alg = QRFactorization() |
201 |
| - init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) |
202 |
| - end |
203 |
| - # Not factorizable operator, default to only using A*x |
204 |
| - # IterativeSolvers is faster on CPU but not GPU-compatible |
205 |
| - else |
206 |
| - alg = KrylovJL_GMRES() |
207 |
| - init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) |
208 |
| - end |
| 107 | +function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) |
| 108 | + init_cacheval(default_alg(A,b), A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose, assumptions) |
209 | 109 | end
|
0 commit comments