@@ -114,6 +114,8 @@ initialization.
114114"""
115115struct SymbolicContinuousCallback
116116 eqs:: Vector{Equation}
117+ initialize:: Union{Vector{Equation}, FunctionalAffect}
118+ finalize:: Union{Vector{Equation}, FunctionalAffect}
117119 affect:: Union{Vector{Equation}, FunctionalAffect}
118120 affect_neg:: Union{Vector{Equation}, FunctionalAffect, Nothing}
119121 rootfind:: SciMLBase.RootfindOpt
@@ -122,9 +124,12 @@ struct SymbolicContinuousCallback
122124 eqs:: Vector{Equation} ,
123125 affect = NULL_AFFECT,
124126 affect_neg = affect,
127+ initialize = NULL_AFFECT,
128+ finalize = NULL_AFFECT,
125129 rootfind = SciMLBase. LeftRootFind,
126130 reinitializealg = SciMLBase. CheckInit ())
127- new (eqs, make_affect (affect), make_affect (affect_neg), rootfind, reinitializealg)
131+ new (eqs, initialize, finalize, make_affect (affect),
132+ make_affect (affect_neg), rootfind, reinitializealg)
128133 end # Default affect to nothing
129134end
130135make_affect (affect) = affect
@@ -133,17 +138,80 @@ make_affect(affect::NamedTuple) = FunctionalAffect(; affect...)
133138
134139function Base.:(== )(e1:: SymbolicContinuousCallback , e2:: SymbolicContinuousCallback )
135140 isequal (e1. eqs, e2. eqs) && isequal (e1. affect, e2. affect) &&
141+ isequal (e1. initialize, e2. initialize) && isequal (e1. finalize, e2. finalize) &&
136142 isequal (e1. affect_neg, e2. affect_neg) && isequal (e1. rootfind, e2. rootfind)
137143end
138144Base. isempty (cb:: SymbolicContinuousCallback ) = isempty (cb. eqs)
139145function Base. hash (cb:: SymbolicContinuousCallback , s:: UInt )
146+ hash_affect (affect:: AbstractVector , s) = foldr (hash, affect, init = s)
147+ hash_affect (affect, s) = hash (affect, s)
140148 s = foldr (hash, cb. eqs, init = s)
141- s = cb. affect isa AbstractVector ? foldr (hash, cb. affect, init = s) : hash (cb. affect, s)
142- s = cb. affect_neg isa AbstractVector ? foldr (hash, cb. affect_neg, init = s) :
143- hash (cb. affect_neg, s)
149+ s = hash_affect (cb. affect, s)
150+ s = hash_affect (cb. affect_neg, s)
151+ s = hash_affect (cb. initialize, s)
152+ s = hash_affect (cb. finalize, s)
153+ s = hash (cb. reinitializealg, s)
144154 hash (cb. rootfind, s)
145155end
146156
157+ function Base. show (io:: IO , cb:: SymbolicContinuousCallback )
158+ indent = get (io, :indent , 0 )
159+ iio = IOContext (io, :indent => indent + 1 )
160+ print (io, " SymbolicContinuousCallback(" )
161+ print (iio, " Equations:" )
162+ show (iio, equations (cb))
163+ print (iio, " ; " )
164+ if affects (cb) != NULL_AFFECT
165+ print (iio, " Affect:" )
166+ show (iio, affects (cb))
167+ print (iio, " , " )
168+ end
169+ if affect_negs (cb) != NULL_AFFECT
170+ print (iio, " Negative-edge affect:" )
171+ show (iio, affect_negs (cb))
172+ print (iio, " , " )
173+ end
174+ if initialize_affects (cb) != NULL_AFFECT
175+ print (iio, " Initialization affect:" )
176+ show (iio, initialize_affects (cb))
177+ print (iio, " , " )
178+ end
179+ if finalize_affects (cb) != NULL_AFFECT
180+ print (iio, " Finalization affect:" )
181+ show (iio, finalize_affects (cb))
182+ end
183+ print (iio, " )" )
184+ end
185+
186+ function Base. show (io:: IO , mime:: MIME"text/plain" , cb:: SymbolicContinuousCallback )
187+ indent = get (io, :indent , 0 )
188+ iio = IOContext (io, :indent => indent + 1 )
189+ println (io, " SymbolicContinuousCallback:" )
190+ println (iio, " Equations:" )
191+ show (iio, mime, equations (cb))
192+ print (iio, " \n " )
193+ if affects (cb) != NULL_AFFECT
194+ println (iio, " Affect:" )
195+ show (iio, mime, affects (cb))
196+ print (iio, " \n " )
197+ end
198+ if affect_negs (cb) != NULL_AFFECT
199+ println (iio, " Negative-edge affect:" )
200+ show (iio, mime, affect_negs (cb))
201+ print (iio, " \n " )
202+ end
203+ if initialize_affects (cb) != NULL_AFFECT
204+ println (iio, " Initialization affect:" )
205+ show (iio, mime, initialize_affects (cb))
206+ print (iio, " \n " )
207+ end
208+ if finalize_affects (cb) != NULL_AFFECT
209+ println (iio, " Finalization affect:" )
210+ show (iio, mime, finalize_affects (cb))
211+ print (iio, " \n " )
212+ end
213+ end
214+
147215to_equation_vector (eq:: Equation ) = [eq]
148216to_equation_vector (eqs:: Vector{Equation} ) = eqs
149217function to_equation_vector (eqs:: Vector{Any} )
@@ -156,15 +224,18 @@ function SymbolicContinuousCallback(args...)
156224end # wrap eq in vector
157225SymbolicContinuousCallback (p:: Pair ) = SymbolicContinuousCallback (p[1 ], p[2 ])
158226SymbolicContinuousCallback (cb:: SymbolicContinuousCallback ) = cb # passthrough
159- function SymbolicContinuousCallback (eqs:: Equation , affect = NULL_AFFECT;
160- affect_neg = affect, rootfind = SciMLBase. LeftRootFind)
227+ function SymbolicContinuousCallback (eqs:: Equation , affect = NULL_AFFECT;
228+ initialize= NULL_AFFECT, finalize= NULL_AFFECT,
229+ affect_neg = affect, rootfind = SciMLBase. LeftRootFind)
161230 SymbolicContinuousCallback (
162- eqs = [eqs], affect = affect, affect_neg = affect_neg, rootfind = rootfind)
231+ eqs = [eqs], affect = affect, affect_neg = affect_neg,
232+ initialize= initialize, finalize= finalize, rootfind = rootfind)
163233end
164234function SymbolicContinuousCallback (eqs:: Vector{Equation} , affect = NULL_AFFECT;
165- affect_neg = affect, rootfind = SciMLBase. LeftRootFind)
235+ affect_neg = affect, initialize= NULL_AFFECT, finalize= NULL_AFFECT,
236+ rootfind = SciMLBase. LeftRootFind)
166237 SymbolicContinuousCallback (
167- eqs = eqs, affect = affect, affect_neg = affect_neg, rootfind = rootfind)
238+ eqs = eqs, affect = affect, affect_neg = affect_neg, initialize = initialize, finalize = finalize, rootfind = rootfind)
168239end
169240
170241SymbolicContinuousCallbacks (cb:: SymbolicContinuousCallback ) = [cb]
@@ -199,15 +270,28 @@ function reinitialization_algs(cbs::Vector{SymbolicContinuousCallback})
199270 reinitialization_alg, vcat, cbs, init = SciMLBase. DAEInitializationAlgorithm[])
200271end
201272
273+ initialize_affects (cb:: SymbolicContinuousCallback ) = cb. initialize
274+ function initialize_affects (cbs:: Vector{SymbolicContinuousCallback} )
275+ mapreduce (initialize_affects, vcat, cbs, init = Equation[])
276+ end
277+
278+ finalize_affects (cb:: SymbolicContinuousCallback ) = cb. initialize
279+ function finalize_affects (cbs:: Vector{SymbolicContinuousCallback} )
280+ mapreduce (finalize_affects, vcat, cbs, init = Equation[])
281+ end
282+
202283namespace_affects (af:: Vector , s) = Equation[namespace_affect (a, s) for a in af]
203284namespace_affects (af:: FunctionalAffect , s) = namespace_affect (af, s)
204285namespace_affects (:: Nothing , s) = nothing
205286
206287function namespace_callback (cb:: SymbolicContinuousCallback , s):: SymbolicContinuousCallback
207- SymbolicContinuousCallback (
208- namespace_equation .(equations (cb), (s,)),
209- namespace_affects (affects (cb), s);
210- affect_neg = namespace_affects (affect_negs (cb), s))
288+ SymbolicContinuousCallback (;
289+ eqs = namespace_equation .(equations (cb), (s,)),
290+ affect = namespace_affects (affects (cb), s),
291+ affect_neg = namespace_affects (affect_negs (cb), s),
292+ initialize = namespace_affects (initialize_affects (cb), s),
293+ finalize = namespace_affects (finalize_affects (cb), s),
294+ rootfind = cb. rootfind)
211295end
212296
213297"""
@@ -589,22 +673,25 @@ function generate_single_rootfinding_callback(
589673 rf_oop (u, parameter_values (integ), t)
590674 end
591675 end
592-
676+ user_initfun = (affect_function. initialize == NULL_AFFECT) ? SciMLBase. INITIALIZE_DEFAULT :
677+ (c, u, t, i) -> affect_function. initialize (i)
593678 if has_index_cache (sys) && (ic = get_index_cache (sys)) != = nothing &&
594679 (save_idxs = get (ic. callback_to_clocks, cb, nothing )) != = nothing
595680 initfn = let save_idxs = save_idxs
596681 function (cb, u, t, integrator)
682+ user_initfun (cb, u, t, integrator)
597683 for idx in save_idxs
598684 SciMLBase. save_discretes! (integrator, idx)
599685 end
600686 end
601687 end
602688 else
603- initfn = SciMLBase . INITIALIZE_DEFAULT
689+ initfn = user_initfun
604690 end
605691 return ContinuousCallback (
606692 cond, affect_function. affect, affect_function. affect_neg, rootfind = cb. rootfind,
607693 initialize = initfn,
694+ finalize = (affect_function. finalize == NULL_AFFECT) ? SciMLBase. FINALIZE_DEFAULT : (c, u, t, i) -> affect_function. finalize (i),
608695 initializealg = reinitialization_alg (cb))
609696end
610697
@@ -626,13 +713,14 @@ function generate_vector_rootfinding_callback(
626713 _, rf_ip = generate_custom_function (
627714 sys, rhss, dvs, ps; expression = Val{false }, kwargs... )
628715
629- affect_functions = @NamedTuple {affect:: Function , affect_neg:: Union{Function, Nothing} }[compile_affect_fn (
630- cb,
631- sys,
632- dvs,
633- ps,
634- kwargs)
635- for cb in cbs]
716+ affect_functions = @NamedTuple {
717+ affect:: Function ,
718+ affect_neg:: Union{Function, Nothing} ,
719+ initialize:: Union{Function, Nothing} ,
720+ finalize:: Union{Function, Nothing} }[
721+ compile_affect_fn (cb, sys, dvs, ps, kwargs)
722+ for cb in cbs]
723+
636724 cond = function (out, u, t, integ)
637725 rf_ip (out, u, parameter_values (integ), t)
638726 end
@@ -658,26 +746,31 @@ function generate_vector_rootfinding_callback(
658746 affect_neg (integ)
659747 end
660748 end
661- if has_index_cache (sys) && (ic = get_index_cache (sys)) != = nothing
662- save_idxs = mapreduce (
663- cb -> get (ic. callback_to_clocks, cb, Int[]), vcat, cbs; init = Int[])
664- initfn = if isempty (save_idxs)
665- SciMLBase. INITIALIZE_DEFAULT
749+ function handle_optional_setup_fn (funs, default)
750+ if all (isnothing, funs)
751+ return default
666752 else
667- let save_idxs = save_idxs
668- function (cb, u, t, integrator)
669- for idx in save_idxs
670- SciMLBase. save_discretes! (integrator, idx)
753+ return let funs = funs
754+ function (cb, u, t, integ)
755+ for func in funs
756+ if isnothing (func)
757+ continue
758+ else
759+ func (integ)
760+ end
671761 end
672762 end
673763 end
674764 end
675- else
676- initfn = SciMLBase. INITIALIZE_DEFAULT
677765 end
766+
767+ initialize = handle_optional_setup_fn (
768+ map (fn -> fn. initialize, affect_functions), SciMLBase. INITIALIZE_DEFAULT)
769+ finalize = handle_optional_setup_fn (
770+ map (fn -> fn. finalize, affect_functions), SciMLBase. FINALIZE_DEFAULT)
678771 return VectorContinuousCallback (
679772 cond, affect, affect_neg, length (eqs), rootfind = rootfind,
680- initialize = initfn , initializealg = reinitialization)
773+ initialize = initialize, finalize = finalize , initializealg = reinitialization)
681774end
682775
683776"""
@@ -687,15 +780,21 @@ function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
687780 eq_aff = affects (cb)
688781 eq_neg_aff = affect_negs (cb)
689782 affect = compile_affect (eq_aff, cb, sys, dvs, ps; expression = Val{false }, kwargs... )
783+ function compile_optional_affect (aff, default = nothing )
784+ if isnothing (aff) || aff == default
785+ return nothing
786+ else
787+ return compile_affect (aff, cb, sys, dvs, ps; expression = Val{false }, kwargs... )
788+ end
789+ end
690790 if eq_neg_aff === eq_aff
691791 affect_neg = affect
692- elseif isnothing (eq_neg_aff)
693- affect_neg = nothing
694792 else
695- affect_neg = compile_affect (
696- eq_neg_aff, cb, sys, dvs, ps; expression = Val{false }, kwargs... )
793+ affect_neg = compile_optional_affect (eq_neg_aff)
697794 end
698- (affect = affect, affect_neg = affect_neg)
795+ initialize = compile_optional_affect (initialize_affects (cb), NULL_AFFECT)
796+ finalize = compile_optional_affect (finalize_affects (cb), NULL_AFFECT)
797+ (affect = affect, affect_neg = affect_neg, initialize = initialize, finalize = finalize)
699798end
700799
701800function generate_rootfinding_callback (cbs, sys:: AbstractODESystem , dvs = unknowns (sys),
0 commit comments