diff --git a/docs/src/interface.md b/docs/src/interface.md index 0f093549..81eeb012 100644 --- a/docs/src/interface.md +++ b/docs/src/interface.md @@ -67,7 +67,7 @@ using SciMLOperators γ = ScalarOperator(0.0; update_func = (a, u, p, t; my_special_scaling) -> my_special_scaling, - accepted_kwargs = (:my_special_scaling,)) + accepted_kwargs = Val((:my_special_scaling,))) # Update coefficients, then apply operator update_coefficients!(γ, nothing, nothing, nothing; my_special_scaling = 7.0) diff --git a/src/func.jl b/src/func.jl index 841ac652..0904efb7 100644 --- a/src/func.jl +++ b/src/func.jl @@ -204,7 +204,7 @@ uniform across `op`, `op_adjoint`, `op_inverse`, `op_adjoint_inverse`. - `u` - Prototype of the state struct passed to the operator during evaluation, i.e. `L(u, p, t)`. `u` is set to `nothing` if no value is provided. - `p` - Prototype of parameter struct passed to the operator during evaluation, i.e. `L(u, p, t)`. `p` is set to `nothing` if no value is provided. - `t` - Protype of scalar time variable passed to the operator during evaluation. `t` is set to `zero(T)` if no value is provided. - - `accepted_kwargs` - `Tuple` of `Symbol`s corresponding to the keyword arguments accepted by `op*`, and `update_coefficients[!]`. For example, if `op` accepts kwarg `scale`, as in `op(u, p, t; scale)`, then `accepted_kwargs = (:scale,)`. + - `accepted_kwargs` - `Val` of a `Tuple` of `Symbol`s for zero-allocation kwarg filtering. Corresponds to the keyword arguments accepted by `op*`, and `update_coefficients[!]`. For example, if `op` accepts kwarg `scale`, as in `op(u, p, t; scale)`, then `accepted_kwargs = Val((:scale,))`. Plain tuples like `(:scale,)` are deprecated but still supported. - `T` - `eltype` of the operator. If no value is provided, the constructor inferrs the value from types of `input`, and `output` - `isinplace` - `true` if the operator can be used is a mutating way with in-place allocations. This trait is inferred if no value is provided. - `outofplace` - `true` if the operator can be used is a non-mutating way with in-place allocations. This trait is inferred if no value is provided. @@ -451,7 +451,7 @@ function Base.copy(L::FunctionOperator) isdefined(L, :p) && L.p !== nothing ? deepcopy(L.p) : nothing, L.t, L.cache === nothing ? nothing : deepcopy(L.cache), - typeof(L).parameters[end-1], # iType + typeof(L).parameters[end - 1], # iType typeof(L).parameters[end] # oType ) end diff --git a/src/matrix.jl b/src/matrix.jl index d71eb3fc..3fe6a1c0 100644 --- a/src/matrix.jl +++ b/src/matrix.jl @@ -14,8 +14,10 @@ or update_func!(A::AbstractMatrix, u, p, t; ) -> [modifies A] -The set of keyword-arguments accepted by `update_func[!]` must be provided -to `MatrixOperator` via the kwarg `accepted_kwargs` as a tuple of `Symbol`s. +The set of keyword-arguments accepted by `update_func[!]` should be provided +to `MatrixOperator` via the kwarg `accepted_kwargs` as a `Val` of a tuple of `Symbol`s +for zero-allocation kwarg filtering. For example, `accepted_kwargs = Val((:dtgamma,))`. +Plain tuples like `(:dtgamma,)` are deprecated but still supported. `kwargs` cannot be passed down to `update_func[!]` if `accepted_kwargs` are not provided. @@ -38,7 +40,7 @@ p = rand(4, 4) t = rand() mat_update = (A, u, p, t; scale = 0.0) -> t * p -M = MatrixOperator(0.0; update_func = mat_update, accepted_kwargs = (:scale,)) +M = MatrixOperator(0.0; update_func = mat_update, accepted_kwargs = Val((:scale,))) L = M * M + 3I L = cache_operator(L, v) @@ -65,7 +67,7 @@ p = rand(4) # Must be non-nothing t = rand() mat_update! = (A, u, p, t; scale = 0.0) -> (A .= t * p * u' * scale) -M = MatrixOperator(zeros(4, 4); update_func! = mat_update!, accepted_kwargs = (:scale,)) +M = MatrixOperator(zeros(4, 4); update_func! = mat_update!, accepted_kwargs = Val((:scale,))) L = M * M + 3I L = cache_operator(L, v) @@ -291,8 +293,10 @@ or update_func!(diag::AbstractVecOrMat, u, p, t; ) -> [modifies diag] -The set of keyword-arguments accepted by `update_func[!]` must be provided -to `MatrixOperator` via the kwarg `accepted_kwargs` as a tuple of `Symbol`s. +The set of keyword-arguments accepted by `update_func[!]` should be provided +to `DiagonalOperator` via the kwarg `accepted_kwargs` as a `Val` of a tuple of `Symbol`s +for zero-allocation kwarg filtering. For example, `accepted_kwargs = Val((:dtgamma,))`. +Plain tuples like `(:dtgamma,)` are deprecated but still supported. `kwargs` cannot be passed down to `update_func[!]` if `accepted_kwargs` are not provided. @@ -501,8 +505,10 @@ and `B`, `b` are expected to have an appropriate size so that `A * v + B * b` makes sense. Specifically, `size(A, 1) == size(B, 1)`, and `size(v, 2) == size(b, 2)`. -The set of keyword-arguments accepted by `update_func[!]` must be provided -to `AffineOperator` via the kwarg `accepted_kwargs` as a tuple of `Symbol`s. +The set of keyword-arguments accepted by `update_func[!]` should be provided +to `AffineOperator` via the kwarg `accepted_kwargs` as a `Val` of a tuple of `Symbol`s +for zero-allocation kwarg filtering. For example, `accepted_kwargs = Val((:dtgamma,))`. +Plain tuples like `(:dtgamma,)` are deprecated but still supported. `kwargs` cannot be passed down to `update_func[!]` if `accepted_kwargs` are not provided. diff --git a/src/utils.jl b/src/utils.jl index 632792a1..a1adf802 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -14,7 +14,19 @@ arguments. Required in implementation of lazy `Base.adjoint`, struct NoKwargFilter end function preprocess_update_func(update_func, accepted_kwargs) - _accepted_kwargs = (accepted_kwargs === nothing) ? () : accepted_kwargs + # Convert accepted_kwargs to Val for compile-time kwarg filtering to avoid allocations + _accepted_kwargs = if accepted_kwargs === nothing + Val(()) + elseif accepted_kwargs isa Tuple + # Deprecation: Encourage users to use Val((...)) directly for better performance + @warn """Passing accepted_kwargs as a plain Tuple is deprecated and will be removed in a future version. + Please use Val((...)) instead for zero-allocation kwarg filtering. + Example: accepted_kwargs = Val((:dtgamma,)) instead of accepted_kwargs = (:dtgamma,) + This message will only be shown once per session.""" maxlog=1 + Val(accepted_kwargs) + else + accepted_kwargs # Already a Val or NoKwargFilter + end # accepted_kwargs can be passed as nothing to indicate that we should not filter # (e.g. if the function already accepts all kwargs...). return (_accepted_kwargs isa NoKwargFilter) ? update_func : @@ -46,7 +58,9 @@ end function get_filtered_kwargs(kwargs::Union{AbstractDict, NamedTuple}, ::Val{accepted_kwargs}) where {accepted_kwargs} kwargs_nt = NamedTuple(kwargs) - return NamedTuple{accepted_kwargs}(kwargs_nt) # This creates a new NamedTuple with keys specified by `accepted_kwargs` + # Only extract keys that exist in kwargs_nt to avoid errors + filtered_keys = filter(k -> haskey(kwargs_nt, k), accepted_kwargs) + return NamedTuple{filtered_keys}(kwargs_nt) end function (f::FilterKwargs)(args...; kwargs...) diff --git a/test/scalar.jl b/test/scalar.jl index c4a8bc35..6625fff6 100644 --- a/test/scalar.jl +++ b/test/scalar.jl @@ -224,7 +224,7 @@ end # Test with keyword arguments γ = ScalarOperator(0.0; update_func = (args...; dtgamma) -> dtgamma, - accepted_kwargs = (:dtgamma,)) + accepted_kwargs = Val((:dtgamma,))) dtgamma = rand() # Original tests diff --git a/test/total.jl b/test/total.jl index ca513458..860b8880 100644 --- a/test/total.jl +++ b/test/total.jl @@ -103,7 +103,7 @@ end # Introduce update function for D dependent on kwarg "matrix" D = MatrixOperator(zeros(N, N); update_func! = (A, u, p, t; matrix) -> (A .= p * t * matrix), - accepted_kwargs = (:matrix,)) + accepted_kwargs = Val((:matrix,))) matrix = rand(N, N) diag = rand(N2) @@ -116,7 +116,7 @@ end D1 = DiagonalOperator(zeros(N2); update_func! = (d, u, p, t) -> d .= p) D2 = DiagonalOperator( zeros(N2); update_func! = (d, u, p, t; diag) -> d .= p * t * diag, - accepted_kwargs = (:diag,)) + accepted_kwargs = Val((:diag,))) TT = [T1, T2] DD = Diagonal([D1, D2])