Skip to content

Commit e5c6b7c

Browse files
authored
Merge pull request #179 from vpuri3/fwdbwd
Combine a forward operator with a backward operator for fast evaluaiton
2 parents 11447bf + d11863a commit e5c6b7c

File tree

7 files changed

+37
-42
lines changed

7 files changed

+37
-42
lines changed

src/func.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,8 @@ function iscached(L::FunctionOperator)
225225
end
226226

227227
function cache_self(L::FunctionOperator, u::AbstractVecOrMat, v::AbstractVecOrMat)
228-
!L.traits.ifcache && @warn "you are allocating cache for a FunctionOperator for which ifcache = false."
228+
!L.traits.ifcache && @warn """Cache is being allocated for a FunctionOperator
229+
created with kwarg ifcache = false."""
229230
@set! L.cache = zero.((u, v))
230231
L
231232
end

src/interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ has_expmv!(L::AbstractSciMLOperator) = false # expmv!(v, L, t, u)
126126
has_expmv(L::AbstractSciMLOperator) = false # v = exp(L, t, u)
127127
has_exp(L::AbstractSciMLOperator) = islinear(L)
128128
has_mul(L::AbstractSciMLOperator) = true # du = L*u
129-
has_mul!(L::AbstractSciMLOperator) = false # mul!(du, L, u)
129+
has_mul!(L::AbstractSciMLOperator) = true # mul!(du, L, u)
130130
has_ldiv(L::AbstractSciMLOperator) = false # du = L\u
131131
has_ldiv!(L::AbstractSciMLOperator) = false # ldiv!(du, L, u)
132132

src/matrix.jl

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -158,25 +158,28 @@ const AdjointFact = isdefined(LinearAlgebra, :AdjointFactorization) ? LinearAlge
158158
const TransposeFact = isdefined(LinearAlgebra, :TransposeFactorization) ? LinearAlgebra.TransposeFactorization : Transpose
159159

160160
"""
161-
InvertibleOperator(F)
161+
InvertibleOperator(L, F)
162162
163-
Like MatrixOperator, but stores a Factorization instead.
164-
165-
Supports left division and `ldiv!` when applied to an array.
163+
Stores an operator and its factorization (or inverse operator).
164+
Supports left division and `ldiv!` via `F`, and operator application
165+
via `L`.
166166
"""
167-
struct InvertibleOperator{T,FType} <: AbstractSciMLOperator{T}
168-
F::FType
167+
struct InvertibleOperator{T,LT,FT} <: AbstractSciMLOperator{T}
168+
L::LT
169+
F::FT
169170

170-
function InvertibleOperator(F)
171+
function InvertibleOperator(L, F)
171172
@assert has_ldiv(F) | has_ldiv!(F) "$F is not invertible"
172-
new{eltype(F),typeof(F)}(F)
173+
T = promote_type(eltype(L), eltype(F))
174+
175+
new{T,typeof(L),typeof(F)}(L, F)
173176
end
174177
end
175178

176179
# constructor
177180
function LinearAlgebra.factorize(L::AbstractSciMLOperator)
178181
fact = factorize(convert(AbstractMatrix, L))
179-
InvertibleOperator(fact)
182+
InvertibleOperator(L, fact)
180183
end
181184

182185
for fact in (
@@ -190,40 +193,36 @@ for fact in (
190193
)
191194

192195
@eval LinearAlgebra.$fact(L::AbstractSciMLOperator, args...) =
193-
InvertibleOperator($fact(convert(AbstractMatrix, L), args...))
196+
InvertibleOperator(L, $fact(convert(AbstractMatrix, L), args...))
194197
@eval LinearAlgebra.$fact(L::AbstractSciMLOperator; kwargs...) =
195-
InvertibleOperator($fact(convert(AbstractMatrix, L); kwargs...))
198+
InvertibleOperator(L, $fact(convert(AbstractMatrix, L); kwargs...))
196199
end
197200

198201
function Base.convert(::Type{<:Factorization}, L::InvertibleOperator{T,<:Factorization}) where{T}
199202
L.F
200203
end
201204

202-
Base.convert(::Type{AbstractMatrix}, L::InvertibleOperator) =
203-
convert(AbstractMatrix, L.F)
204-
Base.convert(::Type{AbstractMatrix}, L::InvertibleOperator{<:Any,<:Union{Adjoint,AdjointFact}}) =
205-
adjoint(convert(AbstractMatrix, adjoint(L.F)))
206-
Base.convert(::Type{AbstractMatrix}, L::InvertibleOperator{<:Any,<:Union{Transpose,TransposeFact}}) =
207-
transpose(convert(AbstractMatrix, transpose(L.F)))
205+
Base.convert(::Type{AbstractMatrix}, L::InvertibleOperator) = convert(AbstractMatrix, L.L)
208206

209207
# traits
210-
Base.size(L::InvertibleOperator) = size(L.F)
211-
Base.transpose(L::InvertibleOperator) = InvertibleOperator(transpose(L.F))
212-
Base.adjoint(L::InvertibleOperator) = InvertibleOperator(L.F')
213-
Base.conj(L::InvertibleOperator) = InvertibleOperator(conj(L.F))
214-
Base.resize!(L::InvertibleOperator, n::Integer) = (resize!(L.F, n); L)
208+
Base.size(L::InvertibleOperator) = size(L.L)
209+
Base.transpose(L::InvertibleOperator) = InvertibleOperator(transpose(L.L), transpose(L.F))
210+
Base.adjoint(L::InvertibleOperator) = InvertibleOperator(L.L', L.F')
211+
Base.conj(L::InvertibleOperator) = InvertibleOperator(conj(L.L), conj(L.F))
212+
Base.resize!(L::InvertibleOperator, n::Integer) = (resize!(L.L, n); resize!(L.F, n); L)
215213
LinearAlgebra.opnorm(L::InvertibleOperator{T}, p=2) where{T} = one(T) / opnorm(L.F)
216214
LinearAlgebra.issuccess(L::InvertibleOperator) = issuccess(L.F)
217215

218216
function update_coefficients(L::InvertibleOperator, u, p, t)
217+
@set! L.L = update_coefficients(L.L, u, p, t)
219218
@set! L.F = update_coefficients(L.F, u, p, t)
220219
L
221220
end
222221

223-
getops(L::InvertibleOperator) = (L.F,)
224-
islinear(L::InvertibleOperator) = islinear(L.F)
222+
getops(L::InvertibleOperator) = (L.L, L.F,)
223+
islinear(L::InvertibleOperator) = islinear(L.L)
225224

226-
@forward InvertibleOperator.F (
225+
@forward InvertibleOperator.L (
227226
# LinearAlgebra
228227
LinearAlgebra.issymmetric,
229228
LinearAlgebra.ishermitian,
@@ -234,15 +233,16 @@ islinear(L::InvertibleOperator) = islinear(L.F)
234233
has_adjoint,
235234
has_mul,
236235
has_mul!,
237-
has_ldiv,
238-
has_ldiv!,
239236
)
240237

238+
has_ldiv(L::InvertibleOperator) = has_mul(L.F)
239+
has_ldiv!(L::InvertibleOperator) = has_ldiv!(L.F)
240+
241241
# operator application
242-
Base.:*(L::InvertibleOperator, x::AbstractVecOrMat) = L.F * x
242+
Base.:*(L::InvertibleOperator, x::AbstractVecOrMat) = L.L * x
243243
Base.:\(L::InvertibleOperator, x::AbstractVecOrMat) = L.F \ x
244-
LinearAlgebra.mul!(v::AbstractVecOrMat, L::InvertibleOperator, u::AbstractVecOrMat) = mul!(v, L.F, u)
245-
LinearAlgebra.mul!(v::AbstractVecOrMat, L::InvertibleOperator, u::AbstractVecOrMat,α, β) = mul!(v, L.F, u, α, β)
244+
LinearAlgebra.mul!(v::AbstractVecOrMat, L::InvertibleOperator, u::AbstractVecOrMat) = mul!(v, L.L, u)
245+
LinearAlgebra.mul!(v::AbstractVecOrMat, L::InvertibleOperator, u::AbstractVecOrMat,α, β) = mul!(v, L.L, u, α, β)
246246
LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::InvertibleOperator, u::AbstractVecOrMat) = ldiv!(v, L.F, u)
247247
LinearAlgebra.ldiv!(L::InvertibleOperator, u::AbstractVecOrMat) = ldiv!(L.F, u)
248248

test/basic.jl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -222,14 +222,8 @@ end
222222
v=copy(u); @test ldiv!(op, u) (A * B * C) \ v
223223

224224
# Test caching of composed operator when inner ops do not support Base.:*
225-
# See issue #129
225+
# ComposedOperator caching was modified in PR # 174
226226
inner_op = qr(MatrixOperator(rand(N, N)))
227-
# We use the QR factorization of a non-square matrix, which does
228-
# not support * as verified below.
229-
@test !has_mul(inner_op)
230-
@test has_ldiv(inner_op)
231-
@test_throws MethodError inner_op * u
232-
# We can now test that caching does not rely on matmul
233227
op = inner_op * factorize(MatrixOperator(rand(N, N)))
234228
@test !iscached(op)
235229
@test_nowarn op = cache_operator(op, rand(N))

test/func.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
using SciMLOperators, LinearAlgebra
33
using Random
44

5-
using SciMLOperators: InvertibleOperator,
5+
using SciMLOperators:
66

77
Random.seed!(0)
88
N = 8

test/total.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ end
137137
(SciMLOperators.AdjointOperator(F), SciMLOperators.AdjointOperator),
138138
(SciMLOperators.TransposedOperator(F), SciMLOperators.TransposedOperator),
139139
(SciMLOperators.InvertedOperator(F), SciMLOperators.InvertedOperator),
140-
(SciMLOperators.InvertibleOperator(F), SciMLOperators.InvertibleOperator),
140+
(SciMLOperators.InvertibleOperator(F, F), SciMLOperators.InvertibleOperator),
141141
)
142142

143143
L = deepcopy(L)

test/zygote.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ for (op_type, A) in
2929
(AffineOperator, AffineOperator(rand(N,N), rand(N,N), rand(N,K))),
3030
(ScaledOperator, rand() * MatrixOperator(rand(N,N))),
3131
(InvertedOperator, InvertedOperator(rand(N,N) |> MatrixOperator)),
32-
(InvertibleOperator, InvertibleOperator(rand(N,N) |> MatrixOperator)),
32+
(InvertibleOperator, InvertibleOperator(MatrixOperator(M), MatrixOperator(inv(M)))),
3333
(BatchedDiagonalOperator, DiagonalOperator(rand(N,K))),
3434
(AddedOperator, MatrixOperator(rand(N,N)) + MatrixOperator(rand(N,N))),
3535
(ComposedOperator, MatrixOperator(rand(N,N)) * MatrixOperator(rand(N,N))),

0 commit comments

Comments
 (0)