Skip to content

Commit 369fc52

Browse files
Merge pull request #143 from gaurav-arya/ag-kwargs
Propagate kwargs through update_coefficients!
2 parents 29c31a4 + 3a665ca commit 369fc52

File tree

10 files changed

+338
-166
lines changed

10 files changed

+338
-166
lines changed

docs/src/interface.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,26 @@ the proof to affine operators, so then ``exp(A*t)*v`` operations via Krylov meth
5555
affine as well, and all sorts of things. Thus affine operators have no matrix representation but they
5656
are still compatible with essentially any Krylov method which would otherwise be compatible with
5757
matrix-free representations, hence their support in the SciMLOperators interface.
58+
59+
## Note about keyword arguments to `update_coefficients!`
60+
61+
In rare cases, an operator may be used in a context where additional state is expected to be provided
62+
to `update_coefficients!` beyond `u`, `p`, and `t`. In this case, the operator may accept this additional
63+
state through arbitrary keyword arguments to `update_coefficients!`. When the caller provides these, they will be recursively propagated downwards through composed operators just like `u`, `p`, and `t`, and provided to the operator.
64+
For the [premade SciMLOperators](premade_operators.md), one can specify the keyword arguments used by an operator with an `accepted_kwargs` argument (by default, none are passed).
65+
66+
In the below example, we create an operator that gleefully ignores `u`, `p`, and `t` and uses its own special scaling.
67+
```@example
68+
using SciMLOperators
69+
70+
γ = ScalarOperator(0.0; update_func=(a, u, p, t; my_special_scaling) -> my_special_scaling,
71+
accepted_kwargs=(:my_special_scaling,))
72+
73+
# Update coefficients, then apply operator
74+
update_coefficients!(γ, nothing, nothing, nothing; my_special_scaling=7.0)
75+
@show γ * [2.0]
76+
77+
# Use operator application form
78+
@show γ([2.0], nothing, nothing; my_special_scaling = 5.0)
79+
nothing # hide
80+
```

src/batch.jl

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
#
22
"""
3-
BatchedDiagonalOperator(diag, [; update_func])
3+
BatchedDiagonalOperator(diag; update_func, update_func!, accepted_kwargs)
44
55
Represents a time-dependent elementwise scaling (diagonal-scaling) operation.
66
Acts on `AbstractArray`s of the same size as `diag`. The update function is called
77
by `update_coefficients!` and is assumed to have the following signature:
88
9-
update_func(diag::AbstractVector,u,p,t) -> [modifies diag]
9+
update_func(diag::AbstractArray, u, p, t; <accepted kwarg fields>) -> [modifies diag]
1010
"""
1111
struct BatchedDiagonalOperator{T,D,F,F!} <: AbstractSciMLOperator{T}
1212
diag::D
1313
update_func::F
1414
update_func!::F!
1515

1616
function BatchedDiagonalOperator(diag::AbstractArray, update_func, update_func!)
17+
1718
new{
1819
eltype(diag),
1920
typeof(diag),
@@ -25,15 +26,16 @@ struct BatchedDiagonalOperator{T,D,F,F!} <: AbstractSciMLOperator{T}
2526
end
2627
end
2728

28-
function BatchedDiagonalOperator(diag::AbstractArray;
29-
update_func = DEFAULT_UPDATE_FUNC,
30-
update_func! = DEFAULT_UPDATE_FUNC)
31-
BatchedDiagonalOperator(diag, update_func, update_func!)
32-
end
29+
function DiagonalOperator(u::AbstractArray;
30+
update_func = DEFAULT_UPDATE_FUNC,
31+
update_func! = DEFAULT_UPDATE_FUNC,
32+
accepted_kwargs = nothing
33+
)
34+
35+
update_func = preprocess_update_func(update_func , accepted_kwargs)
36+
update_func! = preprocess_update_func(update_func!, accepted_kwargs)
3337

34-
function DiagonalOperator(u::AbstractArray; update_func = DEFAULT_UPDATE_FUNC,
35-
update_func! = DEFAULT_UPDATE_FUNC)
36-
BatchedDiagonalOperator(u; update_func = update_func, update_func! = update_func!)
38+
BatchedDiagonalOperator(u, update_func, update_func!)
3739
end
3840

3941
# traits
@@ -46,38 +48,39 @@ function Base.conj(L::BatchedDiagonalOperator) # TODO - test this thoroughly
4648
update_func = if isreal(L)
4749
L.update_func
4850
else
49-
(L,u,p,t) -> conj(L.update_func(conj(L.diag),u,p,t))
51+
(L,u,p,t; kwargs...) -> conj(L.update_func(conj(L.diag),u,p,t; kwargs...))
5052
end
5153
BatchedDiagonalOperator(diag; update_func=update_func)
5254
end
5355

54-
function update_coefficients(L::BatchedDiagonalOperator,u,p,t)
55-
@set! L.diag = L.update_func(L.diag,u,p,t)
56+
LinearAlgebra.issymmetric(L::BatchedDiagonalOperator) = true
57+
function LinearAlgebra.ishermitian(L::BatchedDiagonalOperator)
58+
if isreal(L)
59+
true
60+
else
61+
vec(L.diag) |> Diagonal |> ishermitian
62+
end
63+
end
64+
LinearAlgebra.isposdef(L::BatchedDiagonalOperator) = isposdef(Diagonal(vec(L.diag)))
65+
66+
function update_coefficients(L::BatchedDiagonalOperator,u ,p, t; kwargs...)
67+
@set! L.diag = L.update_func(L.diag, u, p, t; kwargs...)
68+
end
69+
70+
function update_coefficients!(L::BatchedDiagonalOperator, u, p, t; kwargs...)
71+
L.update_func!(L.diag, u, p, t; kwargs...)
5672
end
57-
update_coefficients!(L::BatchedDiagonalOperator,u,p,t) = (L.update_func!(L.diag,u,p,t); L)
5873

5974
getops(L::BatchedDiagonalOperator) = (L.diag,)
6075

6176
function isconstant(L::BatchedDiagonalOperator)
62-
L.update_func == L.update_func! == DEFAULT_UPDATE_FUNC
77+
update_func_isconstant(L.update_func) & update_func_isconstant(L.update_func!)
6378
end
6479
islinear(::BatchedDiagonalOperator) = true
6580
has_adjoint(L::BatchedDiagonalOperator) = true
6681
has_ldiv(L::BatchedDiagonalOperator) = all(x -> !iszero(x), L.diag)
6782
has_ldiv!(L::BatchedDiagonalOperator) = has_ldiv(L)
6883

69-
LinearAlgebra.issymmetric(L::BatchedDiagonalOperator) = true
70-
function LinearAlgebra.ishermitian(L::BatchedDiagonalOperator)
71-
if isreal(L)
72-
true
73-
else
74-
d = vec(L.diag)
75-
D = Diagonal(d)
76-
ishermitian(d)
77-
end
78-
end
79-
LinearAlgebra.isposdef(L::BatchedDiagonalOperator) = isposdef(Diagonal(vec(L.diag)))
80-
8184
# operator application
8285
Base.:*(L::BatchedDiagonalOperator, u::AbstractVecOrMat) = L.diag .* u
8386
Base.:\(L::BatchedDiagonalOperator, u::AbstractVecOrMat) = L.diag .\ u

src/func.jl

Lines changed: 49 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ function FunctionOperator(op,
8484
FunctionOperator(op, input, output; kwargs...)
8585
end
8686

87+
# TODO: document constructor and revisit design as needed (e.g. for "accepted_kwargs")
8788
function FunctionOperator(op,
8889
input::AbstractVecOrMat,
8990
output::AbstractVecOrMat = input;
@@ -101,6 +102,7 @@ function FunctionOperator(op,
101102

102103
p=nothing,
103104
t::Union{Number,Nothing}=nothing,
105+
accepted_kwargs::NTuple{N,Symbol} = (),
104106

105107
ifcache::Bool = true,
106108

@@ -111,7 +113,7 @@ function FunctionOperator(op,
111113
issymmetric::Bool = false,
112114
ishermitian::Bool = false,
113115
isposdef::Bool = false,
114-
)
116+
) where{N}
115117

116118
# store eltype of input/output for caching with ComposedOperator.
117119
eltypes = eltype.((input, output))
@@ -181,6 +183,8 @@ function FunctionOperator(op,
181183
T = T,
182184
size = sz,
183185
eltypes = eltypes,
186+
accepted_kwargs = accepted_kwargs,
187+
kwargs = Dict{Symbol, Any}(),
184188
)
185189

186190
L = FunctionOperator(
@@ -191,7 +195,7 @@ function FunctionOperator(op,
191195
traits,
192196
p,
193197
t,
194-
cache,
198+
cache
195199
)
196200

197201
if ifcache & isnothing(L.cache)
@@ -201,36 +205,40 @@ function FunctionOperator(op,
201205
L
202206
end
203207

204-
function update_coefficients(L::FunctionOperator, u, p, t)
205-
206-
if isconstant(L)
207-
return L
208-
end
209-
210-
@set! L.op = update_coefficients(L.op, u, p, t)
211-
@set! L.op_adjoint = update_coefficients(L.op_adjoint, u, p, t)
212-
@set! L.op_inverse = update_coefficients(L.op_inverse, u, p, t)
213-
@set! L.op_adjoint_inverse = update_coefficients(L.op_adjoint_inverse, u, p, t)
208+
function update_coefficients(L::FunctionOperator, u, p, t; kwargs...)
214209

210+
# update p, t
215211
@set! L.p = p
216212
@set! L.t = t
217213

218-
L
219-
end
214+
# filter and update kwargs
215+
filtered_kwargs = get_filtered_kwargs(kwargs, L.traits.accepted_kwargs)
216+
@set! L.traits.kwargs = Dict{Symbol, Any}(filtered_kwargs)
220217

221-
function update_coefficients!(L::FunctionOperator, u, p, t)
218+
isconstant(L) && return L
222219

223-
if isconstant(L)
224-
return L
225-
end
220+
@set! L.op = update_coefficients(L.op, u, p, t; filtered_kwargs...)
221+
@set! L.op_adjoint = update_coefficients(L.op_adjoint, u, p, t; filtered_kwargs...)
222+
@set! L.op_inverse = update_coefficients(L.op_inverse, u, p, t; filtered_kwargs...)
223+
@set! L.op_adjoint_inverse = update_coefficients(L.op_adjoint_inverse, u, p, t; filtered_kwargs...)
224+
end
226225

227-
for op in getops(L)
228-
update_coefficients!(op, u, p, t)
229-
end
226+
function update_coefficients!(L::FunctionOperator, u, p, t; kwargs...)
230227

228+
# update p, t
231229
L.p = p
232230
L.t = t
233231

232+
# filter and update kwargs
233+
filtered_kwargs = get_filtered_kwargs(kwargs, L.traits.accepted_kwargs)
234+
L.traits = (; L.traits..., kwargs = Dict{Symbol, Any}(filtered_kwargs))
235+
236+
isconstant(L) && return
237+
238+
for op in getops(L)
239+
update_coefficients!(op, u, p, t; filtered_kwargs...)
240+
end
241+
234242
L
235243
end
236244

@@ -267,9 +275,6 @@ function Base.adjoint(L::FunctionOperator)
267275
@set! traits.size = reverse(size(L))
268276
@set! traits.eltypes = reverse(traits.eltypes)
269277

270-
p = L.p
271-
t = L.t
272-
273278
cache = if iscached(L)
274279
cache = reverse(L.cache)
275280
else
@@ -281,8 +286,8 @@ function Base.adjoint(L::FunctionOperator)
281286
op_inverse,
282287
op_adjoint_inverse,
283288
traits,
284-
p,
285-
t,
289+
L.p,
290+
L.t,
286291
cache,
287292
)
288293
end
@@ -310,9 +315,6 @@ function Base.inv(L::FunctionOperator)
310315
(p::Real) -> 1 / traits.opnorm(p)
311316
end
312317

313-
p = L.p
314-
t = L.t
315-
316318
cache = if iscached(L)
317319
cache = reverse(L.cache)
318320
else
@@ -324,8 +326,8 @@ function Base.inv(L::FunctionOperator)
324326
op_inverse,
325327
op_adjoint_inverse,
326328
traits,
327-
p,
328-
t,
329+
L.p,
330+
L.t,
329331
cache,
330332
)
331333
end
@@ -353,8 +355,8 @@ function LinearAlgebra.opnorm(L::FunctionOperator, p)
353355
argument. E.g., `(p::Real) -> p == Inf ? 100 : error("only Inf norm is
354356
defined")`
355357
""")
356-
opn = L.opnorm
357-
return opn isa Number ? opn : L.opnorm(p)
358+
opn = L.traits.opnorm
359+
return opn isa Number ? opn : L.traits.opnorm(p)
358360
end
359361
LinearAlgebra.issymmetric(L::FunctionOperator) = L.traits.issymmetric
360362
LinearAlgebra.ishermitian(L::FunctionOperator) = L.traits.ishermitian
@@ -373,31 +375,36 @@ end
373375
islinear(L::FunctionOperator) = L.traits.islinear
374376
isconstant(L::FunctionOperator) = L.traits.isconstant
375377
has_adjoint(L::FunctionOperator) = !(L.op_adjoint isa Nothing)
376-
has_mul(L::FunctionOperator{iip}) where{iip} = true
377-
has_mul!(L::FunctionOperator{iip}) where{iip} = iip
378+
has_mul(::FunctionOperator{iip}) where{iip} = true
379+
has_mul!(::FunctionOperator{iip}) where{iip} = iip
378380
has_ldiv(L::FunctionOperator{iip}) where{iip} = !(L.op_inverse isa Nothing)
379381
has_ldiv!(L::FunctionOperator{iip}) where{iip} = iip & !(L.op_inverse isa Nothing)
380382

381383
# TODO - FunctionOperator, Base.conj, transpose
382384

383385
# operator application
384-
Base.:*(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip} = L.op(u, L.p, L.t)
385-
Base.:\(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip} = L.op_inverse(u, L.p, L.t)
386+
function Base.:*(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip}
387+
L.op(u, L.p, L.t; L.traits.kwargs...)
388+
end
389+
390+
function Base.:\(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip}
391+
L.op_inverse(u, L.p, L.t; L.traits.kwargs...)
392+
end
386393

387394
function Base.:*(L::FunctionOperator{true,false}, u::AbstractVecOrMat)
388395
_, co = L.cache
389396
du = zero(co)
390-
L.op(du, u, L.p, L.t)
397+
L.op(du, u, L.p, L.t; L.traits.kwargs...)
391398
end
392399

393400
function Base.:\(L::FunctionOperator{true,false}, u::AbstractVecOrMat)
394401
ci, _ = L.cache
395402
du = zero(ci)
396-
L.op_inverse(du, u, L.p, L.t)
403+
L.op_inverse(du, u, L.p, L.t; L.traits.kwargs...)
397404
end
398405

399406
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{true}, u::AbstractVecOrMat)
400-
L.op(v, u, L.p, L.t)
407+
L.op(v, u, L.p, L.t; L.traits.kwargs...)
401408
end
402409

403410
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{false}, u::AbstractVecOrMat, args...)
@@ -414,11 +421,11 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{true, oop,
414421
end
415422

416423
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{true, oop, true}, u::AbstractVecOrMat, α, β) where{oop}
417-
L.op(v, u, L.p, L.t, α, β)
424+
L.op(v, u, L.p, L.t, α, β; L.traits.kwargs...)
418425
end
419426

420427
function LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::FunctionOperator{true}, u::AbstractVecOrMat)
421-
L.op_inverse(v, u, L.p, L.t)
428+
L.op_inverse(v, u, L.p, L.t; L.traits.kwargs...)
422429
end
423430

424431
function LinearAlgebra.ldiv!(L::FunctionOperator{true}, u::AbstractVecOrMat)

0 commit comments

Comments
 (0)