@@ -106,15 +106,25 @@ Affects (i.e. `affect` and `affect_neg`) can be specified as either:
106
106
+ `read_parameters` is a vector of the parameters that are *used* by `f!`. Their indices are passed to `f` in `p` similarly to the indices of `unknowns` passed in `u`.
107
107
+ `modified_parameters` is a vector of the parameters that are *modified* by `f!`. Note that a parameter will not appear in `p` if it only appears in `modified_parameters`; it must appear in both `parameters` and `modified_parameters` if it is used in the affect definition.
108
108
+ `ctx` is a user-defined context object passed to `f!` when invoked. This value is aliased for each problem.
109
+
110
+ Callbacks that impact a DAE are applied, then the DAE is reinitialized using `reinitializealg` (which defaults to `SciMLBase.CheckInit`).
111
+ This reinitialization algorithm ensures that the DAE is satisfied after the callback runs. The default value of `CheckInit` will simply validate
112
+ that the newly-assigned values indeed satisfy the algebraic system; see the documentation on DAE initialization for a more detailed discussion of
113
+ initialization.
109
114
"""
110
115
struct SymbolicContinuousCallback
111
116
eqs:: Vector{Equation}
112
117
affect:: Union{Vector{Equation}, FunctionalAffect}
113
118
affect_neg:: Union{Vector{Equation}, FunctionalAffect, Nothing}
114
119
rootfind:: SciMLBase.RootfindOpt
115
- function SymbolicContinuousCallback (; eqs:: Vector{Equation} , affect = NULL_AFFECT,
116
- affect_neg = affect, rootfind = SciMLBase. LeftRootFind)
117
- new (eqs, make_affect (affect), make_affect (affect_neg), rootfind)
120
+ reinitializealg:: SciMLBase.DAEInitializationAlgorithm
121
+ function SymbolicContinuousCallback (;
122
+ eqs:: Vector{Equation} ,
123
+ affect = NULL_AFFECT,
124
+ affect_neg = affect,
125
+ rootfind = SciMLBase. LeftRootFind,
126
+ reinitializealg= SciMLBase. CheckInit ())
127
+ new (eqs, make_affect (affect), make_affect (affect_neg), rootfind, reinitializealg)
118
128
end # Default affect to nothing
119
129
end
120
130
make_affect (affect) = affect
@@ -183,6 +193,10 @@ function affect_negs(cbs::Vector{SymbolicContinuousCallback})
183
193
mapreduce (affect_negs, vcat, cbs, init = Equation[])
184
194
end
185
195
196
+ reinitialization_alg (cb:: SymbolicContinuousCallback ) = cb. reinitializealg
197
+ reinitialization_algs (cbs:: Vector{SymbolicContinuousCallback} ) =
198
+ mapreduce (reinitialization_alg, vcat, cbs, init = SciMLBase. DAEInitializationAlgorithm[])
199
+
186
200
namespace_affects (af:: Vector , s) = Equation[namespace_affect (a, s) for a in af]
187
201
namespace_affects (af:: FunctionalAffect , s) = namespace_affect (af, s)
188
202
namespace_affects (:: Nothing , s) = nothing
@@ -225,11 +239,12 @@ struct SymbolicDiscreteCallback
225
239
# TODO : Iterative
226
240
condition:: Any
227
241
affects:: Any
242
+ reinitializealg:: SciMLBase.DAEInitializationAlgorithm
228
243
229
- function SymbolicDiscreteCallback (condition, affects = NULL_AFFECT)
244
+ function SymbolicDiscreteCallback (condition, affects = NULL_AFFECT, reinitializealg = SciMLBase . CheckInit () )
230
245
c = scalarize_condition (condition)
231
246
a = scalarize_affects (affects)
232
- new (c, a)
247
+ new (c, a, reinitializealg )
233
248
end # Default affect to nothing
234
249
end
235
250
@@ -286,6 +301,10 @@ function affects(cbs::Vector{SymbolicDiscreteCallback})
286
301
reduce (vcat, affects (cb) for cb in cbs; init = [])
287
302
end
288
303
304
+ reinitialization_alg (cb:: SymbolicDiscreteCallback ) = cb. reinitializealg
305
+ reinitialization_algs (cbs:: Vector{SymbolicDiscreteCallback} ) =
306
+ mapreduce (reinitialization_alg, vcat, cbs, init = SciMLBase. DAEInitializationAlgorithm[])
307
+
289
308
function namespace_callback (cb:: SymbolicDiscreteCallback , s):: SymbolicDiscreteCallback
290
309
af = affects (cb)
291
310
af = af isa AbstractVector ? namespace_affect .(af, Ref (s)) : namespace_affect (af, s)
@@ -579,13 +598,14 @@ function generate_single_rootfinding_callback(
579
598
initfn = SciMLBase. INITIALIZE_DEFAULT
580
599
end
581
600
return ContinuousCallback (
582
- cond, affect_function. affect, affect_function. affect_neg,
583
- rootfind = cb. rootfind, initialize = initfn)
601
+ cond, affect_function. affect, affect_function. affect_neg, rootfind = cb. rootfind,
602
+ initialize = initfn,
603
+ initializealg = reinitialization_alg (cb))
584
604
end
585
605
586
606
function generate_vector_rootfinding_callback (
587
607
cbs, sys:: AbstractODESystem , dvs = unknowns (sys),
588
- ps = parameters (sys); rootfind = SciMLBase. RightRootFind, kwargs... )
608
+ ps = parameters (sys); rootfind = SciMLBase. RightRootFind, reinitialization = SciMLBase . CheckInit (), kwargs... )
589
609
eqs = map (cb -> flatten_equations (cb. eqs), cbs)
590
610
num_eqs = length .(eqs)
591
611
# fuse equations to create VectorContinuousCallback
@@ -650,7 +670,7 @@ function generate_vector_rootfinding_callback(
650
670
initfn = SciMLBase. INITIALIZE_DEFAULT
651
671
end
652
672
return VectorContinuousCallback (
653
- cond, affect, affect_neg, length (eqs), rootfind = rootfind, initialize = initfn)
673
+ cond, affect, affect_neg, length (eqs), rootfind = rootfind, initialize = initfn, initializealg = reinitialization )
654
674
end
655
675
656
676
"""
@@ -690,18 +710,22 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
690
710
# group the cbs by what rootfind op they use
691
711
# groupby would be very useful here, but alas
692
712
cb_classes = Dict{
693
- @NamedTuple {rootfind:: SciMLBase.RootfindOpt }, Vector{SymbolicContinuousCallback}}()
713
+ @NamedTuple {
714
+ rootfind:: SciMLBase.RootfindOpt ,
715
+ reinitialization:: SciMLBase.DAEInitializationAlgorithm }, Vector{SymbolicContinuousCallback}}()
694
716
for cb in cbs
695
717
push! (
696
- get! (() -> SymbolicContinuousCallback[], cb_classes, (rootfind = cb. rootfind,)),
718
+ get! (() -> SymbolicContinuousCallback[], cb_classes, (
719
+ rootfind = cb. rootfind,
720
+ reinitialization = reinitialization_alg (cb))),
697
721
cb)
698
722
end
699
723
700
724
# generate the callbacks out; we sort by the equivalence class to ensure a deterministic preference order
701
725
compiled_callbacks = map (collect (pairs (sort! (
702
726
OrderedDict (cb_classes); by = p -> p. rootfind)))) do (equiv_class, cbs_in_class)
703
727
return generate_vector_rootfinding_callback (
704
- cbs_in_class, sys, dvs, ps; rootfind = equiv_class. rootfind, kwargs... )
728
+ cbs_in_class, sys, dvs, ps; rootfind = equiv_class. rootfind, reinitialization = equiv_class . reinitialization, kwargs... )
705
729
end
706
730
if length (compiled_callbacks) == 1
707
731
return compiled_callbacks[]
@@ -772,10 +796,10 @@ function generate_timed_callback(cb, sys, dvs, ps; postprocess_affect_expr! = no
772
796
end
773
797
if cond isa AbstractVector
774
798
# Preset Time
775
- return PresetTimeCallback (cond, as; initialize = initfn)
799
+ return PresetTimeCallback (cond, as; initialize = initfn, initializealg = reinitialization_alg (cb) )
776
800
else
777
801
# Periodic
778
- return PeriodicCallback (as, cond; initialize = initfn)
802
+ return PeriodicCallback (as, cond; initialize = initfn, initializealg = reinitialization_alg (cb) )
779
803
end
780
804
end
781
805
@@ -800,7 +824,7 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! =
800
824
else
801
825
initfn = SciMLBase. INITIALIZE_DEFAULT
802
826
end
803
- return DiscreteCallback (c, as; initialize = initfn)
827
+ return DiscreteCallback (c, as; initialize = initfn, initializealg = reinitialization_alg (cb) )
804
828
end
805
829
end
806
830
0 commit comments