@@ -114,6 +114,8 @@ initialization.
114
114
"""
115
115
struct SymbolicContinuousCallback
116
116
eqs:: Vector{Equation}
117
+ initialize:: Union{Vector{Equation}, FunctionalAffect}
118
+ finalize:: Union{Vector{Equation}, FunctionalAffect}
117
119
affect:: Union{Vector{Equation}, FunctionalAffect}
118
120
affect_neg:: Union{Vector{Equation}, FunctionalAffect, Nothing}
119
121
rootfind:: SciMLBase.RootfindOpt
@@ -122,9 +124,12 @@ struct SymbolicContinuousCallback
122
124
eqs:: Vector{Equation} ,
123
125
affect = NULL_AFFECT,
124
126
affect_neg = affect,
127
+ initialize = NULL_AFFECT,
128
+ finalize = NULL_AFFECT,
125
129
rootfind = SciMLBase. LeftRootFind,
126
130
reinitializealg = SciMLBase. CheckInit ())
127
- new (eqs, make_affect (affect), make_affect (affect_neg), rootfind, reinitializealg)
131
+ new (eqs, initialize, finalize, make_affect (affect),
132
+ make_affect (affect_neg), rootfind, reinitializealg)
128
133
end # Default affect to nothing
129
134
end
130
135
make_affect (affect) = affect
@@ -133,17 +138,80 @@ make_affect(affect::NamedTuple) = FunctionalAffect(; affect...)
133
138
134
139
function Base.:(== )(e1:: SymbolicContinuousCallback , e2:: SymbolicContinuousCallback )
135
140
isequal (e1. eqs, e2. eqs) && isequal (e1. affect, e2. affect) &&
141
+ isequal (e1. initialize, e2. initialize) && isequal (e1. finalize, e2. finalize) &&
136
142
isequal (e1. affect_neg, e2. affect_neg) && isequal (e1. rootfind, e2. rootfind)
137
143
end
138
144
Base. isempty (cb:: SymbolicContinuousCallback ) = isempty (cb. eqs)
139
145
function Base. hash (cb:: SymbolicContinuousCallback , s:: UInt )
146
+ hash_affect (affect:: AbstractVector , s) = foldr (hash, affect, init = s)
147
+ hash_affect (affect, s) = hash (affect, s)
140
148
s = foldr (hash, cb. eqs, init = s)
141
- s = cb. affect isa AbstractVector ? foldr (hash, cb. affect, init = s) : hash (cb. affect, s)
142
- s = cb. affect_neg isa AbstractVector ? foldr (hash, cb. affect_neg, init = s) :
143
- hash (cb. affect_neg, s)
149
+ s = hash_affect (cb. affect, s)
150
+ s = hash_affect (cb. affect_neg, s)
151
+ s = hash_affect (cb. initialize, s)
152
+ s = hash_affect (cb. finalize, s)
153
+ s = hash (cb. reinitializealg, s)
144
154
hash (cb. rootfind, s)
145
155
end
146
156
157
+ function Base. show (io:: IO , cb:: SymbolicContinuousCallback )
158
+ indent = get (io, :indent , 0 )
159
+ iio = IOContext (io, :indent => indent + 1 )
160
+ print (io, " SymbolicContinuousCallback(" )
161
+ print (iio, " Equations:" )
162
+ show (iio, equations (cb))
163
+ print (iio, " ; " )
164
+ if affects (cb) != NULL_AFFECT
165
+ print (iio, " Affect:" )
166
+ show (iio, affects (cb))
167
+ print (iio, " , " )
168
+ end
169
+ if affect_negs (cb) != NULL_AFFECT
170
+ print (iio, " Negative-edge affect:" )
171
+ show (iio, affect_negs (cb))
172
+ print (iio, " , " )
173
+ end
174
+ if initialize_affects (cb) != NULL_AFFECT
175
+ print (iio, " Initialization affect:" )
176
+ show (iio, initialize_affects (cb))
177
+ print (iio, " , " )
178
+ end
179
+ if finalize_affects (cb) != NULL_AFFECT
180
+ print (iio, " Finalization affect:" )
181
+ show (iio, finalize_affects (cb))
182
+ end
183
+ print (iio, " )" )
184
+ end
185
+
186
+ function Base. show (io:: IO , mime:: MIME"text/plain" , cb:: SymbolicContinuousCallback )
187
+ indent = get (io, :indent , 0 )
188
+ iio = IOContext (io, :indent => indent + 1 )
189
+ println (io, " SymbolicContinuousCallback:" )
190
+ println (iio, " Equations:" )
191
+ show (iio, mime, equations (cb))
192
+ print (iio, " \n " )
193
+ if affects (cb) != NULL_AFFECT
194
+ println (iio, " Affect:" )
195
+ show (iio, mime, affects (cb))
196
+ print (iio, " \n " )
197
+ end
198
+ if affect_negs (cb) != NULL_AFFECT
199
+ println (iio, " Negative-edge affect:" )
200
+ show (iio, mime, affect_negs (cb))
201
+ print (iio, " \n " )
202
+ end
203
+ if initialize_affects (cb) != NULL_AFFECT
204
+ println (iio, " Initialization affect:" )
205
+ show (iio, mime, initialize_affects (cb))
206
+ print (iio, " \n " )
207
+ end
208
+ if finalize_affects (cb) != NULL_AFFECT
209
+ println (iio, " Finalization affect:" )
210
+ show (iio, mime, finalize_affects (cb))
211
+ print (iio, " \n " )
212
+ end
213
+ end
214
+
147
215
to_equation_vector (eq:: Equation ) = [eq]
148
216
to_equation_vector (eqs:: Vector{Equation} ) = eqs
149
217
function to_equation_vector (eqs:: Vector{Any} )
@@ -156,15 +224,18 @@ function SymbolicContinuousCallback(args...)
156
224
end # wrap eq in vector
157
225
SymbolicContinuousCallback (p:: Pair ) = SymbolicContinuousCallback (p[1 ], p[2 ])
158
226
SymbolicContinuousCallback (cb:: SymbolicContinuousCallback ) = cb # passthrough
159
- function SymbolicContinuousCallback (eqs:: Equation , affect = NULL_AFFECT;
160
- affect_neg = affect, rootfind = SciMLBase. LeftRootFind)
227
+ function SymbolicContinuousCallback (eqs:: Equation , affect = NULL_AFFECT;
228
+ initialize= NULL_AFFECT, finalize= NULL_AFFECT,
229
+ affect_neg = affect, rootfind = SciMLBase. LeftRootFind)
161
230
SymbolicContinuousCallback (
162
- eqs = [eqs], affect = affect, affect_neg = affect_neg, rootfind = rootfind)
231
+ eqs = [eqs], affect = affect, affect_neg = affect_neg,
232
+ initialize= initialize, finalize= finalize, rootfind = rootfind)
163
233
end
164
234
function SymbolicContinuousCallback (eqs:: Vector{Equation} , affect = NULL_AFFECT;
165
- affect_neg = affect, rootfind = SciMLBase. LeftRootFind)
235
+ affect_neg = affect, initialize= NULL_AFFECT, finalize= NULL_AFFECT,
236
+ rootfind = SciMLBase. LeftRootFind)
166
237
SymbolicContinuousCallback (
167
- eqs = eqs, affect = affect, affect_neg = affect_neg, rootfind = rootfind)
238
+ eqs = eqs, affect = affect, affect_neg = affect_neg, initialize = initialize, finalize = finalize, rootfind = rootfind)
168
239
end
169
240
170
241
SymbolicContinuousCallbacks (cb:: SymbolicContinuousCallback ) = [cb]
@@ -199,15 +270,28 @@ function reinitialization_algs(cbs::Vector{SymbolicContinuousCallback})
199
270
reinitialization_alg, vcat, cbs, init = SciMLBase. DAEInitializationAlgorithm[])
200
271
end
201
272
273
+ initialize_affects (cb:: SymbolicContinuousCallback ) = cb. initialize
274
+ function initialize_affects (cbs:: Vector{SymbolicContinuousCallback} )
275
+ mapreduce (initialize_affects, vcat, cbs, init = Equation[])
276
+ end
277
+
278
+ finalize_affects (cb:: SymbolicContinuousCallback ) = cb. initialize
279
+ function finalize_affects (cbs:: Vector{SymbolicContinuousCallback} )
280
+ mapreduce (finalize_affects, vcat, cbs, init = Equation[])
281
+ end
282
+
202
283
namespace_affects (af:: Vector , s) = Equation[namespace_affect (a, s) for a in af]
203
284
namespace_affects (af:: FunctionalAffect , s) = namespace_affect (af, s)
204
285
namespace_affects (:: Nothing , s) = nothing
205
286
206
287
function namespace_callback (cb:: SymbolicContinuousCallback , s):: SymbolicContinuousCallback
207
- SymbolicContinuousCallback (
208
- namespace_equation .(equations (cb), (s,)),
209
- namespace_affects (affects (cb), s);
210
- affect_neg = namespace_affects (affect_negs (cb), s))
288
+ SymbolicContinuousCallback (;
289
+ eqs = namespace_equation .(equations (cb), (s,)),
290
+ affect = namespace_affects (affects (cb), s),
291
+ affect_neg = namespace_affects (affect_negs (cb), s),
292
+ initialize = namespace_affects (initialize_affects (cb), s),
293
+ finalize = namespace_affects (finalize_affects (cb), s),
294
+ rootfind = cb. rootfind)
211
295
end
212
296
213
297
"""
@@ -589,22 +673,25 @@ function generate_single_rootfinding_callback(
589
673
rf_oop (u, parameter_values (integ), t)
590
674
end
591
675
end
592
-
676
+ user_initfun = (affect_function. initialize == NULL_AFFECT) ? SciMLBase. INITIALIZE_DEFAULT :
677
+ (c, u, t, i) -> affect_function. initialize (i)
593
678
if has_index_cache (sys) && (ic = get_index_cache (sys)) != = nothing &&
594
679
(save_idxs = get (ic. callback_to_clocks, cb, nothing )) != = nothing
595
680
initfn = let save_idxs = save_idxs
596
681
function (cb, u, t, integrator)
682
+ user_initfun (cb, u, t, integrator)
597
683
for idx in save_idxs
598
684
SciMLBase. save_discretes! (integrator, idx)
599
685
end
600
686
end
601
687
end
602
688
else
603
- initfn = SciMLBase . INITIALIZE_DEFAULT
689
+ initfn = user_initfun
604
690
end
605
691
return ContinuousCallback (
606
692
cond, affect_function. affect, affect_function. affect_neg, rootfind = cb. rootfind,
607
693
initialize = initfn,
694
+ finalize = (affect_function. finalize == NULL_AFFECT) ? SciMLBase. FINALIZE_DEFAULT : (c, u, t, i) -> affect_function. finalize (i),
608
695
initializealg = reinitialization_alg (cb))
609
696
end
610
697
@@ -626,13 +713,14 @@ function generate_vector_rootfinding_callback(
626
713
_, rf_ip = generate_custom_function (
627
714
sys, rhss, dvs, ps; expression = Val{false }, kwargs... )
628
715
629
- affect_functions = @NamedTuple {affect:: Function , affect_neg:: Union{Function, Nothing} }[compile_affect_fn (
630
- cb,
631
- sys,
632
- dvs,
633
- ps,
634
- kwargs)
635
- for cb in cbs]
716
+ affect_functions = @NamedTuple {
717
+ affect:: Function ,
718
+ affect_neg:: Union{Function, Nothing} ,
719
+ initialize:: Union{Function, Nothing} ,
720
+ finalize:: Union{Function, Nothing} }[
721
+ compile_affect_fn (cb, sys, dvs, ps, kwargs)
722
+ for cb in cbs]
723
+
636
724
cond = function (out, u, t, integ)
637
725
rf_ip (out, u, parameter_values (integ), t)
638
726
end
@@ -658,26 +746,31 @@ function generate_vector_rootfinding_callback(
658
746
affect_neg (integ)
659
747
end
660
748
end
661
- if has_index_cache (sys) && (ic = get_index_cache (sys)) != = nothing
662
- save_idxs = mapreduce (
663
- cb -> get (ic. callback_to_clocks, cb, Int[]), vcat, cbs; init = Int[])
664
- initfn = if isempty (save_idxs)
665
- SciMLBase. INITIALIZE_DEFAULT
749
+ function handle_optional_setup_fn (funs, default)
750
+ if all (isnothing, funs)
751
+ return default
666
752
else
667
- let save_idxs = save_idxs
668
- function (cb, u, t, integrator)
669
- for idx in save_idxs
670
- SciMLBase. save_discretes! (integrator, idx)
753
+ return let funs = funs
754
+ function (cb, u, t, integ)
755
+ for func in funs
756
+ if isnothing (func)
757
+ continue
758
+ else
759
+ func (integ)
760
+ end
671
761
end
672
762
end
673
763
end
674
764
end
675
- else
676
- initfn = SciMLBase. INITIALIZE_DEFAULT
677
765
end
766
+
767
+ initialize = handle_optional_setup_fn (
768
+ map (fn -> fn. initialize, affect_functions), SciMLBase. INITIALIZE_DEFAULT)
769
+ finalize = handle_optional_setup_fn (
770
+ map (fn -> fn. finalize, affect_functions), SciMLBase. FINALIZE_DEFAULT)
678
771
return VectorContinuousCallback (
679
772
cond, affect, affect_neg, length (eqs), rootfind = rootfind,
680
- initialize = initfn , initializealg = reinitialization)
773
+ initialize = initialize, finalize = finalize , initializealg = reinitialization)
681
774
end
682
775
683
776
"""
@@ -687,15 +780,21 @@ function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
687
780
eq_aff = affects (cb)
688
781
eq_neg_aff = affect_negs (cb)
689
782
affect = compile_affect (eq_aff, cb, sys, dvs, ps; expression = Val{false }, kwargs... )
783
+ function compile_optional_affect (aff, default = nothing )
784
+ if isnothing (aff) || aff == default
785
+ return nothing
786
+ else
787
+ return compile_affect (aff, cb, sys, dvs, ps; expression = Val{false }, kwargs... )
788
+ end
789
+ end
690
790
if eq_neg_aff === eq_aff
691
791
affect_neg = affect
692
- elseif isnothing (eq_neg_aff)
693
- affect_neg = nothing
694
792
else
695
- affect_neg = compile_affect (
696
- eq_neg_aff, cb, sys, dvs, ps; expression = Val{false }, kwargs... )
793
+ affect_neg = compile_optional_affect (eq_neg_aff)
697
794
end
698
- (affect = affect, affect_neg = affect_neg)
795
+ initialize = compile_optional_affect (initialize_affects (cb), NULL_AFFECT)
796
+ finalize = compile_optional_affect (finalize_affects (cb), NULL_AFFECT)
797
+ (affect = affect, affect_neg = affect_neg, initialize = initialize, finalize = finalize)
699
798
end
700
799
701
800
function generate_rootfinding_callback (cbs, sys:: AbstractODESystem , dvs = unknowns (sys),
0 commit comments