Skip to content

Commit e89d5a1

Browse files
committed
function op working
1 parent e63e445 commit e89d5a1

File tree

5 files changed

+46
-36
lines changed

5 files changed

+46
-36
lines changed

src/func.jl

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"""
33
Matrix free operators (given by a function)
44
"""
5-
mutable struct FunctionOperator{iip,oop,mul5,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,K,C} <: AbstractSciMLOperator{T}
5+
mutable struct FunctionOperator{iip,oop,mul5,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,C} <: AbstractSciMLOperator{T}
66
""" Function with signature op(u, p, t) and (if isinplace) op(du, u, p, t) """
77
op::F
88
""" Adjoint operator"""
@@ -17,8 +17,8 @@ mutable struct FunctionOperator{iip,oop,mul5,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,K,C}
1717
p::P
1818
""" Time """
1919
t::Tt
20-
""" Keyword arguments """
21-
kwargs::K
20+
""" kwargs """
21+
kwargs::Dict{Symbol,Any} # TODO move inside traits later
2222
""" Cache """
2323
cache::C
2424

@@ -30,7 +30,7 @@ mutable struct FunctionOperator{iip,oop,mul5,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,K,C}
3030
traits,
3131
p,
3232
t,
33-
accepted_kwargs,
33+
kwargs,
3434
cache
3535
)
3636

@@ -51,7 +51,6 @@ mutable struct FunctionOperator{iip,oop,mul5,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,K,C}
5151
typeof(traits),
5252
typeof(p),
5353
typeof(t),
54-
typeof(accepted_kwargs),
5554
typeof(cache),
5655
}(
5756
op,
@@ -61,7 +60,7 @@ mutable struct FunctionOperator{iip,oop,mul5,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,K,C}
6160
traits,
6261
p,
6362
t,
64-
accepted_kwargs,
63+
kwargs,
6564
cache,
6665
)
6766
end
@@ -107,7 +106,7 @@ function FunctionOperator(op,
107106

108107
p=nothing,
109108
t::Union{Number,Nothing}=nothing,
110-
accepted_kwargs = (),
109+
accepted_kwargs::NTuple{N,Symbol} = (),
111110

112111
ifcache::Bool = true,
113112

@@ -118,13 +117,14 @@ function FunctionOperator(op,
118117
issymmetric::Bool = false,
119118
ishermitian::Bool = false,
120119
isposdef::Bool = false,
121-
)
120+
) where{N}
122121

123122
# store eltype of input/output for caching with ComposedOperator.
124123
eltypes = eltype.((input, output))
125124
sz = (size(output, 1), size(input, 1))
126125
T = isnothing(T) ? promote_type(eltypes...) : T
127126
t = isnothing(t) ? zero(real(T)) : t
127+
kwargs = Dict{Symbol, Any}()
128128

129129
isinplace = if isnothing(isinplace)
130130
static_hasmethod(op, typeof((output, input, p, t)))
@@ -188,6 +188,7 @@ function FunctionOperator(op,
188188
T = T,
189189
size = sz,
190190
eltypes = eltypes,
191+
accepted_kwargs = accepted_kwargs,
191192
)
192193

193194
L = FunctionOperator(
@@ -198,7 +199,7 @@ function FunctionOperator(op,
198199
traits,
199200
p,
200201
t,
201-
normalize_kwargs(accepted_kwargs),
202+
kwargs,
202203
cache
203204
)
204205

@@ -214,9 +215,10 @@ function update_coefficients(L::FunctionOperator, u, p, t; kwargs...)
214215
@set! L.p = p
215216
@set! L.t = t
216217

217-
isconstant(L) && return L
218+
filtered_kwargs = get_filtered_kwargs(kwargs, L.traits.accepted_kwargs)
219+
@set! L.kwargs = Dict(filtered_kwargs)
218220

219-
filtered_kwargs = (kwarg => kwargs[kwarg] for kwarg in L.kwargs if haskey(kwargs, kwarg))
221+
isconstant(L) && return L
220222

221223
@set! L.op = update_coefficients(L.op, u, p, t; filtered_kwargs...)
222224
@set! L.op_adjoint = update_coefficients(L.op_adjoint, u, p, t; filtered_kwargs...)
@@ -226,17 +228,18 @@ end
226228

227229
function update_coefficients!(L::FunctionOperator, u, p, t; kwargs...)
228230

229-
isconstant(L) && return
231+
L.p = p
232+
L.t = t
233+
234+
filtered_kwargs = get_filtered_kwargs(kwargs, L.traits.accepted_kwargs)
235+
L.kwargs = Dict(filtered_kwargs)
230236

231-
filtered_kwargs = (kwarg => kwargs[kwarg] for kwarg in L.kwargs if haskey(kwargs, kwarg))
237+
isconstant(L) && return
232238

233239
for op in getops(L)
234240
update_coefficients!(op, u, p, t; filtered_kwargs...)
235241
end
236242

237-
L.p = p
238-
L.t = t
239-
240243
L
241244
end
242245

@@ -383,8 +386,13 @@ has_ldiv!(L::FunctionOperator{iip}) where{iip} = iip & !(L.op_inverse isa Nothin
383386
# TODO - FunctionOperator, Base.conj, transpose
384387

385388
# operator application
386-
Base.:*(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip} = L.op(u, L.p, L.t)
387-
Base.:\(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip} = L.op_inverse(u, L.p, L.t; L.kwargs...)
389+
function Base.:*(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip}
390+
L.op(u, L.p, L.t; L.kwargs...)
391+
end
392+
393+
function Base.:\(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip}
394+
L.op_inverse(u, L.p, L.t; L.kwargs...)
395+
end
388396

389397
function Base.:*(L::FunctionOperator{true,false}, u::AbstractVecOrMat)
390398
_, co = L.cache

src/utils.jl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,16 @@ struct FilterKwargs{F,K}
1818
f::F
1919
accepted_kwargs::K
2020
end
21-
function (f_filter::FilterKwargs)(args...; kwargs...)
22-
# Filter keyword arguments to those accepted by function.
23-
# Avoid throwing errors here if a keyword argument is not provided: defer this to the function call for a more readable error.
24-
filtered_kwargs = (kwarg => kwargs[kwarg] for kwarg in f_filter.accepted_kwargs if haskey(kwargs, kwarg))
25-
f_filter.f(args...; filtered_kwargs...)
21+
22+
# Filter keyword arguments to those accepted by function.
23+
# Avoid throwing errors here if a keyword argument is not provided: defer
24+
# this to the function call for a more readable error.
25+
function get_filtered_kwargs(kwargs::Base.Pairs, accepted_kwargs::NTuple{N,Symbol}) where{N}
26+
(kw => kwargs[kw] for kw in accepted_kwargs if haskey(kwargs, kw))
27+
end
28+
29+
function (f::FilterKwargs)(args...; kwargs...)
30+
filtered_kwargs = get_filtered_kwargs(kwargs, f.accepted_kwargs)
31+
f.f(args...; filtered_kwargs...)
2632
end
27-
# automatically convert NamedTuple's, etc. to a normalized kwargs representation (i.e. Base.Pairs)
28-
normalize_kwargs(; kwargs...) = kwargs
29-
normalize_kwargs(kwargs) = normalize_kwargs(; kwargs...)
3033
#

test/func.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,11 @@ end
107107
scale = rand()
108108

109109
# Accept a kwarg "scale" in operator action
110-
f(du,u,p,t; scale) = mul!(du, Diagonal(p*t*scale), u)
111-
f(u, p, t; scale) = Diagonal(p * t * scale) * u
110+
f(du,u,p,t; scale = 1.0) = mul!(du, Diagonal(p*t*scale), u)
111+
f(u, p, t; scale = 1.0) = Diagonal(p * t * scale) * u
112112

113-
L = FunctionOperator(f, u, u; p=zero(p), t=zero(t), accepted_kwargs=(;scale=zero(scale)))
113+
L = FunctionOperator(f, u, u; p=zero(p), t=zero(t),
114+
accepted_kwargs = (:scale,))
114115

115116
ans = @. u * p * t * scale
116117
@test L(u,p,t; scale) ans

test/scalar.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ end
151151
# Test scalar operator which expects keyword argument to update,
152152
# modeled in the style of a DiffEq W-operator.
153153
γ = ScalarOperator(0.0; update_func = (args...; dtgamma) -> dtgamma,
154-
accepted_kwargs=(:dtgamma,))
154+
accepted_kwargs = (:dtgamma,))
155155

156156
dtgamma = rand()
157157
@test γ(u,p,t; dtgamma) dtgamma * u

test/total.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ end
8686
C = rand(N,N)
8787
# Introduce update function for D dependent on kwarg "matrix"
8888
D = MatrixOperator(zeros(N,N); update_func=(A, u, p, t; matrix) -> (A .= p*t*matrix),
89-
accepted_kwargs=(:matrix,))
89+
accepted_kwargs = (:matrix,))
9090

9191
u = rand(N2,K)
9292
p = rand()
@@ -99,11 +99,9 @@ end
9999
T1 = (A, B)
100100
T2 = (C, D)
101101

102-
# Introduce update function for D1
103-
D1 = DiagonalOperator(zeros(N2); update_func=(d, u, p, t) -> (d .= p))
104-
# Introduce update funcion for D2 dependent on kwarg "diag"
105-
D2 = DiagonalOperator(zeros(N2); update_func=(d, u, p, t; diag) -> (d .= p*t*diag),
106-
accepted_kwargs=(:diag,))
102+
D1 = DiagonalOperator(zeros(N2); update_func = (d, u, p, t) -> p)
103+
D2 = DiagonalOperator(zeros(N2); update_func = (d, u, p, t; diag) -> p*t*diag,
104+
accepted_kwargs = (:diag,))
107105

108106
TT = [T1, T2]
109107
DD = Diagonal([D1, D2])

0 commit comments

Comments
 (0)