Skip to content

Commit 44f7121

Browse files
committed
Updated interface for operators, addition of new and modified tests
1 parent 41108ec commit 44f7121

27 files changed

+5635
-28
lines changed

src/SciMLOperators.jl

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,16 @@ A lazy operator algebra is also defined for `AbstractSciMLOperator`s.
2828
2929
# Interface
3030
31-
An `AbstractSciMLOperator` can be called like a function. This behaves
32-
like multiplication by the linear operator represented by the
33-
`AbstractSciMLOperator`. Possible signatures are
31+
An `AbstractSciMLOperator` can be called like a function in the following ways:
3432
35-
- `L(v, u, p, t)` for in-place operator evaluation
36-
- `v = L(u, p, t)` for out-of-place operator evaluation
33+
- `L(v, u, p, t)` - Out-of-place application where `v` is the action vector and `u` is the update vector
34+
- `L(w, v, u, p, t)` - In-place application where `w` is the destination, `v` is the action vector, and `u` is the update vector
35+
- `L(w, v, u, p, t, α, β)` - In-place application with scaling: `w = α*(L*v) + β*w`
3736
38-
Operator evaluation methods update its coefficients with `(u, p, t)`
39-
information using the `update_coefficients(!)` method. The methods
40-
are exported and can be called as follows:
37+
Operator state can be updated separately from application:
4138
42-
- `update_coefficients!(L, u, p, t)` for out-of-place operator update
43-
- `L = update_coefficients(L, u, p, t)` for in-place operator update
39+
- `update_coefficients!(L, u, p, t)` for in-place operator update
40+
- `L = update_coefficients(L, u, p, t)` for out-of-place operator update
4441
4542
SciMLOperators also overloads `Base.*`, `LinearAlgebra.mul!`,
4643
`LinearAlgebra.ldiv!` for operator evaluation without updating operator state.

src/basic.jl

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,27 @@ function LinearAlgebra.ldiv!(ii::IdentityOperator, u::AbstractVecOrMat)
6969
u
7070
end
7171

72+
# Out-of-place: v is action vector, u is update vector
73+
function (ii::IdentityOperator)(v::AbstractVecOrMat, u, p, t; kwargs...)
74+
@assert size(v, 1) == ii.len
75+
update_coefficients(ii, u, p, t; kwargs...)
76+
copy(v)
77+
end
78+
79+
# In-place: w is destination, v is action vector, u is update vector
80+
function (ii::IdentityOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...)
81+
@assert size(v, 1) == ii.len
82+
update_coefficients!(ii, u, p, t; kwargs...)
83+
copy!(w, v)
84+
end
85+
86+
# In-place with scaling: w = α*(ii*v) + β*w
87+
function (ii::IdentityOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs...)
88+
@assert size(v, 1) == ii.len
89+
update_coefficients!(ii, u, p, t; kwargs...)
90+
mul!(w, I, v, α, β)
91+
end
92+
7293
# operator fusion with identity returns operator itself
7394
for op in (:*, :)
7495
@eval function Base.$op(ii::IdentityOperator, A::AbstractSciMLOperator)
@@ -146,6 +167,29 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat,
146167
lmul!(β, v)
147168
end
148169

170+
# Out-of-place: v is action vector, u is update vector
171+
function (nn::NullOperator)(v::AbstractVecOrMat, u, p, t; kwargs...)
172+
@assert size(v, 1) == nn.len
173+
update_coefficients(nn, u, p, t; kwargs...)
174+
zero(v)
175+
end
176+
177+
# In-place: w is destination, v is action vector, u is update vector
178+
function (nn::NullOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...)
179+
@assert size(v, 1) == nn.len
180+
update_coefficients!(nn, u, p, t; kwargs...)
181+
lmul!(false, w)
182+
w
183+
end
184+
185+
# In-place with scaling: w = α*(nn*v) + β*w
186+
function (nn::NullOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs...)
187+
@assert size(v, 1) == nn.len
188+
update_coefficients!(nn, u, p, t; kwargs...)
189+
lmul!(β, w)
190+
w
191+
end
192+
149193
# operator fusion, composition
150194
for op in (:*, :)
151195
@eval function Base.$op(nn::NullOperator, A::AbstractSciMLOperator)
@@ -336,6 +380,43 @@ function LinearAlgebra.ldiv!(L::ScaledOperator, u::AbstractVecOrMat)
336380
ldiv!(L.L, u)
337381
end
338382

383+
# Out-of-place: v is action vector, u is update vector
384+
function (L::ScaledOperator)(v::AbstractVecOrMat, u, p, t; kwargs...)
385+
L = update_coefficients(L, u, p, t; kwargs...)
386+
if iszero(L.λ)
387+
return zero(v)
388+
else
389+
return L.λ * (L.L * v)
390+
end
391+
end
392+
393+
# In-place: w is destination, v is action vector, u is update vector
394+
function (L::ScaledOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...)
395+
update_coefficients!(L, u, p, t; kwargs...)
396+
if iszero(L.λ)
397+
lmul!(false, w)
398+
return w
399+
else
400+
a = convert(Number, L.λ)
401+
mul!(w, L.L, v, a, false)
402+
return w
403+
end
404+
end
405+
406+
# In-place with scaling: w = α*(L*v) + β*w
407+
function (L::ScaledOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs...)
408+
update_coefficients!(L, u, p, t; kwargs...)
409+
if iszero(L.λ)
410+
lmul!(β, w)
411+
return w
412+
else
413+
a = convert(Number, L.λ * α)
414+
mul!(w, L.L, v, a, β)
415+
return w
416+
end
417+
end
418+
419+
339420
"""
340421
Lazy operator addition
341422
@@ -538,6 +619,35 @@ end
538619
v
539620
end
540621
end
622+
# Out-of-place: v is action vector, u is update vector
623+
function (L::AddedOperator)(v::AbstractVecOrMat, u, p, t; kwargs...)
624+
L = update_coefficients(L, u, p, t; kwargs...)
625+
sum(op -> iszero(op) ? zero(v) : op(v, u, p, t; kwargs...), L.ops)
626+
end
627+
628+
# In-place: w is destination, v is action vector, u is update vector
629+
function (L::AddedOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...)
630+
update_coefficients!(L, u, p, t; kwargs...)
631+
L.ops[1](w, v, u, p, t; kwargs...)
632+
for i in 2:length(L.ops)
633+
if !iszero(L.ops[i])
634+
L.ops[i](w, v, u, p, t, 1.0, 1.0; kwargs...)
635+
end
636+
end
637+
w
638+
end
639+
640+
# In-place with scaling: w = α*(L*v) + β*w
641+
function (L::AddedOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs...)
642+
update_coefficients!(L, u, p, t; kwargs...)
643+
lmul!(β, w)
644+
for op in L.ops
645+
if !iszero(op)
646+
op(w, v, u, p, t, α, 1.0; kwargs...)
647+
end
648+
end
649+
w
650+
end
541651

542652
"""
543653
Lazy operator composition
@@ -792,6 +902,41 @@ function LinearAlgebra.ldiv!(L::ComposedOperator, u::AbstractVecOrMat)
792902
u
793903
end
794904

905+
# Out-of-place: v is action vector, u is update vector
906+
function (L::ComposedOperator)(v::AbstractVecOrMat, u, p, t; kwargs...)
907+
L = update_coefficients(L, u, p, t; kwargs...)
908+
result = v
909+
for op in reverse(L.ops)
910+
result = op(result, u, p, t; kwargs...)
911+
end
912+
result
913+
end
914+
915+
# In-place: w is destination, v is action vector, u is update vector
916+
function (L::ComposedOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...)
917+
update_coefficients!(L, u, p, t; kwargs...)
918+
@assert iscached(L) "Cache needs to be set up for ComposedOperator. Call cache_operator(L, u) first."
919+
920+
vecs = (w, L.cache[1:(end-1)]..., v)
921+
for i in reverse(1:length(L.ops))
922+
L.ops[i](vecs[i], vecs[i+1], u, p, t; kwargs...)
923+
end
924+
w
925+
end
926+
927+
# In-place with scaling: w = α*(L*v) + β*w
928+
function (L::ComposedOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs...)
929+
update_coefficients!(L, u, p, t; kwargs...)
930+
@assert iscached(L) "Cache needs to be set up for ComposedOperator. Call cache_operator(L, u) first."
931+
932+
cache = L.cache[end]
933+
copy!(cache, w)
934+
935+
L(w, v, u, p, t; kwargs...)
936+
lmul!(α, w)
937+
axpy!(β, cache, w)
938+
end
939+
795940
"""
796941
Lazy Operator Inverse
797942
"""
@@ -909,4 +1054,29 @@ function LinearAlgebra.ldiv!(L::InvertedOperator, u::AbstractVecOrMat)
9091054
copy!(L.cache, u)
9101055
mul!(u, L.L, L.cache)
9111056
end
1057+
1058+
# Out-of-place: v is action vector, u is update vector
1059+
function (L::InvertedOperator)(v::AbstractVecOrMat, u, p, t; kwargs...)
1060+
L = update_coefficients(L, u, p, t; kwargs...)
1061+
L.L \ v
1062+
end
1063+
1064+
# In-place: w is destination, v is action vector, u is update vector
1065+
function (L::InvertedOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...)
1066+
update_coefficients!(L, u, p, t; kwargs...)
1067+
ldiv!(w, L.L, v)
1068+
w
1069+
end
1070+
1071+
# In-place with scaling: w = α*(L*v) + β*w
1072+
function (L::InvertedOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs...)
1073+
update_coefficients!(L, u, p, t; kwargs...)
1074+
@assert iscached(L) "Cache needs to be set up for InvertedOperator. Call cache_operator(L, u) first."
1075+
1076+
copy!(L.cache, w)
1077+
ldiv!(w, L.L, v)
1078+
lmul!(α, w)
1079+
axpy!(β, L.cache, w)
1080+
w
1081+
end
9121082
#

src/batch.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,4 +153,25 @@ function LinearAlgebra.ldiv!(L::BatchedDiagonalOperator, u::AbstractVecOrMat)
153153

154154
u
155155
end
156+
157+
function (L::BatchedDiagonalOperator)(v::AbstractVecOrMat, u, p, t; kwargs...)
158+
L = update_coefficients(L, u, p, t; kwargs...)
159+
L.diag .* v
160+
end
161+
162+
function (L::BatchedDiagonalOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...)
163+
update_coefficients!(L, u, p, t; kwargs...)
164+
w .= L.diag .* v
165+
return w
166+
end
167+
168+
function (L::BatchedDiagonalOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs...)
169+
update_coefficients!(L, u, p, t; kwargs...)
170+
if β == 0
171+
w .= α .* (L.diag .* v)
172+
else
173+
w .= α .* (L.diag .* v) .+ β .* w
174+
end
175+
return w
176+
end
156177
#

src/interface.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ $(UPDATE_COEFFS_WARNING)
4343
# Example
4444
4545
```
46-
using SciMLOperator
46+
using SciMLOperators
4747
4848
mat_update_func = (A, u, p, t; scale = 1.0) -> p * p' * scale * t
4949
@@ -56,8 +56,12 @@ u = rand(4)
5656
p = rand(4)
5757
t = 1.0
5858
59+
# Update the operator and apply it to `u`
5960
L = update_coefficients(L, u, p, t; scale = 2.0)
60-
L * u
61+
result = L * u
62+
63+
# Or use the interface which separrates the update from the application
64+
result = L(u, u, p, t; scale = 2.0)
6165
```
6266
6367
"""
@@ -108,6 +112,7 @@ end
108112
###
109113
# operator evaluation interface
110114
###
115+
111116
# Out-of-place: v is action vector, u is update vector
112117
function (L::AbstractSciMLOperator)(v, u, p, t; kwargs...)
113118
update_coefficients(L, u, p, t; kwargs...) * v

src/left.jl

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,4 +157,81 @@ for (op, LType, VType) in ((:adjoint, :AdjointOperator, :AbstractAdjointVecOrMat
157157
u
158158
end
159159
end
160+
161+
162+
# For AdjointOperator
163+
# Out-of-place: v is action vector, u is update vector
164+
function (L::AdjointOperator)(v::AbstractVecOrMat, u, p, t; kwargs...)
165+
# Adjoint operator applied to v means L.L' * v
166+
# For matrices: (A')v = (v'A)'
167+
# This means we need to compute L.L(v', u, p, t)'
168+
# Reshape v to match the adjoint operator's expected size
169+
adjv = reshape(v', size(L.L, 1), :)
170+
result = L.L(adjv, u, p, t; kwargs...)
171+
return reshape(result', size(L, 1), :)
172+
end
173+
174+
# In-place: w is destination, v is action vector, u is update vector
175+
function (L::AdjointOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...)
176+
# Need temporary storage for adjoint operations
177+
temp_v = reshape(v', size(L.L, 1), :)
178+
temp_w = similar(temp_v)
179+
180+
# Apply the internal operator
181+
L.L(temp_w, temp_v, u, p, t; kwargs...)
182+
183+
# Copy back to w with adjoint
184+
w .= reshape(temp_w', size(L, 1), :)
185+
return w
186+
end
187+
188+
# In-place with scaling: w = α*(L*v) + β*w
189+
function (L::AdjointOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs...)
190+
# Handle scaling of existing w
191+
if β != 1.0
192+
lmul!(β, w)
193+
end
194+
195+
# Need temporary storage for adjoint operations
196+
temp_v = reshape(v', size(L.L, 1), :)
197+
temp_w = similar(temp_v)
198+
199+
# Apply the internal operator
200+
L.L(temp_w, temp_v, u, p, t, 1.0, 0.0; kwargs...)
201+
202+
# Add α * result' to w
203+
w .+= α .* reshape(temp_w', size(L, 1), :)
204+
return w
205+
end
206+
207+
# For TransposedOperator
208+
function (L::TransposedOperator)(v::AbstractVecOrMat, u, p, t; kwargs...)
209+
transv = reshape(v', size(L.L, 1), :)
210+
result = L.L(transv, u, p, t; kwargs...)
211+
return reshape(result', size(L, 1), :)
212+
end
213+
214+
function (L::TransposedOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...)
215+
temp_v = reshape(v', size(L.L, 1), :)
216+
temp_w = similar(temp_v)
217+
218+
L.L(temp_w, temp_v, u, p, t; kwargs...)
219+
220+
w .= reshape(temp_w', size(L, 1), :)
221+
return w
222+
end
223+
224+
function (L::TransposedOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs...)
225+
if β != 1.0
226+
lmul!(β, w)
227+
end
228+
229+
temp_v = reshape(v', size(L.L, 1), :)
230+
temp_w = similar(temp_v)
231+
232+
L.L(temp_w, temp_v, u, p, t, 1.0, 0.0; kwargs...)
233+
234+
w .+= α .* reshape(temp_w', size(L, 1), :)
235+
return w
236+
end
160237
#

0 commit comments

Comments
 (0)