Skip to content

Commit a023f7e

Browse files
committed
Add support for initialize and finalize callbacks
1 parent f4d3974 commit a023f7e

File tree

1 file changed

+138
-39
lines changed

1 file changed

+138
-39
lines changed

src/systems/callbacks.jl

Lines changed: 138 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ initialization.
114114
"""
115115
struct SymbolicContinuousCallback
116116
eqs::Vector{Equation}
117+
initialize::Union{Vector{Equation}, FunctionalAffect}
118+
finalize::Union{Vector{Equation}, FunctionalAffect}
117119
affect::Union{Vector{Equation}, FunctionalAffect}
118120
affect_neg::Union{Vector{Equation}, FunctionalAffect, Nothing}
119121
rootfind::SciMLBase.RootfindOpt
@@ -122,9 +124,12 @@ struct SymbolicContinuousCallback
122124
eqs::Vector{Equation},
123125
affect = NULL_AFFECT,
124126
affect_neg = affect,
127+
initialize = NULL_AFFECT,
128+
finalize = NULL_AFFECT,
125129
rootfind = SciMLBase.LeftRootFind,
126130
reinitializealg = SciMLBase.CheckInit())
127-
new(eqs, make_affect(affect), make_affect(affect_neg), rootfind, reinitializealg)
131+
new(eqs, initialize, finalize, make_affect(affect),
132+
make_affect(affect_neg), rootfind, reinitializealg)
128133
end # Default affect to nothing
129134
end
130135
make_affect(affect) = affect
@@ -133,17 +138,80 @@ make_affect(affect::NamedTuple) = FunctionalAffect(; affect...)
133138

134139
function Base.:(==)(e1::SymbolicContinuousCallback, e2::SymbolicContinuousCallback)
135140
isequal(e1.eqs, e2.eqs) && isequal(e1.affect, e2.affect) &&
141+
isequal(e1.initialize, e2.initialize) && isequal(e1.finalize, e2.finalize) &&
136142
isequal(e1.affect_neg, e2.affect_neg) && isequal(e1.rootfind, e2.rootfind)
137143
end
138144
Base.isempty(cb::SymbolicContinuousCallback) = isempty(cb.eqs)
139145
function Base.hash(cb::SymbolicContinuousCallback, s::UInt)
146+
hash_affect(affect::AbstractVector, s) = foldr(hash, affect, init = s)
147+
hash_affect(affect, s) = hash(affect, s)
140148
s = foldr(hash, cb.eqs, init = s)
141-
s = cb.affect isa AbstractVector ? foldr(hash, cb.affect, init = s) : hash(cb.affect, s)
142-
s = cb.affect_neg isa AbstractVector ? foldr(hash, cb.affect_neg, init = s) :
143-
hash(cb.affect_neg, s)
149+
s = hash_affect(cb.affect, s)
150+
s = hash_affect(cb.affect_neg, s)
151+
s = hash_affect(cb.initialize, s)
152+
s = hash_affect(cb.finalize, s)
153+
s = hash(cb.reinitializealg, s)
144154
hash(cb.rootfind, s)
145155
end
146156

157+
function Base.show(io::IO, cb::SymbolicContinuousCallback)
158+
indent = get(io, :indent, 0)
159+
iio = IOContext(io, :indent => indent + 1)
160+
print(io, "SymbolicContinuousCallback(")
161+
print(iio, "Equations:")
162+
show(iio, equations(cb))
163+
print(iio, "; ")
164+
if affects(cb) != NULL_AFFECT
165+
print(iio, "Affect:")
166+
show(iio, affects(cb))
167+
print(iio, ", ")
168+
end
169+
if affect_negs(cb) != NULL_AFFECT
170+
print(iio, "Negative-edge affect:")
171+
show(iio, affect_negs(cb))
172+
print(iio, ", ")
173+
end
174+
if initialize_affects(cb) != NULL_AFFECT
175+
print(iio, "Initialization affect:")
176+
show(iio, initialize_affects(cb))
177+
print(iio, ", ")
178+
end
179+
if finalize_affects(cb) != NULL_AFFECT
180+
print(iio, "Finalization affect:")
181+
show(iio, finalize_affects(cb))
182+
end
183+
print(iio, ")")
184+
end
185+
186+
function Base.show(io::IO, mime::MIME"text/plain", cb::SymbolicContinuousCallback)
187+
indent = get(io, :indent, 0)
188+
iio = IOContext(io, :indent => indent + 1)
189+
println(io, "SymbolicContinuousCallback:")
190+
println(iio, "Equations:")
191+
show(iio, mime, equations(cb))
192+
print(iio, "\n")
193+
if affects(cb) != NULL_AFFECT
194+
println(iio, "Affect:")
195+
show(iio, mime, affects(cb))
196+
print(iio, "\n")
197+
end
198+
if affect_negs(cb) != NULL_AFFECT
199+
println(iio, "Negative-edge affect:")
200+
show(iio, mime, affect_negs(cb))
201+
print(iio, "\n")
202+
end
203+
if initialize_affects(cb) != NULL_AFFECT
204+
println(iio, "Initialization affect:")
205+
show(iio, mime, initialize_affects(cb))
206+
print(iio, "\n")
207+
end
208+
if finalize_affects(cb) != NULL_AFFECT
209+
println(iio, "Finalization affect:")
210+
show(iio, mime, finalize_affects(cb))
211+
print(iio, "\n")
212+
end
213+
end
214+
147215
to_equation_vector(eq::Equation) = [eq]
148216
to_equation_vector(eqs::Vector{Equation}) = eqs
149217
function to_equation_vector(eqs::Vector{Any})
@@ -156,15 +224,18 @@ function SymbolicContinuousCallback(args...)
156224
end # wrap eq in vector
157225
SymbolicContinuousCallback(p::Pair) = SymbolicContinuousCallback(p[1], p[2])
158226
SymbolicContinuousCallback(cb::SymbolicContinuousCallback) = cb # passthrough
159-
function SymbolicContinuousCallback(eqs::Equation, affect = NULL_AFFECT;
160-
affect_neg = affect, rootfind = SciMLBase.LeftRootFind)
227+
function SymbolicContinuousCallback(eqs::Equation, affect = NULL_AFFECT;
228+
initialize=NULL_AFFECT, finalize=NULL_AFFECT,
229+
affect_neg = affect, rootfind = SciMLBase.LeftRootFind)
161230
SymbolicContinuousCallback(
162-
eqs = [eqs], affect = affect, affect_neg = affect_neg, rootfind = rootfind)
231+
eqs = [eqs], affect = affect, affect_neg = affect_neg,
232+
initialize=initialize, finalize=finalize, rootfind = rootfind)
163233
end
164234
function SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT;
165-
affect_neg = affect, rootfind = SciMLBase.LeftRootFind)
235+
affect_neg = affect, initialize=NULL_AFFECT, finalize=NULL_AFFECT,
236+
rootfind = SciMLBase.LeftRootFind)
166237
SymbolicContinuousCallback(
167-
eqs = eqs, affect = affect, affect_neg = affect_neg, rootfind = rootfind)
238+
eqs = eqs, affect = affect, affect_neg = affect_neg, initialize=initialize, finalize=finalize, rootfind = rootfind)
168239
end
169240

170241
SymbolicContinuousCallbacks(cb::SymbolicContinuousCallback) = [cb]
@@ -199,15 +270,28 @@ function reinitialization_algs(cbs::Vector{SymbolicContinuousCallback})
199270
reinitialization_alg, vcat, cbs, init = SciMLBase.DAEInitializationAlgorithm[])
200271
end
201272

273+
initialize_affects(cb::SymbolicContinuousCallback) = cb.initialize
274+
function initialize_affects(cbs::Vector{SymbolicContinuousCallback})
275+
mapreduce(initialize_affects, vcat, cbs, init = Equation[])
276+
end
277+
278+
finalize_affects(cb::SymbolicContinuousCallback) = cb.initialize
279+
function finalize_affects(cbs::Vector{SymbolicContinuousCallback})
280+
mapreduce(finalize_affects, vcat, cbs, init = Equation[])
281+
end
282+
202283
namespace_affects(af::Vector, s) = Equation[namespace_affect(a, s) for a in af]
203284
namespace_affects(af::FunctionalAffect, s) = namespace_affect(af, s)
204285
namespace_affects(::Nothing, s) = nothing
205286

206287
function namespace_callback(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallback
207-
SymbolicContinuousCallback(
208-
namespace_equation.(equations(cb), (s,)),
209-
namespace_affects(affects(cb), s);
210-
affect_neg = namespace_affects(affect_negs(cb), s))
288+
SymbolicContinuousCallback(;
289+
eqs = namespace_equation.(equations(cb), (s,)),
290+
affect = namespace_affects(affects(cb), s),
291+
affect_neg = namespace_affects(affect_negs(cb), s),
292+
initialize = namespace_affects(initialize_affects(cb), s),
293+
finalize = namespace_affects(finalize_affects(cb), s),
294+
rootfind = cb.rootfind)
211295
end
212296

213297
"""
@@ -589,22 +673,25 @@ function generate_single_rootfinding_callback(
589673
rf_oop(u, parameter_values(integ), t)
590674
end
591675
end
592-
676+
user_initfun = (affect_function.initialize == NULL_AFFECT) ? SciMLBase.INITIALIZE_DEFAULT :
677+
(c, u, t, i) -> affect_function.initialize(i)
593678
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing &&
594679
(save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing
595680
initfn = let save_idxs = save_idxs
596681
function (cb, u, t, integrator)
682+
user_initfun(cb, u, t, integrator)
597683
for idx in save_idxs
598684
SciMLBase.save_discretes!(integrator, idx)
599685
end
600686
end
601687
end
602688
else
603-
initfn = SciMLBase.INITIALIZE_DEFAULT
689+
initfn = user_initfun
604690
end
605691
return ContinuousCallback(
606692
cond, affect_function.affect, affect_function.affect_neg, rootfind = cb.rootfind,
607693
initialize = initfn,
694+
finalize = (affect_function.finalize == NULL_AFFECT) ? SciMLBase.FINALIZE_DEFAULT : (c, u, t, i) -> affect_function.finalize(i),
608695
initializealg = reinitialization_alg(cb))
609696
end
610697

@@ -626,13 +713,14 @@ function generate_vector_rootfinding_callback(
626713
_, rf_ip = generate_custom_function(
627714
sys, rhss, dvs, ps; expression = Val{false}, kwargs...)
628715

629-
affect_functions = @NamedTuple{affect::Function, affect_neg::Union{Function, Nothing}}[compile_affect_fn(
630-
cb,
631-
sys,
632-
dvs,
633-
ps,
634-
kwargs)
635-
for cb in cbs]
716+
affect_functions = @NamedTuple{
717+
affect::Function,
718+
affect_neg::Union{Function, Nothing},
719+
initialize::Union{Function, Nothing},
720+
finalize::Union{Function, Nothing}}[
721+
compile_affect_fn(cb, sys, dvs, ps, kwargs)
722+
for cb in cbs]
723+
636724
cond = function (out, u, t, integ)
637725
rf_ip(out, u, parameter_values(integ), t)
638726
end
@@ -658,26 +746,31 @@ function generate_vector_rootfinding_callback(
658746
affect_neg(integ)
659747
end
660748
end
661-
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
662-
save_idxs = mapreduce(
663-
cb -> get(ic.callback_to_clocks, cb, Int[]), vcat, cbs; init = Int[])
664-
initfn = if isempty(save_idxs)
665-
SciMLBase.INITIALIZE_DEFAULT
749+
function handle_optional_setup_fn(funs, default)
750+
if all(isnothing, funs)
751+
return default
666752
else
667-
let save_idxs = save_idxs
668-
function (cb, u, t, integrator)
669-
for idx in save_idxs
670-
SciMLBase.save_discretes!(integrator, idx)
753+
return let funs = funs
754+
function (cb, u, t, integ)
755+
for func in funs
756+
if isnothing(func)
757+
continue
758+
else
759+
func(integ)
760+
end
671761
end
672762
end
673763
end
674764
end
675-
else
676-
initfn = SciMLBase.INITIALIZE_DEFAULT
677765
end
766+
767+
initialize = handle_optional_setup_fn(
768+
map(fn -> fn.initialize, affect_functions), SciMLBase.INITIALIZE_DEFAULT)
769+
finalize = handle_optional_setup_fn(
770+
map(fn -> fn.finalize, affect_functions), SciMLBase.FINALIZE_DEFAULT)
678771
return VectorContinuousCallback(
679772
cond, affect, affect_neg, length(eqs), rootfind = rootfind,
680-
initialize = initfn, initializealg = reinitialization)
773+
initialize = initialize, finalize = finalize, initializealg = reinitialization)
681774
end
682775

683776
"""
@@ -687,15 +780,21 @@ function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
687780
eq_aff = affects(cb)
688781
eq_neg_aff = affect_negs(cb)
689782
affect = compile_affect(eq_aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...)
783+
function compile_optional_affect(aff, default = nothing)
784+
if isnothing(aff) || aff == default
785+
return nothing
786+
else
787+
return compile_affect(aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...)
788+
end
789+
end
690790
if eq_neg_aff === eq_aff
691791
affect_neg = affect
692-
elseif isnothing(eq_neg_aff)
693-
affect_neg = nothing
694792
else
695-
affect_neg = compile_affect(
696-
eq_neg_aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...)
793+
affect_neg = compile_optional_affect(eq_neg_aff)
697794
end
698-
(affect = affect, affect_neg = affect_neg)
795+
initialize = compile_optional_affect(initialize_affects(cb), NULL_AFFECT)
796+
finalize = compile_optional_affect(finalize_affects(cb), NULL_AFFECT)
797+
(affect = affect, affect_neg = affect_neg, initialize = initialize, finalize = finalize)
699798
end
700799

701800
function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys),

0 commit comments

Comments
 (0)