@@ -183,14 +183,15 @@ uniform across `op`, `op_adjoint`, `op_inverse`, `op_adjoint_inverse`.
183183* `issymmetric` - `true` if the operator is linear and symmetric. Defaults to `false`.
184184* `ishermitian` - `true` if the operator is linear and hermitian. Defaults to `false`.
185185* `isposdef` - `true` if the operator is linear and positive-definite. Defaults to `false`.
186+ * `kwargs` - Keyword arguments for cache initialization. If `accepted_kwargs` is provided, the corresponding keyword arguments must be passed.
186187"""
187188function FunctionOperator (op,
188189 input:: AbstractArray ,
189190 output:: AbstractArray = input; op_adjoint = nothing ,
190191 op_inverse = nothing ,
191192 op_adjoint_inverse = nothing , p = nothing ,
192193 t:: Union{Number, Nothing} = nothing ,
193- accepted_kwargs:: NTuple{N, Symbol} = () ,
194+ accepted_kwargs:: Union{Nothing, Val, NTuple{N, Symbol}} = nothing ,
194195
195196 # traits
196197 T:: Union{Type{<:Number}, Nothing} = nothing ,
@@ -207,7 +208,8 @@ function FunctionOperator(op,
207208 opnorm = nothing ,
208209 issymmetric:: Union{Bool, Val} = Val (false ),
209210 ishermitian:: Union{Bool, Val} = Val (false ),
210- isposdef:: Bool = false ) where {N}
211+ isposdef:: Bool = false ,
212+ kwargs... ) where {N}
211213
212214 # establish types
213215
@@ -316,12 +318,21 @@ function FunctionOperator(op,
316318 _op_adjoint_inverse = op_adjoint_inverse
317319 end
318320
321+ if accepted_kwargs === nothing
322+ accepted_kwargs = Val (())
323+ kwargs = NamedTuple ()
324+ else
325+ length (kwargs) != 0 ||
326+ throw (ArgumentError (" No keyword arguments provided. When `accepted_kwargs` is provided, the corresponding keyword arguments must be passed for cache initialization." ))
327+ kwargs = get_filtered_kwargs (kwargs, accepted_kwargs)
328+ end
329+
319330 traits = (; islinear, isconvertible, isconstant, opnorm,
320331 issymmetric = _unwrap_val (issymmetric), ishermitian = _unwrap_val (ishermitian),
321332 isposdef, isinplace = _unwrap_val (_isinplace),
322333 outofplace = _unwrap_val (_outofplace), has_mul5 = _unwrap_val (_has_mul5),
323334 ifcache = _unwrap_val (ifcache), T = _T, batch, size = _size, sizes,
324- accepted_kwargs, kwargs = Dict {Symbol, Any} () )
335+ accepted_kwargs, kwargs = kwargs )
325336
326337 L = FunctionOperator{_unwrap_val (_isinplace), _unwrap_val (_outofplace),
327338 _unwrap_val (_has_mul5), _T, typeof (op), typeof (_op_adjoint), typeof (op_inverse),
@@ -355,7 +366,7 @@ function update_coefficients(L::FunctionOperator, u, p, t; kwargs...)
355366 # filter and update kwargs
356367 filtered_kwargs = get_filtered_kwargs (kwargs, L. traits. accepted_kwargs)
357368
358- L = set_traits (L, merge (L. traits, (; kwargs = Dict {Symbol, Any} ( filtered_kwargs) )))
369+ L = set_traits (L, merge (L. traits, (; kwargs = filtered_kwargs)))
359370
360371 isconstant (L) && return L
361372
@@ -374,7 +385,7 @@ function update_coefficients!(L::FunctionOperator, u, p, t; kwargs...)
374385
375386 # filter and update kwargs
376387 filtered_kwargs = get_filtered_kwargs (kwargs, L. traits. accepted_kwargs)
377- L. traits = (; L. traits... , kwargs = Dict {Symbol, Any} ( filtered_kwargs))
388+ L. traits = merge ( L. traits, (; kwargs = filtered_kwargs))
378389
379390 isconstant (L) && return
380391
0 commit comments