diff --git a/src/basic.jl b/src/basic.jl index 1257588a..666cbe94 100644 --- a/src/basic.jl +++ b/src/basic.jl @@ -44,12 +44,13 @@ for op in (:*, :\) end end -function LinearAlgebra.mul!(v::AbstractVecOrMat, ii::IdentityOperator, u::AbstractVecOrMat) +@inline function LinearAlgebra.mul!( + v::AbstractVecOrMat, ii::IdentityOperator, u::AbstractVecOrMat) @assert size(u, 1) == ii.len copy!(v, u) end -function LinearAlgebra.mul!(v::AbstractVecOrMat, +@inline function LinearAlgebra.mul!(v::AbstractVecOrMat, ii::IdentityOperator, u::AbstractVecOrMat, α, diff --git a/src/func.jl b/src/func.jl index e1b2cadc..0bc0da97 100644 --- a/src/func.jl +++ b/src/func.jl @@ -183,6 +183,7 @@ uniform across `op`, `op_adjoint`, `op_inverse`, `op_adjoint_inverse`. * `issymmetric` - `true` if the operator is linear and symmetric. Defaults to `false`. * `ishermitian` - `true` if the operator is linear and hermitian. Defaults to `false`. * `isposdef` - `true` if the operator is linear and positive-definite. Defaults to `false`. +* `kwargs` - Keyword arguments for cache initialization. If `accepted_kwargs` is provided, the corresponding keyword arguments must be passed. """ function FunctionOperator(op, input::AbstractArray, @@ -190,7 +191,7 @@ function FunctionOperator(op, op_inverse = nothing, op_adjoint_inverse = nothing, p = nothing, t::Union{Number, Nothing} = nothing, - accepted_kwargs::NTuple{N, Symbol} = (), + accepted_kwargs::Union{Nothing, Val, NTuple{N, Symbol}} = nothing, # traits T::Union{Type{<:Number}, Nothing} = nothing, @@ -207,7 +208,8 @@ function FunctionOperator(op, opnorm = nothing, issymmetric::Union{Bool, Val} = Val(false), ishermitian::Union{Bool, Val} = Val(false), - isposdef::Bool = false) where {N} + isposdef::Bool = false, + kwargs...) where {N} # establish types @@ -316,12 +318,21 @@ function FunctionOperator(op, _op_adjoint_inverse = op_adjoint_inverse end + if accepted_kwargs === nothing + accepted_kwargs = Val(()) + kwargs = NamedTuple() + else + length(kwargs) != 0 || + throw(ArgumentError("No keyword arguments provided. When `accepted_kwargs` is provided, the corresponding keyword arguments must be passed for cache initialization.")) + kwargs = get_filtered_kwargs(kwargs, accepted_kwargs) + end + traits = (; islinear, isconvertible, isconstant, opnorm, issymmetric = _unwrap_val(issymmetric), ishermitian = _unwrap_val(ishermitian), isposdef, isinplace = _unwrap_val(_isinplace), outofplace = _unwrap_val(_outofplace), has_mul5 = _unwrap_val(_has_mul5), ifcache = _unwrap_val(ifcache), T = _T, batch, size = _size, sizes, - accepted_kwargs, kwargs = Dict{Symbol, Any}()) + accepted_kwargs, kwargs = kwargs) L = FunctionOperator{_unwrap_val(_isinplace), _unwrap_val(_outofplace), _unwrap_val(_has_mul5), _T, typeof(op), typeof(_op_adjoint), typeof(op_inverse), @@ -355,7 +366,7 @@ function update_coefficients(L::FunctionOperator, u, p, t; kwargs...) # filter and update kwargs filtered_kwargs = get_filtered_kwargs(kwargs, L.traits.accepted_kwargs) - L = set_traits(L, merge(L.traits, (; kwargs = Dict{Symbol, Any}(filtered_kwargs)))) + L = set_traits(L, merge(L.traits, (; kwargs = filtered_kwargs))) isconstant(L) && return L @@ -374,7 +385,7 @@ function update_coefficients!(L::FunctionOperator, u, p, t; kwargs...) # filter and update kwargs filtered_kwargs = get_filtered_kwargs(kwargs, L.traits.accepted_kwargs) - L.traits = (; L.traits..., kwargs = Dict{Symbol, Any}(filtered_kwargs)) + L.traits = merge(L.traits, (; kwargs = filtered_kwargs)) isconstant(L) && return diff --git a/src/utils.jl b/src/utils.jl index 1297f64d..67b1440d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -42,6 +42,11 @@ function get_filtered_kwargs(kwargs::AbstractDict, accepted_kwargs::NTuple{N, Symbol}) where {N} (kw => kwargs[kw] for kw in accepted_kwargs if haskey(kwargs, kw)) end +function get_filtered_kwargs(kwargs::Union{AbstractDict, NamedTuple}, + ::Val{accepted_kwargs}) where {accepted_kwargs} + kwargs_nt = NamedTuple(kwargs) + return NamedTuple{accepted_kwargs}(kwargs_nt) # This creates a new NamedTuple with keys specified by `accepted_kwargs` +end function (f::FilterKwargs)(args...; kwargs...) filtered_kwargs = get_filtered_kwargs(kwargs, f.accepted_kwargs) diff --git a/test/func.jl b/test/func.jl index 0f28aeaf..cdc1a52c 100644 --- a/test/func.jl +++ b/test/func.jl @@ -256,8 +256,66 @@ end f(du, u, p, t; scale = 1.0) = mul!(du, Diagonal(p * t * scale), u) f(u, p, t; scale = 1.0) = Diagonal(p * t * scale) * u + for acc_kw in ((:scale,), Val((:scale,))) + L = FunctionOperator(f, u, u; p = zero(p), t = zero(t), batch = true, + accepted_kwargs = acc_kw, scale = 1.0) + + @test_throws ArgumentError FunctionOperator( + f, u, u; p = zero(p), t = zero(t), batch = true, + accepted_kwargs = acc_kw) + + @test size(L) == (N, N) + + ans = @. u * p * t * scale + @test L(u, p, t; scale) ≈ ans + v = copy(u) + @test L(v, u, p, t; scale) ≈ ans + + # test that output isn't accidentally mutated by passing an internal cache. + + A = Diagonal(p * t * scale) + u1 = rand(N, K) + u2 = rand(N, K) + + v1 = L * u1 + @test v1 ≈ A * u1 + v2 = L * u2 + @test v2 ≈ A * u2 + @test v1 ≈ A * u1 + @test v1 + v2 ≈ A * (u1 + u2) + + v1 .= 0.0 + v2 .= 0.0 + + mul!(v1, L, u1) + @test v1 ≈ A * u1 + mul!(v2, L, u2) + @test v2 ≈ A * u2 + @test v1 ≈ A * u1 + @test v1 + v2 ≈ A * (u1 + u2) + + v1 = rand(N, K) + w1 = copy(v1) + v2 = rand(N, K) + w2 = copy(v2) + a1, a2, b1, b2 = rand(4) + + mul!(v1, L, u1, a1, b1) + @test v1 ≈ a1 * A * u1 + b1 * w1 + mul!(v2, L, u2, a2, b2) + @test v2 ≈ a2 * A * u2 + b2 * w2 + @test v1 ≈ a1 * A * u1 + b1 * w1 + @test v1 + v2 ≈ (a1 * A * u1 + b1 * w1) + (a2 * A * u2 + b2 * w2) + end + + ## Do the same with Val((:scale,)) + L = FunctionOperator(f, u, u; p = zero(p), t = zero(t), batch = true, - accepted_kwargs = (:scale,)) + accepted_kwargs = Val((:scale,)), scale = 1.0) + + @test_throws ArgumentError FunctionOperator( + f, u, u; p = zero(p), t = zero(t), batch = true, + accepted_kwargs = Val((:scale,))) @test size(L) == (N, N)