Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/src/interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ using SciMLOperators

γ = ScalarOperator(0.0;
update_func = (a, u, p, t; my_special_scaling) -> my_special_scaling,
accepted_kwargs = (:my_special_scaling,))
accepted_kwargs = Val((:my_special_scaling,)))

# Update coefficients, then apply operator
update_coefficients!(γ, nothing, nothing, nothing; my_special_scaling = 7.0)
Expand Down
4 changes: 2 additions & 2 deletions src/func.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ uniform across `op`, `op_adjoint`, `op_inverse`, `op_adjoint_inverse`.
- `u` - Prototype of the state struct passed to the operator during evaluation, i.e. `L(u, p, t)`. `u` is set to `nothing` if no value is provided.
- `p` - Prototype of parameter struct passed to the operator during evaluation, i.e. `L(u, p, t)`. `p` is set to `nothing` if no value is provided.
- `t` - Protype of scalar time variable passed to the operator during evaluation. `t` is set to `zero(T)` if no value is provided.
- `accepted_kwargs` - `Tuple` of `Symbol`s corresponding to the keyword arguments accepted by `op*`, and `update_coefficients[!]`. For example, if `op` accepts kwarg `scale`, as in `op(u, p, t; scale)`, then `accepted_kwargs = (:scale,)`.
- `accepted_kwargs` - `Val` of a `Tuple` of `Symbol`s for zero-allocation kwarg filtering. Corresponds to the keyword arguments accepted by `op*`, and `update_coefficients[!]`. For example, if `op` accepts kwarg `scale`, as in `op(u, p, t; scale)`, then `accepted_kwargs = Val((:scale,))`. Plain tuples like `(:scale,)` are deprecated but still supported.
- `T` - `eltype` of the operator. If no value is provided, the constructor inferrs the value from types of `input`, and `output`
- `isinplace` - `true` if the operator can be used is a mutating way with in-place allocations. This trait is inferred if no value is provided.
- `outofplace` - `true` if the operator can be used is a non-mutating way with in-place allocations. This trait is inferred if no value is provided.
Expand Down Expand Up @@ -451,7 +451,7 @@ function Base.copy(L::FunctionOperator)
isdefined(L, :p) && L.p !== nothing ? deepcopy(L.p) : nothing,
L.t,
L.cache === nothing ? nothing : deepcopy(L.cache),
typeof(L).parameters[end-1], # iType
typeof(L).parameters[end - 1], # iType
typeof(L).parameters[end] # oType
)
end
Expand Down
22 changes: 14 additions & 8 deletions src/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ or

update_func!(A::AbstractMatrix, u, p, t; <accepted kwargs>) -> [modifies A]

The set of keyword-arguments accepted by `update_func[!]` must be provided
to `MatrixOperator` via the kwarg `accepted_kwargs` as a tuple of `Symbol`s.
The set of keyword-arguments accepted by `update_func[!]` should be provided
to `MatrixOperator` via the kwarg `accepted_kwargs` as a `Val` of a tuple of `Symbol`s
for zero-allocation kwarg filtering. For example, `accepted_kwargs = Val((:dtgamma,))`.
Plain tuples like `(:dtgamma,)` are deprecated but still supported.
`kwargs` cannot be passed down to `update_func[!]` if `accepted_kwargs`
are not provided.

Expand All @@ -38,7 +40,7 @@ p = rand(4, 4)
t = rand()

mat_update = (A, u, p, t; scale = 0.0) -> t * p
M = MatrixOperator(0.0; update_func = mat_update, accepted_kwargs = (:scale,))
M = MatrixOperator(0.0; update_func = mat_update, accepted_kwargs = Val((:scale,)))

L = M * M + 3I
L = cache_operator(L, v)
Expand All @@ -65,7 +67,7 @@ p = rand(4) # Must be non-nothing
t = rand()

mat_update! = (A, u, p, t; scale = 0.0) -> (A .= t * p * u' * scale)
M = MatrixOperator(zeros(4, 4); update_func! = mat_update!, accepted_kwargs = (:scale,))
M = MatrixOperator(zeros(4, 4); update_func! = mat_update!, accepted_kwargs = Val((:scale,)))
L = M * M + 3I
L = cache_operator(L, v)

Expand Down Expand Up @@ -291,8 +293,10 @@ or

update_func!(diag::AbstractVecOrMat, u, p, t; <accepted kwargs>) -> [modifies diag]

The set of keyword-arguments accepted by `update_func[!]` must be provided
to `MatrixOperator` via the kwarg `accepted_kwargs` as a tuple of `Symbol`s.
The set of keyword-arguments accepted by `update_func[!]` should be provided
to `DiagonalOperator` via the kwarg `accepted_kwargs` as a `Val` of a tuple of `Symbol`s
for zero-allocation kwarg filtering. For example, `accepted_kwargs = Val((:dtgamma,))`.
Plain tuples like `(:dtgamma,)` are deprecated but still supported.
`kwargs` cannot be passed down to `update_func[!]` if `accepted_kwargs`
are not provided.

Expand Down Expand Up @@ -501,8 +505,10 @@ and `B`, `b` are expected to have an appropriate size so that
`A * v + B * b` makes sense. Specifically, `size(A, 1) == size(B, 1)`, and
`size(v, 2) == size(b, 2)`.

The set of keyword-arguments accepted by `update_func[!]` must be provided
to `AffineOperator` via the kwarg `accepted_kwargs` as a tuple of `Symbol`s.
The set of keyword-arguments accepted by `update_func[!]` should be provided
to `AffineOperator` via the kwarg `accepted_kwargs` as a `Val` of a tuple of `Symbol`s
for zero-allocation kwarg filtering. For example, `accepted_kwargs = Val((:dtgamma,))`.
Plain tuples like `(:dtgamma,)` are deprecated but still supported.
`kwargs` cannot be passed down to `update_func[!]` if `accepted_kwargs`
are not provided.

Expand Down
18 changes: 16 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,19 @@ arguments. Required in implementation of lazy `Base.adjoint`,
struct NoKwargFilter end

function preprocess_update_func(update_func, accepted_kwargs)
_accepted_kwargs = (accepted_kwargs === nothing) ? () : accepted_kwargs
# Convert accepted_kwargs to Val for compile-time kwarg filtering to avoid allocations
_accepted_kwargs = if accepted_kwargs === nothing
Val(())
elseif accepted_kwargs isa Tuple
# Deprecation: Encourage users to use Val((...)) directly for better performance
@warn """Passing accepted_kwargs as a plain Tuple is deprecated and will be removed in a future version.
Please use Val((...)) instead for zero-allocation kwarg filtering.
Example: accepted_kwargs = Val((:dtgamma,)) instead of accepted_kwargs = (:dtgamma,)
This message will only be shown once per session.""" maxlog=1
Val(accepted_kwargs)
else
accepted_kwargs # Already a Val or NoKwargFilter
end
# accepted_kwargs can be passed as nothing to indicate that we should not filter
# (e.g. if the function already accepts all kwargs...).
return (_accepted_kwargs isa NoKwargFilter) ? update_func :
Expand Down Expand Up @@ -46,7 +58,9 @@ end
function get_filtered_kwargs(kwargs::Union{AbstractDict, NamedTuple},
::Val{accepted_kwargs}) where {accepted_kwargs}
kwargs_nt = NamedTuple(kwargs)
return NamedTuple{accepted_kwargs}(kwargs_nt) # This creates a new NamedTuple with keys specified by `accepted_kwargs`
# Only extract keys that exist in kwargs_nt to avoid errors
filtered_keys = filter(k -> haskey(kwargs_nt, k), accepted_kwargs)
return NamedTuple{filtered_keys}(kwargs_nt)
end

function (f::FilterKwargs)(args...; kwargs...)
Expand Down
2 changes: 1 addition & 1 deletion test/scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ end

# Test with keyword arguments
γ = ScalarOperator(0.0; update_func = (args...; dtgamma) -> dtgamma,
accepted_kwargs = (:dtgamma,))
accepted_kwargs = Val((:dtgamma,)))

dtgamma = rand()
# Original tests
Expand Down
4 changes: 2 additions & 2 deletions test/total.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ end
# Introduce update function for D dependent on kwarg "matrix"
D = MatrixOperator(zeros(N, N);
update_func! = (A, u, p, t; matrix) -> (A .= p * t * matrix),
accepted_kwargs = (:matrix,))
accepted_kwargs = Val((:matrix,)))

matrix = rand(N, N)
diag = rand(N2)
Expand All @@ -116,7 +116,7 @@ end
D1 = DiagonalOperator(zeros(N2); update_func! = (d, u, p, t) -> d .= p)
D2 = DiagonalOperator(
zeros(N2); update_func! = (d, u, p, t; diag) -> d .= p * t * diag,
accepted_kwargs = (:diag,))
accepted_kwargs = Val((:diag,)))

TT = [T1, T2]
DD = Diagonal([D1, D2])
Expand Down
Loading