Skip to content

Commit 881cdfb

Browse files
Add an assumptions mechanism for type-stable default help, and nonsquare
Fixes #177
1 parent ec2dfad commit 881cdfb

File tree

5 files changed

+119
-204
lines changed

5 files changed

+119
-204
lines changed

src/common.jl

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
1-
struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol}
1+
struct OperatorAssumptions{issquare} end
2+
function OperatorAssumptions(issquare = nothing)
3+
OperatorAssumptions{_unwrap_val(issquare)}()
4+
end
5+
6+
_unwrap_val(::Val{B}) where {B} = B
7+
_unwrap_val(B::Nothing) = Nothing
8+
_unwrap_val(B::Bool) = B
9+
10+
struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issquare}
211
A::TA
312
b::Tb
413
u::Tu
@@ -12,6 +21,7 @@ struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol}
1221
reltol::Ttol
1322
maxiters::Int
1423
verbose::Bool
24+
assumptions::OperatorAssumptions{issquare}
1525
end
1626

1727
"""
@@ -86,6 +96,7 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith
8696
verbose = false,
8797
Pl = Identity(),
8898
Pr = Identity(),
99+
assumptions = OperatorAssumptions(),
89100
kwargs...)
90101
@unpack A, b, u0, p = prob
91102

@@ -96,7 +107,7 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith
96107
fill!(u0, false)
97108
end
98109

99-
cacheval = init_cacheval(alg, A, b, u0, Pl, Pr, maxiters, abstol, reltol, verbose)
110+
cacheval = init_cacheval(alg, A, b, u0, Pl, Pr, maxiters, abstol, reltol, verbose, assumptions)
100111
isfresh = true
101112
Tc = typeof(cacheval)
102113

@@ -112,7 +123,8 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith
112123
Tc,
113124
typeof(Pl),
114125
typeof(Pr),
115-
typeof(reltol)
126+
typeof(reltol),
127+
typeof(assumptions)
116128
}(A,
117129
b,
118130
u0,
@@ -125,7 +137,8 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith
125137
abstol,
126138
reltol,
127139
maxiters,
128-
verbose)
140+
verbose,
141+
assumptions)
129142
return cache
130143
end
131144

src/default.jl

Lines changed: 65 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -1,209 +1,109 @@
1-
## Default algorithm
1+
defaultalg(A::DiffEqArrayOperator, b, assumptions::OperatorAssumptions) = defaultalg(A.A, b, assumptions)
22

3-
# Allows A === nothing as a stand-in for dense matrix
4-
function defaultalg(A, b)
5-
if A isa DiffEqArrayOperator
6-
A = A.A
7-
end
8-
9-
# Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when
10-
# it makes sense according to the benchmarks, which is dependent on
11-
# whether MKL or OpenBLAS is being used
12-
if (A === nothing && !(b isa GPUArraysCore.AbstractGPUArray)) || A isa Matrix
13-
if (A === nothing || eltype(A) <: Union{Float32, Float64, ComplexF32, ComplexF64}) &&
14-
ArrayInterfaceCore.can_setindex(b)
15-
if length(b) <= 10
16-
alg = GenericLUFactorization()
17-
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) &&
18-
eltype(A) <: Union{Float32, Float64}
19-
alg = RFLUFactorization()
20-
#elseif A === nothing || A isa Matrix
21-
# alg = FastLUFactorization()
22-
else
23-
alg = LUFactorization()
24-
end
25-
else
26-
alg = LUFactorization()
27-
end
3+
# Ambiguity handling
4+
defaultalg(A::DiffEqArrayOperator, b, assumptions::OperatorAssumptions{nothing}) = defaultalg(A.A, b, assumptions)
285

29-
# These few cases ensure the choice is optimal without the
30-
# dynamic dispatching of factorize
31-
elseif A isa Tridiagonal
32-
alg = GenericFactorization(; fact_alg = lu!)
33-
elseif A isa SymTridiagonal
34-
alg = GenericFactorization(; fact_alg = ldlt!)
35-
elseif A isa SparseMatrixCSC
36-
if length(b) <= 10_000
37-
alg = KLUFactorization()
38-
else
39-
alg = UMFPACKFactorization()
40-
end
6+
function defaultalg(A, b, ::OperatorAssumptions{nothing})
7+
issquare = size(A,1) == size(A,2)
8+
defaultalg(A, b, OperatorAssumptions(Val(issquare)))
9+
end
4110

42-
# This catches the cases where a factorization overload could exist
43-
# For example, BlockBandedMatrix
44-
elseif A !== nothing && ArrayInterfaceCore.isstructured(A)
45-
alg = GenericFactorization()
11+
defaultalg(A::Tridiagonal, b, ::OperatorAssumptions{true}) = GenericFactorization(; fact_alg = lu!)
12+
defaultalg(A::Tridiagonal, b, ::OperatorAssumptions{false}) = GenericFactorization(; fact_alg = qr!)
13+
defaultalg(A::SymTridiagonal, b, ::OperatorAssumptions{true}) = GenericFactorization(; fact_alg = ldlt!)
4614

47-
# This catches the case where A is a CuMatrix
48-
# Which does not have LU fully defined
49-
elseif A isa GPUArraysCore.AbstractGPUArray || b isa GPUArraysCore.AbstractGPUArray
50-
if VERSION >= v"1.8-"
51-
alg = LUFactorization()
52-
else
53-
alg = QRFactorization()
54-
end
15+
function defaultalg(A::SparseMatrixCSC, b, ::OperatorAssumptions{true})
16+
if length(b) <= 10_000
17+
KLUFactorization()
18+
else
19+
UMFPACKFactorization()
20+
end
21+
end
5522

56-
# Not factorizable operator, default to only using A*x
23+
function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, ::OperatorAssumptions{true})
24+
if VERSION >= v"1.8-"
25+
LUFactorization()
5726
else
58-
alg = KrylovJL_GMRES()
27+
QRFactorization()
5928
end
60-
alg
6129
end
6230

63-
## Other dispatches are to decrease the dispatch cost
31+
function defaultalg(A, b::GPUArraysCore.AbstractGPUArray, ::OperatorAssumptions{true})
32+
if VERSION >= v"1.8-"
33+
LUFactorization()
34+
else
35+
QRFactorization()
36+
end
37+
end
6438

65-
function SciMLBase.solve(cache::LinearCache, alg::Nothing,
66-
args...; kwargs...)
67-
@unpack A = cache
68-
if A isa DiffEqArrayOperator
69-
A = A.A
39+
# Handle ambiguity
40+
function defaultalg(A::GPUArraysCore.AbstractGPUArray, b::GPUArraysCore.AbstractGPUArray, ::OperatorAssumptions{true})
41+
if VERSION >= v"1.8-"
42+
LUFactorization()
43+
else
44+
QRFactorization()
7045
end
46+
end
47+
48+
function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, ::OperatorAssumptions{false})
49+
QRFactorization()
50+
end
51+
52+
function defaultalg(A, b::GPUArraysCore.AbstractGPUArray, ::OperatorAssumptions{false})
53+
QRFactorization()
54+
end
7155

56+
# Handle ambiguity
57+
function defaultalg(A::GPUArraysCore.AbstractGPUArray, b::GPUArraysCore.AbstractGPUArray, ::OperatorAssumptions{false})
58+
QRFactorization()
59+
end
60+
61+
# Allows A === nothing as a stand-in for dense matrix
62+
function defaultalg(A, b, ::Assumptions{true})
7263
# Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when
7364
# it makes sense according to the benchmarks, which is dependent on
7465
# whether MKL or OpenBLAS is being used
75-
if A isa Matrix
76-
b = cache.b
66+
if (A === nothing && !(b isa GPUArraysCore.AbstractGPUArray)) || A isa Matrix
7767
if (A === nothing || eltype(A) <: Union{Float32, Float64, ComplexF32, ComplexF64}) &&
7868
ArrayInterfaceCore.can_setindex(b)
7969
if length(b) <= 10
8070
alg = GenericLUFactorization()
81-
SciMLBase.solve(cache, alg, args...; kwargs...)
8271
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) &&
8372
eltype(A) <: Union{Float32, Float64}
8473
alg = RFLUFactorization()
85-
SciMLBase.solve(cache, alg, args...; kwargs...)
86-
#elseif A isa Matrix
74+
#elseif A === nothing || A isa Matrix
8775
# alg = FastLUFactorization()
88-
# SciMLBase.solve(cache, alg, args...; kwargs...)
8976
else
9077
alg = LUFactorization()
91-
SciMLBase.solve(cache, alg, args...; kwargs...)
9278
end
9379
else
9480
alg = LUFactorization()
95-
SciMLBase.solve(cache, alg, args...; kwargs...)
96-
end
97-
98-
# These few cases ensure the choice is optimal without the
99-
# dynamic dispatching of factorize
100-
elseif A isa Tridiagonal
101-
alg = GenericFactorization(; fact_alg = lu!)
102-
SciMLBase.solve(cache, alg, args...; kwargs...)
103-
elseif A isa SymTridiagonal
104-
alg = GenericFactorization(; fact_alg = ldlt!)
105-
SciMLBase.solve(cache, alg, args...; kwargs...)
106-
elseif A isa SparseMatrixCSC
107-
b = cache.b
108-
if length(b) <= 10_000
109-
alg = KLUFactorization()
110-
SciMLBase.solve(cache, alg, args...; kwargs...)
111-
else
112-
alg = UMFPACKFactorization()
113-
SciMLBase.solve(cache, alg, args...; kwargs...)
11481
end
11582

11683
# This catches the cases where a factorization overload could exist
11784
# For example, BlockBandedMatrix
118-
elseif ArrayInterfaceCore.isstructured(A)
85+
elseif A !== nothing && ArrayInterfaceCore.isstructured(A)
11986
alg = GenericFactorization()
120-
SciMLBase.solve(cache, alg, args...; kwargs...)
12187

122-
# This catches the case where A is a CuMatrix
123-
# Which does not have LU fully defined
124-
elseif A isa GPUArraysCore.AbstractGPUArray
125-
if VERSION >= v"1.8-"
126-
alg = LUFactorization()
127-
SciMLBase.solve(cache, alg, args...; kwargs...)
128-
else
129-
alg = QRFactorization()
130-
SciMLBase.solve(cache, alg, args...; kwargs...)
131-
end
13288
# Not factorizable operator, default to only using A*x
133-
# IterativeSolvers is faster on CPU but not GPU-compatible
13489
else
13590
alg = KrylovJL_GMRES()
136-
SciMLBase.solve(cache, alg, args...; kwargs...)
13791
end
92+
alg
13893
end
13994

140-
function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
141-
if A isa DiffEqArrayOperator
142-
A = A.A
143-
end
144-
145-
# Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when
146-
# it makes sense according to the benchmarks, which is dependent on
147-
# whether MKL or OpenBLAS is being used
148-
if A isa Matrix
149-
if (A === nothing || eltype(A) <: Union{Float32, Float64, ComplexF32, ComplexF64}) &&
150-
ArrayInterfaceCore.can_setindex(b)
151-
if length(b) <= 10
152-
alg = GenericLUFactorization()
153-
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
154-
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) &&
155-
eltype(A) <: Union{Float32, Float64}
156-
alg = RFLUFactorization()
157-
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
158-
#elseif A isa Matrix
159-
# alg = FastLUFactorization()
160-
# init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
161-
else
162-
alg = LUFactorization()
163-
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
164-
end
165-
else
166-
alg = LUFactorization()
167-
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
168-
end
95+
function defaultalg(A, b, ::Val{false})
96+
QRFactorization()
97+
end
16998

170-
# These few cases ensure the choice is optimal without the
171-
# dynamic dispatching of factorize
172-
elseif A isa Tridiagonal
173-
alg = GenericFactorization(; fact_alg = lu!)
174-
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
175-
elseif A isa SymTridiagonal
176-
alg = GenericFactorization(; fact_alg = ldlt!)
177-
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
178-
elseif A isa SparseMatrixCSC
179-
if length(b) <= 10_000
180-
alg = KLUFactorization()
181-
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
182-
else
183-
alg = UMFPACKFactorization()
184-
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
185-
end
99+
## Catch high level interface
186100

187-
# This catches the cases where a factorization overload could exist
188-
# For example, BlockBandedMatrix
189-
elseif ArrayInterfaceCore.isstructured(A)
190-
alg = GenericFactorization()
191-
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
101+
function SciMLBase.solve(cache::LinearCache, alg::Nothing,
102+
args...; assumptions::OperatorAssumptions = OperatorAssumptions(), kwargs...)
103+
@unpack A, b = cache
104+
SciMLBase.solve(cache, default_alg(A,b,assumptions), args...; kwargs...)
105+
end
192106

193-
# This catches the case where A is a CuMatrix
194-
# Which does not have LU fully defined
195-
elseif A isa GPUArraysCore.AbstractGPUArray
196-
if VERSION >= v"1.8-"
197-
alg = LUFactorization()
198-
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
199-
else
200-
alg = QRFactorization()
201-
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
202-
end
203-
# Not factorizable operator, default to only using A*x
204-
# IterativeSolvers is faster on CPU but not GPU-compatible
205-
else
206-
alg = KrylovJL_GMRES()
207-
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
208-
end
107+
function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
108+
init_cacheval(default_alg(A,b), A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose, assumptions)
209109
end

0 commit comments

Comments
 (0)