Skip to content

Commit 4818a66

Browse files
committed
moved kwargs to FunctionOp.traits
1 parent e89d5a1 commit 4818a66

File tree

1 file changed

+14
-17
lines changed

1 file changed

+14
-17
lines changed

src/func.jl

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ mutable struct FunctionOperator{iip,oop,mul5,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,C} <:
1717
p::P
1818
""" Time """
1919
t::Tt
20-
""" kwargs """
21-
kwargs::Dict{Symbol,Any} # TODO move inside traits later
2220
""" Cache """
2321
cache::C
2422

@@ -30,7 +28,6 @@ mutable struct FunctionOperator{iip,oop,mul5,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,C} <:
3028
traits,
3129
p,
3230
t,
33-
kwargs,
3431
cache
3532
)
3633

@@ -60,7 +57,6 @@ mutable struct FunctionOperator{iip,oop,mul5,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,C} <:
6057
traits,
6158
p,
6259
t,
63-
kwargs,
6460
cache,
6561
)
6662
end
@@ -124,7 +120,6 @@ function FunctionOperator(op,
124120
sz = (size(output, 1), size(input, 1))
125121
T = isnothing(T) ? promote_type(eltypes...) : T
126122
t = isnothing(t) ? zero(real(T)) : t
127-
kwargs = Dict{Symbol, Any}()
128123

129124
isinplace = if isnothing(isinplace)
130125
static_hasmethod(op, typeof((output, input, p, t)))
@@ -189,6 +184,7 @@ function FunctionOperator(op,
189184
size = sz,
190185
eltypes = eltypes,
191186
accepted_kwargs = accepted_kwargs,
187+
kwargs = Dict{Symbol, Any}(),
192188
)
193189

194190
L = FunctionOperator(
@@ -199,7 +195,6 @@ function FunctionOperator(op,
199195
traits,
200196
p,
201197
t,
202-
kwargs,
203198
cache
204199
)
205200

@@ -212,11 +207,13 @@ end
212207

213208
function update_coefficients(L::FunctionOperator, u, p, t; kwargs...)
214209

210+
# update p, t
215211
@set! L.p = p
216212
@set! L.t = t
217213

214+
# filter and update kwargs
218215
filtered_kwargs = get_filtered_kwargs(kwargs, L.traits.accepted_kwargs)
219-
@set! L.kwargs = Dict(filtered_kwargs)
216+
@set! L.traits.kwargs = Dict{Symbol, Any}(filtered_kwargs)
220217

221218
isconstant(L) && return L
222219

@@ -228,11 +225,13 @@ end
228225

229226
function update_coefficients!(L::FunctionOperator, u, p, t; kwargs...)
230227

228+
# update p, t
231229
L.p = p
232230
L.t = t
233231

232+
# filter and update kwargs
234233
filtered_kwargs = get_filtered_kwargs(kwargs, L.traits.accepted_kwargs)
235-
L.kwargs = Dict(filtered_kwargs)
234+
L.traits = (; L.traits..., kwargs = Dict{Symbol, Any}(filtered_kwargs))
236235

237236
isconstant(L) && return
238237

@@ -289,7 +288,6 @@ function Base.adjoint(L::FunctionOperator)
289288
traits,
290289
L.p,
291290
L.t,
292-
L.kwargs,
293291
cache,
294292
)
295293
end
@@ -330,7 +328,6 @@ function Base.inv(L::FunctionOperator)
330328
traits,
331329
L.p,
332330
L.t,
333-
L.kwargs,
334331
cache,
335332
)
336333
end
@@ -387,27 +384,27 @@ has_ldiv!(L::FunctionOperator{iip}) where{iip} = iip & !(L.op_inverse isa Nothin
387384

388385
# operator application
389386
function Base.:*(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip}
390-
L.op(u, L.p, L.t; L.kwargs...)
387+
L.op(u, L.p, L.t; L.traits.kwargs...)
391388
end
392389

393390
function Base.:\(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip}
394-
L.op_inverse(u, L.p, L.t; L.kwargs...)
391+
L.op_inverse(u, L.p, L.t; L.traits.kwargs...)
395392
end
396393

397394
function Base.:*(L::FunctionOperator{true,false}, u::AbstractVecOrMat)
398395
_, co = L.cache
399396
du = zero(co)
400-
L.op(du, u, L.p, L.t; L.kwargs...)
397+
L.op(du, u, L.p, L.t; L.traits.kwargs...)
401398
end
402399

403400
function Base.:\(L::FunctionOperator{true,false}, u::AbstractVecOrMat)
404401
ci, _ = L.cache
405402
du = zero(ci)
406-
L.op_inverse(du, u, L.p, L.t; L.kwargs...)
403+
L.op_inverse(du, u, L.p, L.t; L.traits.kwargs...)
407404
end
408405

409406
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{true}, u::AbstractVecOrMat)
410-
L.op(v, u, L.p, L.t; L.kwargs...)
407+
L.op(v, u, L.p, L.t; L.traits.kwargs...)
411408
end
412409

413410
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{false}, u::AbstractVecOrMat, args...)
@@ -424,11 +421,11 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{true, oop,
424421
end
425422

426423
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{true, oop, true}, u::AbstractVecOrMat, α, β) where{oop}
427-
L.op(v, u, L.p, L.t, α, β; L.kwargs...)
424+
L.op(v, u, L.p, L.t, α, β; L.traits.kwargs...)
428425
end
429426

430427
function LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::FunctionOperator{true}, u::AbstractVecOrMat)
431-
L.op_inverse(v, u, L.p, L.t; L.kwargs...)
428+
L.op_inverse(v, u, L.p, L.t; L.traits.kwargs...)
432429
end
433430

434431
function LinearAlgebra.ldiv!(L::FunctionOperator{true}, u::AbstractVecOrMat)

0 commit comments

Comments
 (0)