Skip to content

Commit 531ede2

Browse files
Remove Dict{Symbol, Any} for type stability
1 parent bb2c1d3 commit 531ede2

File tree

3 files changed

+25
-5
lines changed

3 files changed

+25
-5
lines changed

src/func.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -183,14 +183,15 @@ uniform across `op`, `op_adjoint`, `op_inverse`, `op_adjoint_inverse`.
183183
* `issymmetric` - `true` if the operator is linear and symmetric. Defaults to `false`.
184184
* `ishermitian` - `true` if the operator is linear and hermitian. Defaults to `false`.
185185
* `isposdef` - `true` if the operator is linear and positive-definite. Defaults to `false`.
186+
* `kwargs` - Keyword arguments for cache initialization. If `accepted_kwargs` is provided, the corresponding keyword arguments must be passed.
186187
"""
187188
function FunctionOperator(op,
188189
input::AbstractArray,
189190
output::AbstractArray = input; op_adjoint = nothing,
190191
op_inverse = nothing,
191192
op_adjoint_inverse = nothing, p = nothing,
192193
t::Union{Number, Nothing} = nothing,
193-
accepted_kwargs::NTuple{N, Symbol} = (),
194+
accepted_kwargs::Union{Nothing, Val, NTuple{N, Symbol}} = nothing,
194195

195196
# traits
196197
T::Union{Type{<:Number}, Nothing} = nothing,
@@ -207,7 +208,8 @@ function FunctionOperator(op,
207208
opnorm = nothing,
208209
issymmetric::Union{Bool, Val} = Val(false),
209210
ishermitian::Union{Bool, Val} = Val(false),
210-
isposdef::Bool = false) where {N}
211+
isposdef::Bool = false,
212+
kwargs...) where {N}
211213

212214
# establish types
213215

@@ -316,12 +318,21 @@ function FunctionOperator(op,
316318
_op_adjoint_inverse = op_adjoint_inverse
317319
end
318320

321+
if accepted_kwargs === nothing
322+
accepted_kwargs = Val(())
323+
kwargs = NamedTuple()
324+
else
325+
length(kwargs) != 0 ||
326+
throw(ArgumentError("No keyword arguments provided. When `accepted_kwargs` is provided, the corresponding keyword arguments must be passed for cache initialization."))
327+
kwargs = get_filtered_kwargs(kwargs, accepted_kwargs)
328+
end
329+
319330
traits = (; islinear, isconvertible, isconstant, opnorm,
320331
issymmetric = _unwrap_val(issymmetric), ishermitian = _unwrap_val(ishermitian),
321332
isposdef, isinplace = _unwrap_val(_isinplace),
322333
outofplace = _unwrap_val(_outofplace), has_mul5 = _unwrap_val(_has_mul5),
323334
ifcache = _unwrap_val(ifcache), T = _T, batch, size = _size, sizes,
324-
accepted_kwargs, kwargs = Dict{Symbol, Any}())
335+
accepted_kwargs, kwargs = kwargs)
325336

326337
L = FunctionOperator{_unwrap_val(_isinplace), _unwrap_val(_outofplace),
327338
_unwrap_val(_has_mul5), _T, typeof(op), typeof(_op_adjoint), typeof(op_inverse),
@@ -355,7 +366,7 @@ function update_coefficients(L::FunctionOperator, u, p, t; kwargs...)
355366
# filter and update kwargs
356367
filtered_kwargs = get_filtered_kwargs(kwargs, L.traits.accepted_kwargs)
357368

358-
L = set_traits(L, merge(L.traits, (; kwargs = Dict{Symbol, Any}(filtered_kwargs))))
369+
L = set_traits(L, merge(L.traits, (; kwargs = filtered_kwargs)))
359370

360371
isconstant(L) && return L
361372

@@ -374,7 +385,7 @@ function update_coefficients!(L::FunctionOperator, u, p, t; kwargs...)
374385

375386
# filter and update kwargs
376387
filtered_kwargs = get_filtered_kwargs(kwargs, L.traits.accepted_kwargs)
377-
L.traits = (; L.traits..., kwargs = Dict{Symbol, Any}(filtered_kwargs))
388+
L.traits = merge(L.traits, (; kwargs = filtered_kwargs))
378389

379390
isconstant(L) && return
380391

src/utils.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ function get_filtered_kwargs(kwargs::AbstractDict,
4242
accepted_kwargs::NTuple{N, Symbol}) where {N}
4343
(kw => kwargs[kw] for kw in accepted_kwargs if haskey(kwargs, kw))
4444
end
45+
function get_filtered_kwargs(kwargs::Union{AbstractDict, NamedTuple},
46+
::Val{accepted_kwargs}) where {accepted_kwargs}
47+
kwargs_nt = NamedTuple(kwargs)
48+
return NamedTuple{accepted_kwargs}(kwargs_nt) # This creates a new NamedTuple with keys specified by `accepted_kwargs`
49+
end
4550

4651
function (f::FilterKwargs)(args...; kwargs...)
4752
filtered_kwargs = get_filtered_kwargs(kwargs, f.accepted_kwargs)

test/func.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,10 @@ end
257257
f(u, p, t; scale = 1.0) = Diagonal(p * t * scale) * u
258258

259259
L = FunctionOperator(f, u, u; p = zero(p), t = zero(t), batch = true,
260+
accepted_kwargs = (:scale,), scale = 1.0)
261+
262+
@test_throws ArgumentError FunctionOperator(
263+
f, u, u; p = zero(p), t = zero(t), batch = true,
260264
accepted_kwargs = (:scale,))
261265

262266
@test size(L) == (N, N)

0 commit comments

Comments
 (0)