@@ -106,15 +106,25 @@ Affects (i.e. `affect` and `affect_neg`) can be specified as either:
106106 + `read_parameters` is a vector of the parameters that are *used* by `f!`. Their indices are passed to `f` in `p` similarly to the indices of `unknowns` passed in `u`.
107107 + `modified_parameters` is a vector of the parameters that are *modified* by `f!`. Note that a parameter will not appear in `p` if it only appears in `modified_parameters`; it must appear in both `parameters` and `modified_parameters` if it is used in the affect definition.
108108 + `ctx` is a user-defined context object passed to `f!` when invoked. This value is aliased for each problem.
109+
110+ Callbacks that impact a DAE are applied, then the DAE is reinitialized using `reinitializealg` (which defaults to `SciMLBase.CheckInit`).
111+ This reinitialization algorithm ensures that the DAE is satisfied after the callback runs. The default value of `CheckInit` will simply validate
112+ that the newly-assigned values indeed satisfy the algebraic system; see the documentation on DAE initialization for a more detailed discussion of
113+ initialization.
109114"""
110115struct SymbolicContinuousCallback
111116 eqs:: Vector{Equation}
112117 affect:: Union{Vector{Equation}, FunctionalAffect}
113118 affect_neg:: Union{Vector{Equation}, FunctionalAffect, Nothing}
114119 rootfind:: SciMLBase.RootfindOpt
115- function SymbolicContinuousCallback (; eqs:: Vector{Equation} , affect = NULL_AFFECT,
116- affect_neg = affect, rootfind = SciMLBase. LeftRootFind)
117- new (eqs, make_affect (affect), make_affect (affect_neg), rootfind)
120+ reinitializealg:: SciMLBase.DAEInitializationAlgorithm
121+ function SymbolicContinuousCallback (;
122+ eqs:: Vector{Equation} ,
123+ affect = NULL_AFFECT,
124+ affect_neg = affect,
125+ rootfind = SciMLBase. LeftRootFind,
126+ reinitializealg= SciMLBase. CheckInit ())
127+ new (eqs, make_affect (affect), make_affect (affect_neg), rootfind, reinitializealg)
118128 end # Default affect to nothing
119129end
120130make_affect (affect) = affect
@@ -183,6 +193,10 @@ function affect_negs(cbs::Vector{SymbolicContinuousCallback})
183193 mapreduce (affect_negs, vcat, cbs, init = Equation[])
184194end
185195
196+ reinitialization_alg (cb:: SymbolicContinuousCallback ) = cb. reinitializealg
197+ reinitialization_algs (cbs:: Vector{SymbolicContinuousCallback} ) =
198+ mapreduce (reinitialization_alg, vcat, cbs, init = SciMLBase. DAEInitializationAlgorithm[])
199+
186200namespace_affects (af:: Vector , s) = Equation[namespace_affect (a, s) for a in af]
187201namespace_affects (af:: FunctionalAffect , s) = namespace_affect (af, s)
188202namespace_affects (:: Nothing , s) = nothing
@@ -225,11 +239,12 @@ struct SymbolicDiscreteCallback
225239 # TODO : Iterative
226240 condition:: Any
227241 affects:: Any
242+ reinitializealg:: SciMLBase.DAEInitializationAlgorithm
228243
229- function SymbolicDiscreteCallback (condition, affects = NULL_AFFECT)
244+ function SymbolicDiscreteCallback (condition, affects = NULL_AFFECT, reinitializealg = SciMLBase . CheckInit () )
230245 c = scalarize_condition (condition)
231246 a = scalarize_affects (affects)
232- new (c, a)
247+ new (c, a, reinitializealg )
233248 end # Default affect to nothing
234249end
235250
@@ -286,6 +301,10 @@ function affects(cbs::Vector{SymbolicDiscreteCallback})
286301 reduce (vcat, affects (cb) for cb in cbs; init = [])
287302end
288303
304+ reinitialization_alg (cb:: SymbolicDiscreteCallback ) = cb. reinitializealg
305+ reinitialization_algs (cbs:: Vector{SymbolicDiscreteCallback} ) =
306+ mapreduce (reinitialization_alg, vcat, cbs, init = SciMLBase. DAEInitializationAlgorithm[])
307+
289308function namespace_callback (cb:: SymbolicDiscreteCallback , s):: SymbolicDiscreteCallback
290309 af = affects (cb)
291310 af = af isa AbstractVector ? namespace_affect .(af, Ref (s)) : namespace_affect (af, s)
@@ -579,13 +598,14 @@ function generate_single_rootfinding_callback(
579598 initfn = SciMLBase. INITIALIZE_DEFAULT
580599 end
581600 return ContinuousCallback (
582- cond, affect_function. affect, affect_function. affect_neg,
583- rootfind = cb. rootfind, initialize = initfn)
601+ cond, affect_function. affect, affect_function. affect_neg, rootfind = cb. rootfind,
602+ initialize = initfn,
603+ initializealg = reinitialization_alg (cb))
584604end
585605
586606function generate_vector_rootfinding_callback (
587607 cbs, sys:: AbstractODESystem , dvs = unknowns (sys),
588- ps = parameters (sys); rootfind = SciMLBase. RightRootFind, kwargs... )
608+ ps = parameters (sys); rootfind = SciMLBase. RightRootFind, reinitialization = SciMLBase . CheckInit (), kwargs... )
589609 eqs = map (cb -> flatten_equations (cb. eqs), cbs)
590610 num_eqs = length .(eqs)
591611 # fuse equations to create VectorContinuousCallback
@@ -650,7 +670,7 @@ function generate_vector_rootfinding_callback(
650670 initfn = SciMLBase. INITIALIZE_DEFAULT
651671 end
652672 return VectorContinuousCallback (
653- cond, affect, affect_neg, length (eqs), rootfind = rootfind, initialize = initfn)
673+ cond, affect, affect_neg, length (eqs), rootfind = rootfind, initialize = initfn, initializealg = reinitialization )
654674end
655675
656676"""
@@ -690,18 +710,22 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
690710 # group the cbs by what rootfind op they use
691711 # groupby would be very useful here, but alas
692712 cb_classes = Dict{
693- @NamedTuple {rootfind:: SciMLBase.RootfindOpt }, Vector{SymbolicContinuousCallback}}()
713+ @NamedTuple {
714+ rootfind:: SciMLBase.RootfindOpt ,
715+ reinitialization:: SciMLBase.DAEInitializationAlgorithm }, Vector{SymbolicContinuousCallback}}()
694716 for cb in cbs
695717 push! (
696- get! (() -> SymbolicContinuousCallback[], cb_classes, (rootfind = cb. rootfind,)),
718+ get! (() -> SymbolicContinuousCallback[], cb_classes, (
719+ rootfind = cb. rootfind,
720+ reinitialization = reinitialization_alg (cb))),
697721 cb)
698722 end
699723
700724 # generate the callbacks out; we sort by the equivalence class to ensure a deterministic preference order
701725 compiled_callbacks = map (collect (pairs (sort! (
702726 OrderedDict (cb_classes); by = p -> p. rootfind)))) do (equiv_class, cbs_in_class)
703727 return generate_vector_rootfinding_callback (
704- cbs_in_class, sys, dvs, ps; rootfind = equiv_class. rootfind, kwargs... )
728+ cbs_in_class, sys, dvs, ps; rootfind = equiv_class. rootfind, reinitialization = equiv_class . reinitialization, kwargs... )
705729 end
706730 if length (compiled_callbacks) == 1
707731 return compiled_callbacks[]
@@ -772,10 +796,10 @@ function generate_timed_callback(cb, sys, dvs, ps; postprocess_affect_expr! = no
772796 end
773797 if cond isa AbstractVector
774798 # Preset Time
775- return PresetTimeCallback (cond, as; initialize = initfn)
799+ return PresetTimeCallback (cond, as; initialize = initfn, initializealg = reinitialization_alg (cb) )
776800 else
777801 # Periodic
778- return PeriodicCallback (as, cond; initialize = initfn)
802+ return PeriodicCallback (as, cond; initialize = initfn, initializealg = reinitialization_alg (cb) )
779803 end
780804end
781805
@@ -800,7 +824,7 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! =
800824 else
801825 initfn = SciMLBase. INITIALIZE_DEFAULT
802826 end
803- return DiscreteCallback (c, as; initialize = initfn)
827+ return DiscreteCallback (c, as; initialize = initfn, initializealg = reinitialization_alg (cb) )
804828 end
805829end
806830
0 commit comments