Skip to content

Commit c613fe3

Browse files
Merge pull request #255 from albertomercurio/patch-1
Remove `Dict{Symbol, Any}` for type stability in `FunctionOperator`
2 parents 9ca6598 + 717d4dd commit c613fe3

File tree

4 files changed

+83
-8
lines changed

4 files changed

+83
-8
lines changed

src/basic.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,13 @@ for op in (:*, :\)
4444
end
4545
end
4646

47-
function LinearAlgebra.mul!(v::AbstractVecOrMat, ii::IdentityOperator, u::AbstractVecOrMat)
47+
@inline function LinearAlgebra.mul!(
48+
v::AbstractVecOrMat, ii::IdentityOperator, u::AbstractVecOrMat)
4849
@assert size(u, 1) == ii.len
4950
copy!(v, u)
5051
end
5152

52-
function LinearAlgebra.mul!(v::AbstractVecOrMat,
53+
@inline function LinearAlgebra.mul!(v::AbstractVecOrMat,
5354
ii::IdentityOperator,
5455
u::AbstractVecOrMat,
5556
α,

src/func.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -183,14 +183,15 @@ uniform across `op`, `op_adjoint`, `op_inverse`, `op_adjoint_inverse`.
183183
* `issymmetric` - `true` if the operator is linear and symmetric. Defaults to `false`.
184184
* `ishermitian` - `true` if the operator is linear and hermitian. Defaults to `false`.
185185
* `isposdef` - `true` if the operator is linear and positive-definite. Defaults to `false`.
186+
* `kwargs` - Keyword arguments for cache initialization. If `accepted_kwargs` is provided, the corresponding keyword arguments must be passed.
186187
"""
187188
function FunctionOperator(op,
188189
input::AbstractArray,
189190
output::AbstractArray = input; op_adjoint = nothing,
190191
op_inverse = nothing,
191192
op_adjoint_inverse = nothing, p = nothing,
192193
t::Union{Number, Nothing} = nothing,
193-
accepted_kwargs::NTuple{N, Symbol} = (),
194+
accepted_kwargs::Union{Nothing, Val, NTuple{N, Symbol}} = nothing,
194195

195196
# traits
196197
T::Union{Type{<:Number}, Nothing} = nothing,
@@ -207,7 +208,8 @@ function FunctionOperator(op,
207208
opnorm = nothing,
208209
issymmetric::Union{Bool, Val} = Val(false),
209210
ishermitian::Union{Bool, Val} = Val(false),
210-
isposdef::Bool = false) where {N}
211+
isposdef::Bool = false,
212+
kwargs...) where {N}
211213

212214
# establish types
213215

@@ -316,12 +318,21 @@ function FunctionOperator(op,
316318
_op_adjoint_inverse = op_adjoint_inverse
317319
end
318320

321+
if accepted_kwargs === nothing
322+
accepted_kwargs = Val(())
323+
kwargs = NamedTuple()
324+
else
325+
length(kwargs) != 0 ||
326+
throw(ArgumentError("No keyword arguments provided. When `accepted_kwargs` is provided, the corresponding keyword arguments must be passed for cache initialization."))
327+
kwargs = get_filtered_kwargs(kwargs, accepted_kwargs)
328+
end
329+
319330
traits = (; islinear, isconvertible, isconstant, opnorm,
320331
issymmetric = _unwrap_val(issymmetric), ishermitian = _unwrap_val(ishermitian),
321332
isposdef, isinplace = _unwrap_val(_isinplace),
322333
outofplace = _unwrap_val(_outofplace), has_mul5 = _unwrap_val(_has_mul5),
323334
ifcache = _unwrap_val(ifcache), T = _T, batch, size = _size, sizes,
324-
accepted_kwargs, kwargs = Dict{Symbol, Any}())
335+
accepted_kwargs, kwargs = kwargs)
325336

326337
L = FunctionOperator{_unwrap_val(_isinplace), _unwrap_val(_outofplace),
327338
_unwrap_val(_has_mul5), _T, typeof(op), typeof(_op_adjoint), typeof(op_inverse),
@@ -355,7 +366,7 @@ function update_coefficients(L::FunctionOperator, u, p, t; kwargs...)
355366
# filter and update kwargs
356367
filtered_kwargs = get_filtered_kwargs(kwargs, L.traits.accepted_kwargs)
357368

358-
L = set_traits(L, merge(L.traits, (; kwargs = Dict{Symbol, Any}(filtered_kwargs))))
369+
L = set_traits(L, merge(L.traits, (; kwargs = filtered_kwargs)))
359370

360371
isconstant(L) && return L
361372

@@ -374,7 +385,7 @@ function update_coefficients!(L::FunctionOperator, u, p, t; kwargs...)
374385

375386
# filter and update kwargs
376387
filtered_kwargs = get_filtered_kwargs(kwargs, L.traits.accepted_kwargs)
377-
L.traits = (; L.traits..., kwargs = Dict{Symbol, Any}(filtered_kwargs))
388+
L.traits = merge(L.traits, (; kwargs = filtered_kwargs))
378389

379390
isconstant(L) && return
380391

src/utils.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ function get_filtered_kwargs(kwargs::AbstractDict,
4242
accepted_kwargs::NTuple{N, Symbol}) where {N}
4343
(kw => kwargs[kw] for kw in accepted_kwargs if haskey(kwargs, kw))
4444
end
45+
function get_filtered_kwargs(kwargs::Union{AbstractDict, NamedTuple},
46+
::Val{accepted_kwargs}) where {accepted_kwargs}
47+
kwargs_nt = NamedTuple(kwargs)
48+
return NamedTuple{accepted_kwargs}(kwargs_nt) # This creates a new NamedTuple with keys specified by `accepted_kwargs`
49+
end
4550

4651
function (f::FilterKwargs)(args...; kwargs...)
4752
filtered_kwargs = get_filtered_kwargs(kwargs, f.accepted_kwargs)

test/func.jl

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,66 @@ end
256256
f(du, u, p, t; scale = 1.0) = mul!(du, Diagonal(p * t * scale), u)
257257
f(u, p, t; scale = 1.0) = Diagonal(p * t * scale) * u
258258

259+
for acc_kw in ((:scale,), Val((:scale,)))
260+
L = FunctionOperator(f, u, u; p = zero(p), t = zero(t), batch = true,
261+
accepted_kwargs = acc_kw, scale = 1.0)
262+
263+
@test_throws ArgumentError FunctionOperator(
264+
f, u, u; p = zero(p), t = zero(t), batch = true,
265+
accepted_kwargs = acc_kw)
266+
267+
@test size(L) == (N, N)
268+
269+
ans = @. u * p * t * scale
270+
@test L(u, p, t; scale) ans
271+
v = copy(u)
272+
@test L(v, u, p, t; scale) ans
273+
274+
# test that output isn't accidentally mutated by passing an internal cache.
275+
276+
A = Diagonal(p * t * scale)
277+
u1 = rand(N, K)
278+
u2 = rand(N, K)
279+
280+
v1 = L * u1
281+
@test v1 A * u1
282+
v2 = L * u2
283+
@test v2 A * u2
284+
@test v1 A * u1
285+
@test v1 + v2 A * (u1 + u2)
286+
287+
v1 .= 0.0
288+
v2 .= 0.0
289+
290+
mul!(v1, L, u1)
291+
@test v1 A * u1
292+
mul!(v2, L, u2)
293+
@test v2 A * u2
294+
@test v1 A * u1
295+
@test v1 + v2 A * (u1 + u2)
296+
297+
v1 = rand(N, K)
298+
w1 = copy(v1)
299+
v2 = rand(N, K)
300+
w2 = copy(v2)
301+
a1, a2, b1, b2 = rand(4)
302+
303+
mul!(v1, L, u1, a1, b1)
304+
@test v1 a1 * A * u1 + b1 * w1
305+
mul!(v2, L, u2, a2, b2)
306+
@test v2 a2 * A * u2 + b2 * w2
307+
@test v1 a1 * A * u1 + b1 * w1
308+
@test v1 + v2 (a1 * A * u1 + b1 * w1) + (a2 * A * u2 + b2 * w2)
309+
end
310+
311+
## Do the same with Val((:scale,))
312+
259313
L = FunctionOperator(f, u, u; p = zero(p), t = zero(t), batch = true,
260-
accepted_kwargs = (:scale,))
314+
accepted_kwargs = Val((:scale,)), scale = 1.0)
315+
316+
@test_throws ArgumentError FunctionOperator(
317+
f, u, u; p = zero(p), t = zero(t), batch = true,
318+
accepted_kwargs = Val((:scale,)))
261319

262320
@test size(L) == (N, N)
263321

0 commit comments

Comments
 (0)