22"""
33 Matrix free operators (given by a function)
44"""
5- mutable struct FunctionOperator{iip,oop,mul5,T<: Number ,F,Fa,Fi,Fai,Tr,P,Tt,K, C} <: AbstractSciMLOperator{T}
5+ mutable struct FunctionOperator{iip,oop,mul5,T<: Number ,F,Fa,Fi,Fai,Tr,P,Tt,C} <: AbstractSciMLOperator{T}
66 """ Function with signature op(u, p, t) and (if isinplace) op(du, u, p, t) """
77 op:: F
88 """ Adjoint operator"""
@@ -17,8 +17,8 @@ mutable struct FunctionOperator{iip,oop,mul5,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,K,C}
1717 p:: P
1818 """ Time """
1919 t:: Tt
20- """ Keyword arguments """
21- kwargs:: K
20+ """ kwargs """
21+ kwargs:: Dict{Symbol,Any} # TODO move inside traits later
2222 """ Cache """
2323 cache:: C
2424
@@ -30,7 +30,7 @@ mutable struct FunctionOperator{iip,oop,mul5,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,K,C}
3030 traits,
3131 p,
3232 t,
33- accepted_kwargs ,
33+ kwargs ,
3434 cache
3535 )
3636
@@ -51,7 +51,6 @@ mutable struct FunctionOperator{iip,oop,mul5,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,K,C}
5151 typeof (traits),
5252 typeof (p),
5353 typeof (t),
54- typeof (accepted_kwargs),
5554 typeof (cache),
5655 }(
5756 op,
@@ -61,7 +60,7 @@ mutable struct FunctionOperator{iip,oop,mul5,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,K,C}
6160 traits,
6261 p,
6362 t,
64- accepted_kwargs ,
63+ kwargs ,
6564 cache,
6665 )
6766 end
@@ -107,7 +106,7 @@ function FunctionOperator(op,
107106
108107 p= nothing ,
109108 t:: Union{Number,Nothing} = nothing ,
110- accepted_kwargs = (),
109+ accepted_kwargs:: NTuple{N,Symbol} = (),
111110
112111 ifcache:: Bool = true ,
113112
@@ -118,13 +117,14 @@ function FunctionOperator(op,
118117 issymmetric:: Bool = false ,
119118 ishermitian:: Bool = false ,
120119 isposdef:: Bool = false ,
121- )
120+ ) where {N}
122121
123122 # store eltype of input/output for caching with ComposedOperator.
124123 eltypes = eltype .((input, output))
125124 sz = (size (output, 1 ), size (input, 1 ))
126125 T = isnothing (T) ? promote_type (eltypes... ) : T
127126 t = isnothing (t) ? zero (real (T)) : t
127+ kwargs = Dict {Symbol, Any} ()
128128
129129 isinplace = if isnothing (isinplace)
130130 static_hasmethod (op, typeof ((output, input, p, t)))
@@ -188,6 +188,7 @@ function FunctionOperator(op,
188188 T = T,
189189 size = sz,
190190 eltypes = eltypes,
191+ accepted_kwargs = accepted_kwargs,
191192 )
192193
193194 L = FunctionOperator (
@@ -198,7 +199,7 @@ function FunctionOperator(op,
198199 traits,
199200 p,
200201 t,
201- normalize_kwargs (accepted_kwargs) ,
202+ kwargs ,
202203 cache
203204 )
204205
@@ -214,9 +215,10 @@ function update_coefficients(L::FunctionOperator, u, p, t; kwargs...)
214215 @set! L. p = p
215216 @set! L. t = t
216217
217- isconstant (L) && return L
218+ filtered_kwargs = get_filtered_kwargs (kwargs, L. traits. accepted_kwargs)
219+ @set! L. kwargs = Dict (filtered_kwargs)
218220
219- filtered_kwargs = (kwarg => kwargs[kwarg] for kwarg in L . kwargs if haskey (kwargs, kwarg))
221+ isconstant (L) && return L
220222
221223 @set! L. op = update_coefficients (L. op, u, p, t; filtered_kwargs... )
222224 @set! L. op_adjoint = update_coefficients (L. op_adjoint, u, p, t; filtered_kwargs... )
@@ -226,17 +228,18 @@ end
226228
227229function update_coefficients! (L:: FunctionOperator , u, p, t; kwargs... )
228230
229- isconstant (L) && return
231+ L. p = p
232+ L. t = t
233+
234+ filtered_kwargs = get_filtered_kwargs (kwargs, L. traits. accepted_kwargs)
235+ L. kwargs = Dict (filtered_kwargs)
230236
231- filtered_kwargs = (kwarg => kwargs[kwarg] for kwarg in L . kwargs if haskey (kwargs, kwarg))
237+ isconstant (L) && return
232238
233239 for op in getops (L)
234240 update_coefficients! (op, u, p, t; filtered_kwargs... )
235241 end
236242
237- L. p = p
238- L. t = t
239-
240243 L
241244end
242245
@@ -383,8 +386,13 @@ has_ldiv!(L::FunctionOperator{iip}) where{iip} = iip & !(L.op_inverse isa Nothin
383386# TODO - FunctionOperator, Base.conj, transpose
384387
385388# operator application
386- Base.:* (L:: FunctionOperator{iip,true} , u:: AbstractVecOrMat ) where {iip} = L. op (u, L. p, L. t)
387- Base.:\ (L:: FunctionOperator{iip,true} , u:: AbstractVecOrMat ) where {iip} = L. op_inverse (u, L. p, L. t; L. kwargs... )
389+ function Base.:* (L:: FunctionOperator{iip,true} , u:: AbstractVecOrMat ) where {iip}
390+ L. op (u, L. p, L. t; L. kwargs... )
391+ end
392+
393+ function Base.:\ (L:: FunctionOperator{iip,true} , u:: AbstractVecOrMat ) where {iip}
394+ L. op_inverse (u, L. p, L. t; L. kwargs... )
395+ end
388396
389397function Base.:* (L:: FunctionOperator{true,false} , u:: AbstractVecOrMat )
390398 _, co = L. cache
0 commit comments