Skip to content

Commit 155e0ec

Browse files
Merge pull request #270 from vpuri3/scimloperators
Support SciMLOperators in LinearSolve
2 parents 8e70428 + c763a17 commit 155e0ec

File tree

11 files changed

+152
-62
lines changed

11 files changed

+152
-62
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250"
1717
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
1818
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1919
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
20+
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
2021
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2122
SnoopPrecompile = "66db9d55-30c0-4569-8b51-7e840670fc0c"
2223
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
@@ -43,7 +44,8 @@ KrylovKit = "0.5, 0.6"
4344
Preferences = "1"
4445
RecursiveFactorization = "0.2.8"
4546
Reexport = "1"
46-
SciMLBase = "1.68"
47+
SciMLBase = "1.82"
48+
SciMLOperators = "0.1.19"
4749
Setfield = "0.7, 0.8, 1"
4850
SnoopPrecompile = "1"
4951
Sparspak = "0.3.6"

ext/LinearSolveHYPRE.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using HYPRE.LibHYPRE: HYPRE_Complex
44
using HYPRE: HYPRE, HYPREMatrix, HYPRESolver, HYPREVector
55
using IterativeSolvers: Identity
66
using LinearSolve: HYPREAlgorithm, LinearCache, LinearProblem, LinearSolve,
7-
OperatorAssumptions, default_tol, init_cacheval, issquare, set_cacheval
7+
OperatorAssumptions, default_tol, init_cacheval, __issquare, set_cacheval
88
using SciMLBase: LinearProblem, SciMLBase
99
using UnPack: @unpack
1010
using Setfield: @set!
@@ -82,7 +82,7 @@ function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm,
8282

8383
cache = LinearCache{
8484
typeof(A), typeof(b), typeof(u0), typeof(p), typeof(alg), Tc,
85-
typeof(Pl), typeof(Pr), typeof(reltol), issquare(assumptions)
85+
typeof(Pl), typeof(Pr), typeof(reltol), __issquare(assumptions)
8686
}(A, b, u0, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol,
8787
maxiters,
8888
verbose, assumptions)

lib/LinearSolveCUDA/src/LinearSolveCUDA.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module LinearSolveCUDA
22

33
using CUDA, LinearAlgebra, LinearSolve, SciMLBase
4+
using SciMLBase: AbstractSciMLOperator
45

56
struct CudaOffloadFactorization <: LinearSolve.AbstractFactorization end
67

@@ -17,12 +18,13 @@ function SciMLBase.solve(cache::LinearSolve.LinearCache, alg::CudaOffloadFactori
1718
end
1819

1920
function LinearSolve.do_factorization(alg::CudaOffloadFactorization, A, b, u)
20-
A isa Union{AbstractMatrix, SciMLBase.AbstractDiffEqOperator} ||
21+
A isa Union{AbstractMatrix, AbstractSciMLOperator} ||
2122
error("LU is not defined for $(typeof(A))")
2223

23-
if A isa SciMLBase.AbstractDiffEqOperator
24+
if A isa Union{MatrixOperator, DiffEqArrayOperator}
2425
A = A.A
2526
end
27+
2628
fact = qr(CUDA.CuArray(A))
2729
return fact
2830
end

src/LinearSolve.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@ end
66
using ArrayInterfaceCore
77
using RecursiveFactorization
88
using Base: cache_dependencies, Bool
9-
import Base: eltype, adjoint, inv
109
using LinearAlgebra
1110
using IterativeSolvers: Identity
1211
using SparseArrays
13-
using SciMLBase: AbstractDiffEqOperator, AbstractLinearAlgorithm
12+
using SciMLBase: AbstractLinearAlgorithm
13+
using SciMLOperators
14+
using SciMLOperators: AbstractSciMLOperator, IdentityOperator
1415
using Setfield
1516
using UnPack
1617
using SuiteSparse
@@ -41,6 +42,15 @@ needs_concrete_A(alg::AbstractFactorization) = true
4142
needs_concrete_A(alg::AbstractKrylovSubspaceMethod) = false
4243
needs_concrete_A(alg::AbstractSolveFunction) = false
4344

45+
# Util
46+
47+
_isidentity_struct(A) = false
48+
_isidentity_struct::Number) = isone(λ)
49+
_isidentity_struct(A::UniformScaling) = isone(A.λ)
50+
_isidentity_struct(::IterativeSolvers.Identity) = true
51+
_isidentity_struct(::SciMLOperators.IdentityOperator) = true
52+
_isidentity_struct(::SciMLBase.DiffEqIdentity) = true
53+
4454
# Code
4555

4656
const INCLUDE_SPARSE = Preferences.@load_preference("include_sparse", Base.USE_GPL_LIBS)
@@ -97,7 +107,7 @@ export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
97107
UMFPACKFactorization, KLUFactorization, FastLUFactorization, FastQRFactorization,
98108
SparspakFactorization, DiagonalFactorization
99109

100-
export LinearSolveFunction
110+
export LinearSolveFunction, DirectLdiv!
101111

102112
export KrylovJL, KrylovJL_CG, KrylovJL_MINRES, KrylovJL_GMRES,
103113
KrylovJL_BICGSTAB, KrylovJL_LSMR, KrylovJL_CRAIGMR,

src/common.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
struct OperatorAssumptions{issquare} end
1+
struct OperatorAssumptions{issq} end
22
function OperatorAssumptions(issquare = nothing)
3-
issquare = something(_unwrap_val(issquare), Nothing)
4-
OperatorAssumptions{issquare}()
3+
issq = something(_unwrap_val(issquare), Nothing)
4+
OperatorAssumptions{issq}()
55
end
6-
issquare(::OperatorAssumptions{issq}) where {issq} = issq
6+
__issquare(::OperatorAssumptions{issq}) where {issq} = issq
77

8-
struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issquare}
8+
struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq}
99
A::TA
1010
b::Tb
1111
u::Tu
@@ -19,7 +19,7 @@ struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issquare}
1919
reltol::Ttol
2020
maxiters::Int
2121
verbose::Bool
22-
assumptions::OperatorAssumptions{issquare}
22+
assumptions::OperatorAssumptions{issq}
2323
end
2424

2525
"""
@@ -92,9 +92,9 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith
9292
reltol = default_tol(eltype(prob.A)),
9393
maxiters::Int = length(prob.b),
9494
verbose::Bool = false,
95-
Pl = Identity(),
96-
Pr = Identity(),
97-
assumptions = OperatorAssumptions(),
95+
Pl = IdentityOperator{size(prob.A, 1)}(),
96+
Pr = IdentityOperator{size(prob.A, 2)}(),
97+
assumptions = OperatorAssumptions(Val(issquare(prob.A))),
9898
kwargs...)
9999
@unpack A, b, u0, p = prob
100100

@@ -129,7 +129,7 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith
129129
typeof(Pl),
130130
typeof(Pr),
131131
typeof(reltol),
132-
issquare(assumptions)
132+
__issquare(assumptions)
133133
}(A,
134134
b,
135135
u0,

src/default.jl

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,30 @@
22
# For SciML algorithms already using `defaultalg`, all assume square matrix.
33
defaultalg(A, b) = defaultalg(A, b, OperatorAssumptions(Val(true)))
44

5-
function defaultalg(A::DiffEqArrayOperator, b, assumptions::OperatorAssumptions)
5+
function defaultalg(A::Union{DiffEqArrayOperator, MatrixOperator}, b,
6+
assumptions::OperatorAssumptions)
67
defaultalg(A.A, b, assumptions)
78
end
89

910
# Ambiguity handling
10-
function defaultalg(A::DiffEqArrayOperator, b, assumptions::OperatorAssumptions{nothing})
11+
function defaultalg(A::Union{DiffEqArrayOperator, MatrixOperator}, b,
12+
assumptions::OperatorAssumptions{nothing})
1113
defaultalg(A.A, b, assumptions)
1214
end
1315

14-
function defaultalg(A::DiffEqArrayOperator, b, assumptions::OperatorAssumptions{false})
16+
function defaultalg(A::Union{DiffEqArrayOperator, MatrixOperator}, b,
17+
assumptions::OperatorAssumptions{false})
1518
defaultalg(A.A, b, assumptions)
1619
end
1720

18-
function defaultalg(A::DiffEqArrayOperator, b, assumptions::OperatorAssumptions{true})
21+
function defaultalg(A::Union{DiffEqArrayOperator, MatrixOperator}, b,
22+
assumptions::OperatorAssumptions{true})
1923
defaultalg(A.A, b, assumptions)
2024
end
2125

2226
function defaultalg(A, b, ::OperatorAssumptions{Nothing})
23-
issquare = size(A, 1) == size(A, 2)
24-
defaultalg(A, b, OperatorAssumptions(Val(issquare)))
27+
issq = issquare(A)
28+
defaultalg(A, b, OperatorAssumptions(Val(issq)))
2529
end
2630

2731
function defaultalg(A::Tridiagonal, b, ::OperatorAssumptions{true})
@@ -33,10 +37,13 @@ end
3337
function defaultalg(A::SymTridiagonal, b, ::OperatorAssumptions{true})
3438
GenericFactorization(; fact_alg = ldlt!)
3539
end
36-
function defaultalg(A::Diagonal, b, ::OperatorAssumptions{true})
37-
DiagonalFactorization()
40+
function defaultalg(A::Bidiagonal, b, ::OperatorAssumptions{true})
41+
DirectLdiv!()
42+
end
43+
function defaultalg(A::Factorization, b, ::OperatorAssumptions{true})
44+
DirectLdiv!()
3845
end
39-
function defaultalg(A::Diagonal, b, ::OperatorAssumptions{false})
46+
function defaultalg(A::Diagonal, b, ::OperatorAssumptions{true})
4047
DiagonalFactorization()
4148
end
4249
function defaultalg(A::Diagonal, b, ::OperatorAssumptions{Nothing})
@@ -75,18 +82,26 @@ function defaultalg(A, b::GPUArraysCore.AbstractGPUArray, ::OperatorAssumptions{
7582
end
7683
end
7784

78-
function defaultalg(A::SciMLBase.AbstractDiffEqOperator, b,
79-
assumptions::OperatorAssumptions)
85+
function defaultalg(A::SciMLBase.AbstractSciMLOperator, b,
86+
assumptions::OperatorAssumptions{true})
87+
if has_ldiv!(A)
88+
return DirectLdiv!()
89+
end
90+
8091
KrylovJL_GMRES()
8192
end
8293

8394
# Ambiguity handling
84-
function defaultalg(A::SciMLBase.AbstractDiffEqOperator, b,
95+
function defaultalg(A::SciMLBase.AbstractSciMLOperator, b,
8596
assumptions::OperatorAssumptions{Nothing})
97+
if has_ldiv!(A)
98+
return DirectLdiv!()
99+
end
100+
86101
KrylovJL_GMRES()
87102
end
88103

89-
function defaultalg(A::SciMLBase.AbstractDiffEqOperator, b,
104+
function defaultalg(A::SciMLBase.AbstractSciMLOperator, b,
90105
assumptions::OperatorAssumptions{false})
91106
m, n = size(A)
92107
if m < n

src/iterative_wrappers.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,9 @@ function SciMLBase.solve(cache::LinearCache, alg::KrylovJL; kwargs...)
150150
M = cache.Pl
151151
N = cache.Pr
152152

153-
M = (M === Identity()) ? I : InvPreconditioner(M)
154-
N = (N === Identity()) ? I : InvPreconditioner(N)
153+
# use no-op preconditioner for Krylov.jl (LinearAlgebra.I) when M/N is identity
154+
M = _isidentity_struct(M) ? I : M
155+
N = _isidentity_struct(M) ? I : N
155156

156157
atol = float(cache.abstol)
157158
rtol = float(cache.reltol)
@@ -160,7 +161,7 @@ function SciMLBase.solve(cache::LinearCache, alg::KrylovJL; kwargs...)
160161

161162
args = (cache.cacheval, cache.A, cache.b)
162163
kwargs = (atol = atol, rtol = rtol, itmax = itmax, verbose = verbose,
163-
history = true, alg.kwargs...)
164+
ldiv = true, history = true, alg.kwargs...)
164165

165166
if cache.cacheval isa Krylov.CgSolver
166167
N !== I &&
@@ -234,15 +235,15 @@ function init_cacheval(alg::IterativeSolversJL, A, b, u, Pl, Pr, maxiters::Int,
234235
alg.kwargs...)
235236

236237
iterable = if alg.generate_iterator === IterativeSolvers.cg_iterator!
237-
Pr !== Identity() &&
238+
!_isidentity_struct(Pr) &&
238239
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
239240
alg.generate_iterator(u, A, b, Pl;
240241
kwargs...)
241242
elseif alg.generate_iterator === IterativeSolvers.gmres_iterable!
242243
alg.generate_iterator(u, A, b; Pl = Pl, Pr = Pr, restart = restart,
243244
kwargs...)
244245
elseif alg.generate_iterator === IterativeSolvers.bicgstabl_iterator!
245-
Pr !== Identity() &&
246+
!_isidentity_struct(Pr) &&
246247
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
247248
alg.generate_iterator(u, A, b, alg.args...; Pl = Pl,
248249
abstol = abstol, reltol = reltol,

src/solve_function.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,12 @@ function SciMLBase.solve(cache::LinearCache, alg::LinearSolveFunction,
1313

1414
return SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
1515
end
16+
17+
struct DirectLdiv! <: AbstractSolveFunction end
18+
19+
function SciMLBase.solve(cache::LinearCache, alg::DirectLdiv!, args...; kwargs...)
20+
@unpack A, b, u = cache
21+
ldiv!(u, A, b)
22+
23+
return SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
24+
end

0 commit comments

Comments
 (0)