Skip to content

Commit f1f4853

Browse files
Merge pull request #313 from ChrisRackauckas-Claude/fix-accepted-kwargs-allocations
Fix allocation issue with accepted_kwargs by using Val types
2 parents 1bc1531 + 8715f69 commit f1f4853

File tree

6 files changed

+36
-16
lines changed

6 files changed

+36
-16
lines changed

docs/src/interface.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ using SciMLOperators
6767
6868
γ = ScalarOperator(0.0;
6969
update_func = (a, u, p, t; my_special_scaling) -> my_special_scaling,
70-
accepted_kwargs = (:my_special_scaling,))
70+
accepted_kwargs = Val((:my_special_scaling,)))
7171
7272
# Update coefficients, then apply operator
7373
update_coefficients!(γ, nothing, nothing, nothing; my_special_scaling = 7.0)

src/func.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ uniform across `op`, `op_adjoint`, `op_inverse`, `op_adjoint_inverse`.
204204
- `u` - Prototype of the state struct passed to the operator during evaluation, i.e. `L(u, p, t)`. `u` is set to `nothing` if no value is provided.
205205
- `p` - Prototype of parameter struct passed to the operator during evaluation, i.e. `L(u, p, t)`. `p` is set to `nothing` if no value is provided.
206206
- `t` - Protype of scalar time variable passed to the operator during evaluation. `t` is set to `zero(T)` if no value is provided.
207-
- `accepted_kwargs` - `Tuple` of `Symbol`s corresponding to the keyword arguments accepted by `op*`, and `update_coefficients[!]`. For example, if `op` accepts kwarg `scale`, as in `op(u, p, t; scale)`, then `accepted_kwargs = (:scale,)`.
207+
- `accepted_kwargs` - `Val` of a `Tuple` of `Symbol`s for zero-allocation kwarg filtering. Corresponds to the keyword arguments accepted by `op*`, and `update_coefficients[!]`. For example, if `op` accepts kwarg `scale`, as in `op(u, p, t; scale)`, then `accepted_kwargs = Val((:scale,))`. Plain tuples like `(:scale,)` are deprecated but still supported.
208208
- `T` - `eltype` of the operator. If no value is provided, the constructor inferrs the value from types of `input`, and `output`
209209
- `isinplace` - `true` if the operator can be used is a mutating way with in-place allocations. This trait is inferred if no value is provided.
210210
- `outofplace` - `true` if the operator can be used is a non-mutating way with in-place allocations. This trait is inferred if no value is provided.
@@ -451,7 +451,7 @@ function Base.copy(L::FunctionOperator)
451451
isdefined(L, :p) && L.p !== nothing ? deepcopy(L.p) : nothing,
452452
L.t,
453453
L.cache === nothing ? nothing : deepcopy(L.cache),
454-
typeof(L).parameters[end-1], # iType
454+
typeof(L).parameters[end - 1], # iType
455455
typeof(L).parameters[end] # oType
456456
)
457457
end

src/matrix.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@ or
1414
1515
update_func!(A::AbstractMatrix, u, p, t; <accepted kwargs>) -> [modifies A]
1616
17-
The set of keyword-arguments accepted by `update_func[!]` must be provided
18-
to `MatrixOperator` via the kwarg `accepted_kwargs` as a tuple of `Symbol`s.
17+
The set of keyword-arguments accepted by `update_func[!]` should be provided
18+
to `MatrixOperator` via the kwarg `accepted_kwargs` as a `Val` of a tuple of `Symbol`s
19+
for zero-allocation kwarg filtering. For example, `accepted_kwargs = Val((:dtgamma,))`.
20+
Plain tuples like `(:dtgamma,)` are deprecated but still supported.
1921
`kwargs` cannot be passed down to `update_func[!]` if `accepted_kwargs`
2022
are not provided.
2123
@@ -38,7 +40,7 @@ p = rand(4, 4)
3840
t = rand()
3941
4042
mat_update = (A, u, p, t; scale = 0.0) -> t * p
41-
M = MatrixOperator(0.0; update_func = mat_update, accepted_kwargs = (:scale,))
43+
M = MatrixOperator(0.0; update_func = mat_update, accepted_kwargs = Val((:scale,)))
4244
4345
L = M * M + 3I
4446
L = cache_operator(L, v)
@@ -65,7 +67,7 @@ p = rand(4) # Must be non-nothing
6567
t = rand()
6668
6769
mat_update! = (A, u, p, t; scale = 0.0) -> (A .= t * p * u' * scale)
68-
M = MatrixOperator(zeros(4, 4); update_func! = mat_update!, accepted_kwargs = (:scale,))
70+
M = MatrixOperator(zeros(4, 4); update_func! = mat_update!, accepted_kwargs = Val((:scale,)))
6971
L = M * M + 3I
7072
L = cache_operator(L, v)
7173
@@ -291,8 +293,10 @@ or
291293
292294
update_func!(diag::AbstractVecOrMat, u, p, t; <accepted kwargs>) -> [modifies diag]
293295
294-
The set of keyword-arguments accepted by `update_func[!]` must be provided
295-
to `MatrixOperator` via the kwarg `accepted_kwargs` as a tuple of `Symbol`s.
296+
The set of keyword-arguments accepted by `update_func[!]` should be provided
297+
to `DiagonalOperator` via the kwarg `accepted_kwargs` as a `Val` of a tuple of `Symbol`s
298+
for zero-allocation kwarg filtering. For example, `accepted_kwargs = Val((:dtgamma,))`.
299+
Plain tuples like `(:dtgamma,)` are deprecated but still supported.
296300
`kwargs` cannot be passed down to `update_func[!]` if `accepted_kwargs`
297301
are not provided.
298302
@@ -501,8 +505,10 @@ and `B`, `b` are expected to have an appropriate size so that
501505
`A * v + B * b` makes sense. Specifically, `size(A, 1) == size(B, 1)`, and
502506
`size(v, 2) == size(b, 2)`.
503507
504-
The set of keyword-arguments accepted by `update_func[!]` must be provided
505-
to `AffineOperator` via the kwarg `accepted_kwargs` as a tuple of `Symbol`s.
508+
The set of keyword-arguments accepted by `update_func[!]` should be provided
509+
to `AffineOperator` via the kwarg `accepted_kwargs` as a `Val` of a tuple of `Symbol`s
510+
for zero-allocation kwarg filtering. For example, `accepted_kwargs = Val((:dtgamma,))`.
511+
Plain tuples like `(:dtgamma,)` are deprecated but still supported.
506512
`kwargs` cannot be passed down to `update_func[!]` if `accepted_kwargs`
507513
are not provided.
508514

src/utils.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,19 @@ arguments. Required in implementation of lazy `Base.adjoint`,
1414
struct NoKwargFilter end
1515

1616
function preprocess_update_func(update_func, accepted_kwargs)
17-
_accepted_kwargs = (accepted_kwargs === nothing) ? () : accepted_kwargs
17+
# Convert accepted_kwargs to Val for compile-time kwarg filtering to avoid allocations
18+
_accepted_kwargs = if accepted_kwargs === nothing
19+
Val(())
20+
elseif accepted_kwargs isa Tuple
21+
# Deprecation: Encourage users to use Val((...)) directly for better performance
22+
@warn """Passing accepted_kwargs as a plain Tuple is deprecated and will be removed in a future version.
23+
Please use Val((...)) instead for zero-allocation kwarg filtering.
24+
Example: accepted_kwargs = Val((:dtgamma,)) instead of accepted_kwargs = (:dtgamma,)
25+
This message will only be shown once per session.""" maxlog=1
26+
Val(accepted_kwargs)
27+
else
28+
accepted_kwargs # Already a Val or NoKwargFilter
29+
end
1830
# accepted_kwargs can be passed as nothing to indicate that we should not filter
1931
# (e.g. if the function already accepts all kwargs...).
2032
return (_accepted_kwargs isa NoKwargFilter) ? update_func :
@@ -46,7 +58,9 @@ end
4658
function get_filtered_kwargs(kwargs::Union{AbstractDict, NamedTuple},
4759
::Val{accepted_kwargs}) where {accepted_kwargs}
4860
kwargs_nt = NamedTuple(kwargs)
49-
return NamedTuple{accepted_kwargs}(kwargs_nt) # This creates a new NamedTuple with keys specified by `accepted_kwargs`
61+
# Only extract keys that exist in kwargs_nt to avoid errors
62+
filtered_keys = filter(k -> haskey(kwargs_nt, k), accepted_kwargs)
63+
return NamedTuple{filtered_keys}(kwargs_nt)
5064
end
5165

5266
function (f::FilterKwargs)(args...; kwargs...)

test/scalar.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ end
224224

225225
# Test with keyword arguments
226226
γ = ScalarOperator(0.0; update_func = (args...; dtgamma) -> dtgamma,
227-
accepted_kwargs = (:dtgamma,))
227+
accepted_kwargs = Val((:dtgamma,)))
228228

229229
dtgamma = rand()
230230
# Original tests

test/total.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ end
103103
# Introduce update function for D dependent on kwarg "matrix"
104104
D = MatrixOperator(zeros(N, N);
105105
update_func! = (A, u, p, t; matrix) -> (A .= p * t * matrix),
106-
accepted_kwargs = (:matrix,))
106+
accepted_kwargs = Val((:matrix,)))
107107

108108
matrix = rand(N, N)
109109
diag = rand(N2)
@@ -116,7 +116,7 @@ end
116116
D1 = DiagonalOperator(zeros(N2); update_func! = (d, u, p, t) -> d .= p)
117117
D2 = DiagonalOperator(
118118
zeros(N2); update_func! = (d, u, p, t; diag) -> d .= p * t * diag,
119-
accepted_kwargs = (:diag,))
119+
accepted_kwargs = Val((:diag,)))
120120

121121
TT = [T1, T2]
122122
DD = Diagonal([D1, D2])

0 commit comments

Comments
 (0)