Skip to content

Commit 8edef14

Browse files
committed
Implement initialize and finalize affects for symbolic callbacks
1 parent 61cf676 commit 8edef14

File tree

1 file changed

+154
-52
lines changed

1 file changed

+154
-52
lines changed

src/systems/callbacks.jl

Lines changed: 154 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -104,28 +104,38 @@ The affect function updates the value at `x` in `modified` to be the result of e
104104
modified::Vector
105105
mod_syms::Vector{Symbol}
106106
ctx::Any
107+
skip_checks::Bool
107108
end
108109

109110
function MutatingFunctionalAffect(f::Function;
110111
observed::NamedTuple = NamedTuple{()}(()),
111112
modified::NamedTuple = NamedTuple{()}(()),
112-
ctx = nothing)
113-
MutatingFunctionalAffect(f, collect(values(observed)), collect(keys(observed)),
114-
collect(values(modified)), collect(keys(modified)), ctx)
113+
ctx = nothing,
114+
skip_checks = false)
115+
MutatingFunctionalAffect(f,
116+
collect(values(observed)), collect(keys(observed)),
117+
collect(values(modified)), collect(keys(modified)),
118+
ctx, skip_checks)
115119
end
116120
function MutatingFunctionalAffect(f::Function, modified::NamedTuple;
117-
observed::NamedTuple = NamedTuple{()}(()), ctx = nothing)
118-
MutatingFunctionalAffect(f, observed = observed, modified = modified, ctx = ctx)
121+
observed::NamedTuple = NamedTuple{()}(()), ctx = nothing, skip_checks=false)
122+
MutatingFunctionalAffect(f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
119123
end
120124
function MutatingFunctionalAffect(
121-
f::Function, modified::NamedTuple, observed::NamedTuple; ctx = nothing)
122-
MutatingFunctionalAffect(f, observed = observed, modified = modified, ctx = ctx)
125+
f::Function, modified::NamedTuple, observed::NamedTuple; ctx = nothing, skip_checks=false)
126+
MutatingFunctionalAffect(f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
123127
end
124128
function MutatingFunctionalAffect(
125-
f::Function, modified::NamedTuple, observed::NamedTuple, ctx)
126-
MutatingFunctionalAffect(f, observed = observed, modified = modified, ctx = ctx)
129+
f::Function, modified::NamedTuple, observed::NamedTuple, ctx; skip_checks=false)
130+
MutatingFunctionalAffect(f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
127131
end
128132

133+
function Base.show(io::IO, mfa::MutatingFunctionalAffect)
134+
obs_vals = join(map((ob,nm) -> "$ob => $nm", mfa.obs, mfa.obs_syms), ", ")
135+
mod_vals = join(map((md,nm) -> "$md => $nm", mfa.modified, mfa.mod_syms), ", ")
136+
affect = mfa.f
137+
print(io, "MutatingFunctionalAffect(observed: [$obs_vals], modified: [$mod_vals], affect:$affect)")
138+
end
129139
func(f::MutatingFunctionalAffect) = f.f
130140
context(a::MutatingFunctionalAffect) = a.ctx
131141
observed(a::MutatingFunctionalAffect) = a.obs
@@ -208,31 +218,101 @@ Affects (i.e. `affect` and `affect_neg`) can be specified as either:
208218
"""
209219
struct SymbolicContinuousCallback
210220
eqs::Vector{Equation}
221+
initialize::Union{Vector{Equation}, FunctionalAffect, MutatingFunctionalAffect}
222+
finalize::Union{Vector{Equation}, FunctionalAffect, MutatingFunctionalAffect}
211223
affect::Union{Vector{Equation}, FunctionalAffect, MutatingFunctionalAffect}
212224
affect_neg::Union{Vector{Equation}, FunctionalAffect, MutatingFunctionalAffect, Nothing}
213225
rootfind::SciMLBase.RootfindOpt
214-
function SymbolicContinuousCallback(; eqs::Vector{Equation}, affect = NULL_AFFECT,
215-
affect_neg = affect, rootfind = SciMLBase.LeftRootFind)
216-
new(eqs, make_affect(affect), make_affect(affect_neg), rootfind)
226+
function SymbolicContinuousCallback(;
227+
eqs::Vector{Equation},
228+
affect = NULL_AFFECT,
229+
affect_neg = affect,
230+
rootfind = SciMLBase.LeftRootFind,
231+
initialize=NULL_AFFECT,
232+
finalize=NULL_AFFECT)
233+
new(eqs, initialize, finalize, make_affect(affect), make_affect(affect_neg), rootfind)
217234
end # Default affect to nothing
218235
end
219236
make_affect(affect) = affect
220237
make_affect(affect::Tuple) = FunctionalAffect(affect...)
221238
make_affect(affect::NamedTuple) = FunctionalAffect(; affect...)
222239

223240
function Base.:(==)(e1::SymbolicContinuousCallback, e2::SymbolicContinuousCallback)
224-
isequal(e1.eqs, e2.eqs) && isequal(e1.affect, e2.affect) &&
241+
isequal(e1.eqs, e2.eqs) && isequal(e1.affect, e2.affect) &&
242+
isequal(e1.initialize, e2.initialize) && isequal(e1.finalize, e2.finalize) &&
225243
isequal(e1.affect_neg, e2.affect_neg) && isequal(e1.rootfind, e2.rootfind)
226244
end
227245
Base.isempty(cb::SymbolicContinuousCallback) = isempty(cb.eqs)
228246
function Base.hash(cb::SymbolicContinuousCallback, s::UInt)
247+
hash_affect(affect::AbstractVector, s) = foldr(hash, affect, init = s)
248+
hash_affect(affect, s) = hash(cb.affect, s)
229249
s = foldr(hash, cb.eqs, init = s)
230-
s = cb.affect isa AbstractVector ? foldr(hash, cb.affect, init = s) : hash(cb.affect, s)
231-
s = cb.affect_neg isa AbstractVector ? foldr(hash, cb.affect_neg, init = s) :
232-
hash(cb.affect_neg, s)
250+
s = hash_affect(cb.affect, s)
251+
s = hash_affect(cb.affect_neg, s)
252+
s = hash_affect(cb.initialize, s)
253+
s = hash_affect(cb.finalize, s)
233254
hash(cb.rootfind, s)
234255
end
235256

257+
258+
function Base.show(io::IO, cb::SymbolicContinuousCallback)
259+
indent = get(io, :indent, 0)
260+
iio = IOContext(io, :indent => indent+1)
261+
print(io, "SymbolicContinuousCallback(")
262+
print(iio, "Equations:")
263+
show(iio, equations(cb))
264+
print(iio, "; ")
265+
if affects(cb) != NULL_AFFECT
266+
print(iio, "Affect:")
267+
show(iio, affects(cb))
268+
print(iio, ", ")
269+
end
270+
if affect_negs(cb) != NULL_AFFECT
271+
print(iio, "Negative-edge affect:")
272+
show(iio, affect_negs(cb))
273+
print(iio, ", ")
274+
end
275+
if initialize_affects(cb) != NULL_AFFECT
276+
print(iio, "Initialization affect:")
277+
show(iio, initialize_affects(cb))
278+
print(iio, ", ")
279+
end
280+
if finalize_affects(cb) != NULL_AFFECT
281+
print(iio, "Finalization affect:")
282+
show(iio, finalize_affects(cb))
283+
end
284+
print(iio, ")")
285+
end
286+
287+
function Base.show(io::IO, mime::MIME"text/plain", cb::SymbolicContinuousCallback)
288+
indent = get(io, :indent, 0)
289+
iio = IOContext(io, :indent => indent+1)
290+
println(io, "SymbolicContinuousCallback:")
291+
println(iio, "Equations:")
292+
show(iio, mime, equations(cb))
293+
print(iio, "\n")
294+
if affects(cb) != NULL_AFFECT
295+
println(iio, "Affect:")
296+
show(iio, mime, affects(cb))
297+
print(iio, "\n")
298+
end
299+
if affect_negs(cb) != NULL_AFFECT
300+
println(iio, "Negative-edge affect:")
301+
show(iio, mime, affect_negs(cb))
302+
print(iio, "\n")
303+
end
304+
if initialize_affects(cb) != NULL_AFFECT
305+
println(iio, "Initialization affect:")
306+
show(iio, mime, initialize_affects(cb))
307+
print(iio, "\n")
308+
end
309+
if finalize_affects(cb) != NULL_AFFECT
310+
println(iio, "Finalization affect:")
311+
show(iio, mime, finalize_affects(cb))
312+
print(iio, "\n")
313+
end
314+
end
315+
236316
to_equation_vector(eq::Equation) = [eq]
237317
to_equation_vector(eqs::Vector{Equation}) = eqs
238318
function to_equation_vector(eqs::Vector{Any})
@@ -246,14 +326,14 @@ end # wrap eq in vector
246326
SymbolicContinuousCallback(p::Pair) = SymbolicContinuousCallback(p[1], p[2])
247327
SymbolicContinuousCallback(cb::SymbolicContinuousCallback) = cb # passthrough
248328
function SymbolicContinuousCallback(eqs::Equation, affect = NULL_AFFECT;
249-
affect_neg = affect, rootfind = SciMLBase.LeftRootFind)
329+
affect_neg = affect, rootfind = SciMLBase.LeftRootFind, initialize = NULL_AFFECT, finalize = NULL_AFFECT)
250330
SymbolicContinuousCallback(
251-
eqs = [eqs], affect = affect, affect_neg = affect_neg, rootfind = rootfind)
331+
eqs = [eqs], affect = affect, affect_neg = affect_neg, rootfind = rootfind, initialize=initialize, finalize=finalize)
252332
end
253333
function SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT;
254-
affect_neg = affect, rootfind = SciMLBase.LeftRootFind)
334+
affect_neg = affect, rootfind = SciMLBase.LeftRootFind, initialize = NULL_AFFECT, finalize = NULL_AFFECT)
255335
SymbolicContinuousCallback(
256-
eqs = eqs, affect = affect, affect_neg = affect_neg, rootfind = rootfind)
336+
eqs = eqs, affect = affect, affect_neg = affect_neg, rootfind = rootfind, initialize=initialize, finalize=finalize)
257337
end
258338

259339
SymbolicContinuousCallbacks(cb::SymbolicContinuousCallback) = [cb]
@@ -282,6 +362,16 @@ function affect_negs(cbs::Vector{SymbolicContinuousCallback})
282362
mapreduce(affect_negs, vcat, cbs, init = Equation[])
283363
end
284364

365+
initialize_affects(cb::SymbolicContinuousCallback) = cb.initialize
366+
function initialize_affects(cbs::Vector{SymbolicContinuousCallback})
367+
mapreduce(initialize_affects, vcat, cbs, init = Equation[])
368+
end
369+
370+
finalize_affects(cb::SymbolicContinuousCallback) = cb.initialize
371+
function finalize_affects(cbs::Vector{SymbolicContinuousCallback})
372+
mapreduce(finalize_affects, vcat, cbs, init = Equation[])
373+
end
374+
285375
namespace_affects(af::Vector, s) = Equation[namespace_affect(a, s) for a in af]
286376
namespace_affects(af::FunctionalAffect, s) = namespace_affect(af, s)
287377
namespace_affects(af::MutatingFunctionalAffect, s) = namespace_affect(af, s)
@@ -292,6 +382,8 @@ function namespace_callback(cb::SymbolicContinuousCallback, s)::SymbolicContinuo
292382
eqs = namespace_equation.(equations(cb), (s,)),
293383
affect = namespace_affects(affects(cb), s),
294384
affect_neg = namespace_affects(affect_negs(cb), s),
385+
initialize = namespace_affects(initialize_affects(cb), s),
386+
finalize = namespace_affects(finalize_affects(cb), s),
295387
rootfind = cb.rootfind)
296388
end
297389

@@ -681,8 +773,9 @@ function generate_single_rootfinding_callback(
681773
initfn = SciMLBase.INITIALIZE_DEFAULT
682774
end
683775
return ContinuousCallback(
684-
cond, affect_function.affect, affect_function.affect_neg,
685-
rootfind = cb.rootfind, initialize = initfn)
776+
cond, affect_function.affect, affect_function.affect_neg, rootfind = cb.rootfind,
777+
initialize = isnothing(affect_function.initialize) ? SciMLBase.INITIALIZE_DEFAULT : (c, u, t, i) -> affect_function.initialize(i),
778+
finalize = isnothing(affect_function.finalize) ? SciMLBase.FINALIZE_DEFAULT : (c, u, t, i) -> affect_function.finalize(i))
686779
end
687780

688781
function generate_vector_rootfinding_callback(
@@ -702,13 +795,12 @@ function generate_vector_rootfinding_callback(
702795
_, rf_ip = generate_custom_function(
703796
sys, rhss, dvs, ps; expression = Val{false}, kwargs...)
704797

705-
affect_functions = @NamedTuple{affect::Function, affect_neg::Union{Function, Nothing}}[compile_affect_fn(
706-
cb,
707-
sys,
708-
dvs,
709-
ps,
710-
kwargs)
711-
for cb in cbs]
798+
affect_functions = @NamedTuple{
799+
affect::Function,
800+
affect_neg::Union{Function, Nothing},
801+
initialize::Union{Function, Nothing},
802+
finalize::Union{Function, Nothing}}[
803+
compile_affect_fn(cb, sys, dvs, ps, kwargs) for cb in cbs]
712804
cond = function (out, u, t, integ)
713805
rf_ip(out, u, parameter_values(integ), t)
714806
end
@@ -734,25 +826,27 @@ function generate_vector_rootfinding_callback(
734826
affect_neg(integ)
735827
end
736828
end
737-
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
738-
save_idxs = mapreduce(
739-
cb -> get(ic.callback_to_clocks, cb, Int[]), vcat, cbs; init = Int[])
740-
initfn = if isempty(save_idxs)
741-
SciMLBase.INITIALIZE_DEFAULT
829+
function handle_optional_setup_fn(funs, default)
830+
if all(isnothing, funs)
831+
return default
742832
else
743-
let save_idxs = save_idxs
744-
function (cb, u, t, integrator)
745-
for idx in save_idxs
746-
SciMLBase.save_discretes!(integrator, idx)
833+
return let funs = funs
834+
function (cb, u, t, integ)
835+
for func in funs
836+
if isnothing(func)
837+
continue
838+
else
839+
func(integ)
840+
end
747841
end
748842
end
749843
end
750844
end
751-
else
752-
initfn = SciMLBase.INITIALIZE_DEFAULT
753845
end
846+
initialize = handle_optional_setup_fn(map(fn -> fn.initialize, affect_functions), SciMLBase.INITIALIZE_DEFAULT)
847+
finalize = handle_optional_setup_fn(map(fn -> fn.finalize, affect_functions), SciMLBase.FINALIZE_DEFAULT)
754848
return VectorContinuousCallback(
755-
cond, affect, affect_neg, length(eqs), rootfind = rootfind, initialize = initfn)
849+
cond, affect, affect_neg, length(eqs), rootfind = rootfind, initialize = initialize, finalize = finalize)
756850
end
757851

758852
"""
@@ -762,15 +856,23 @@ function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
762856
eq_aff = affects(cb)
763857
eq_neg_aff = affect_negs(cb)
764858
affect = compile_affect(eq_aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...)
859+
function compile_optional_affect(aff)
860+
if isnothing(aff)
861+
return nothing
862+
else
863+
affspr = compile_affect(aff, cb, sys, dvs, ps; expression = Val{true}, kwargs...)
864+
@show affspr
865+
return compile_affect(aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...)
866+
end
867+
end
765868
if eq_neg_aff === eq_aff
766869
affect_neg = affect
767-
elseif isnothing(eq_neg_aff)
768-
affect_neg = nothing
769870
else
770-
affect_neg = compile_affect(
771-
eq_neg_aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...)
871+
affect_neg = compile_optional_affect(eq_neg_aff)
772872
end
773-
(affect = affect, affect_neg = affect_neg)
873+
initialize = compile_optional_affect(initialize_affects(cb))
874+
finalize = compile_optional_affect(finalize_affects(cb))
875+
(affect = affect, affect_neg = affect_neg, initialize = initialize, finalize = finalize)
774876
end
775877

776878
function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys),
@@ -877,7 +979,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa
877979
push!(syms_dedup, sym)
878980
push!(exprs_dedup, exp)
879981
push!(seen, sym)
880-
else
982+
elseif !affect.skip_checks
881983
@warn "Expression $(expr) is aliased as $sym, which has already been used. The first definition will be used."
882984
end
883985
end
@@ -887,7 +989,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa
887989
obs_exprs = observed(affect)
888990
for oexpr in obs_exprs
889991
invalid_vars = invalid_variables(sys, oexpr)
890-
if length(invalid_vars) > 0
992+
if length(invalid_vars) > 0 && !affect.skip_checks
891993
error("Observed equation $(oexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing).")
892994
end
893995
end
@@ -897,11 +999,11 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa
897999

8981000
mod_exprs = modified(affect)
8991001
for mexpr in mod_exprs
900-
if !is_observed(sys, mexpr) && parameter_index(sys, mexpr) === nothing
901-
error("Expression $mexpr cannot be assigned to; currently only unknowns and parameters may be updated by an affect.")
1002+
if !is_observed(sys, mexpr) && parameter_index(sys, mexpr) === nothing && !affect.skip_checks
1003+
@warn ("Expression $mexpr cannot be assigned to; currently only unknowns and parameters may be updated by an affect.")
9021004
end
9031005
invalid_vars = unassignable_variables(sys, mexpr)
904-
if length(invalid_vars) > 0
1006+
if length(invalid_vars) > 0 && !affect.skip_checks
9051007
error("Modified equation $(mexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing) or they may have been reduced away.")
9061008
end
9071009
end
@@ -911,7 +1013,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa
9111013
sys, mod_exprs; return_inplace = true)
9121014

9131015
overlapping_syms = intersect(mod_syms, obs_syms)
914-
if length(overlapping_syms) > 0
1016+
if length(overlapping_syms) > 0 && !affect.skip_checks
9151017
@warn "The symbols $overlapping_syms are declared as both observed and modified; this is a code smell because it becomes easy to confuse them and assign/not assign a value."
9161018
end
9171019

0 commit comments

Comments
 (0)