@@ -106,15 +106,25 @@ Affects (i.e. `affect` and `affect_neg`) can be specified as either:
106106 + `read_parameters` is a vector of the parameters that are *used* by `f!`. Their indices are passed to `f` in `p` similarly to the indices of `unknowns` passed in `u`.
107107 + `modified_parameters` is a vector of the parameters that are *modified* by `f!`. Note that a parameter will not appear in `p` if it only appears in `modified_parameters`; it must appear in both `parameters` and `modified_parameters` if it is used in the affect definition.
108108 + `ctx` is a user-defined context object passed to `f!` when invoked. This value is aliased for each problem.
109+
110+ DAEs will be reinitialized using `reinitializealg` (which defaults to `SciMLBase.CheckInit`) after callbacks are applied.
111+ This reinitialization algorithm ensures that the DAE is satisfied after the callback runs. The default value of `CheckInit` will simply validate
112+ that the newly-assigned values indeed satisfy the algebraic system; see the documentation on DAE initialization for a more detailed discussion of
113+ initialization.
109114"""
110115struct SymbolicContinuousCallback
111116 eqs:: Vector{Equation}
112117 affect:: Union{Vector{Equation}, FunctionalAffect}
113118 affect_neg:: Union{Vector{Equation}, FunctionalAffect, Nothing}
114119 rootfind:: SciMLBase.RootfindOpt
115- function SymbolicContinuousCallback (; eqs:: Vector{Equation} , affect = NULL_AFFECT,
116- affect_neg = affect, rootfind = SciMLBase. LeftRootFind)
117- new (eqs, make_affect (affect), make_affect (affect_neg), rootfind)
120+ reinitializealg:: SciMLBase.DAEInitializationAlgorithm
121+ function SymbolicContinuousCallback (;
122+ eqs:: Vector{Equation} ,
123+ affect = NULL_AFFECT,
124+ affect_neg = affect,
125+ rootfind = SciMLBase. LeftRootFind,
126+ reinitializealg = SciMLBase. CheckInit ())
127+ new (eqs, make_affect (affect), make_affect (affect_neg), rootfind, reinitializealg)
118128 end # Default affect to nothing
119129end
120130make_affect (affect) = affect
@@ -183,6 +193,12 @@ function affect_negs(cbs::Vector{SymbolicContinuousCallback})
183193 mapreduce (affect_negs, vcat, cbs, init = Equation[])
184194end
185195
196+ reinitialization_alg (cb:: SymbolicContinuousCallback ) = cb. reinitializealg
197+ function reinitialization_algs (cbs:: Vector{SymbolicContinuousCallback} )
198+ mapreduce (
199+ reinitialization_alg, vcat, cbs, init = SciMLBase. DAEInitializationAlgorithm[])
200+ end
201+
186202namespace_affects (af:: Vector , s) = Equation[namespace_affect (a, s) for a in af]
187203namespace_affects (af:: FunctionalAffect , s) = namespace_affect (af, s)
188204namespace_affects (:: Nothing , s) = nothing
@@ -225,11 +241,13 @@ struct SymbolicDiscreteCallback
225241 # TODO : Iterative
226242 condition:: Any
227243 affects:: Any
244+ reinitializealg:: SciMLBase.DAEInitializationAlgorithm
228245
229- function SymbolicDiscreteCallback (condition, affects = NULL_AFFECT)
246+ function SymbolicDiscreteCallback (
247+ condition, affects = NULL_AFFECT, reinitializealg = SciMLBase. CheckInit ())
230248 c = scalarize_condition (condition)
231249 a = scalarize_affects (affects)
232- new (c, a)
250+ new (c, a, reinitializealg )
233251 end # Default affect to nothing
234252end
235253
@@ -286,6 +304,12 @@ function affects(cbs::Vector{SymbolicDiscreteCallback})
286304 reduce (vcat, affects (cb) for cb in cbs; init = [])
287305end
288306
307+ reinitialization_alg (cb:: SymbolicDiscreteCallback ) = cb. reinitializealg
308+ function reinitialization_algs (cbs:: Vector{SymbolicDiscreteCallback} )
309+ mapreduce (
310+ reinitialization_alg, vcat, cbs, init = SciMLBase. DAEInitializationAlgorithm[])
311+ end
312+
289313function namespace_callback (cb:: SymbolicDiscreteCallback , s):: SymbolicDiscreteCallback
290314 af = affects (cb)
291315 af = af isa AbstractVector ? namespace_affect .(af, Ref (s)) : namespace_affect (af, s)
@@ -579,13 +603,15 @@ function generate_single_rootfinding_callback(
579603 initfn = SciMLBase. INITIALIZE_DEFAULT
580604 end
581605 return ContinuousCallback (
582- cond, affect_function. affect, affect_function. affect_neg,
583- rootfind = cb. rootfind, initialize = initfn)
606+ cond, affect_function. affect, affect_function. affect_neg, rootfind = cb. rootfind,
607+ initialize = initfn,
608+ initializealg = reinitialization_alg (cb))
584609end
585610
586611function generate_vector_rootfinding_callback (
587612 cbs, sys:: AbstractODESystem , dvs = unknowns (sys),
588- ps = parameters (sys); rootfind = SciMLBase. RightRootFind, kwargs... )
613+ ps = parameters (sys); rootfind = SciMLBase. RightRootFind,
614+ reinitialization = SciMLBase. CheckInit (), kwargs... )
589615 eqs = map (cb -> flatten_equations (cb. eqs), cbs)
590616 num_eqs = length .(eqs)
591617 # fuse equations to create VectorContinuousCallback
@@ -650,7 +676,8 @@ function generate_vector_rootfinding_callback(
650676 initfn = SciMLBase. INITIALIZE_DEFAULT
651677 end
652678 return VectorContinuousCallback (
653- cond, affect, affect_neg, length (eqs), rootfind = rootfind, initialize = initfn)
679+ cond, affect, affect_neg, length (eqs), rootfind = rootfind,
680+ initialize = initfn, initializealg = reinitialization)
654681end
655682
656683"""
@@ -690,18 +717,24 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
690717 # group the cbs by what rootfind op they use
691718 # groupby would be very useful here, but alas
692719 cb_classes = Dict{
693- @NamedTuple {rootfind:: SciMLBase.RootfindOpt }, Vector{SymbolicContinuousCallback}}()
720+ @NamedTuple {
721+ rootfind:: SciMLBase.RootfindOpt ,
722+ reinitialization:: SciMLBase.DAEInitializationAlgorithm }, Vector{SymbolicContinuousCallback}}()
694723 for cb in cbs
695724 push! (
696- get! (() -> SymbolicContinuousCallback[], cb_classes, (rootfind = cb. rootfind,)),
725+ get! (() -> SymbolicContinuousCallback[], cb_classes,
726+ (
727+ rootfind = cb. rootfind,
728+ reinitialization = reinitialization_alg (cb))),
697729 cb)
698730 end
699731
700732 # generate the callbacks out; we sort by the equivalence class to ensure a deterministic preference order
701733 compiled_callbacks = map (collect (pairs (sort! (
702734 OrderedDict (cb_classes); by = p -> p. rootfind)))) do (equiv_class, cbs_in_class)
703735 return generate_vector_rootfinding_callback (
704- cbs_in_class, sys, dvs, ps; rootfind = equiv_class. rootfind, kwargs... )
736+ cbs_in_class, sys, dvs, ps; rootfind = equiv_class. rootfind,
737+ reinitialization = equiv_class. reinitialization, kwargs... )
705738 end
706739 if length (compiled_callbacks) == 1
707740 return compiled_callbacks[]
@@ -772,10 +805,12 @@ function generate_timed_callback(cb, sys, dvs, ps; postprocess_affect_expr! = no
772805 end
773806 if cond isa AbstractVector
774807 # Preset Time
775- return PresetTimeCallback (cond, as; initialize = initfn)
808+ return PresetTimeCallback (
809+ cond, as; initialize = initfn, initializealg = reinitialization_alg (cb))
776810 else
777811 # Periodic
778- return PeriodicCallback (as, cond; initialize = initfn)
812+ return PeriodicCallback (
813+ as, cond; initialize = initfn, initializealg = reinitialization_alg (cb))
779814 end
780815end
781816
@@ -800,7 +835,8 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! =
800835 else
801836 initfn = SciMLBase. INITIALIZE_DEFAULT
802837 end
803- return DiscreteCallback (c, as; initialize = initfn)
838+ return DiscreteCallback (
839+ c, as; initialize = initfn, initializealg = reinitialization_alg (cb))
804840 end
805841end
806842
0 commit comments