@@ -216,6 +216,11 @@ Affects (i.e. `affect` and `affect_neg`) can be specified as either:
216216 + `modified_parameters` is a vector of the parameters that are *modified* by `f!`. Note that a parameter will not appear in `p` if it only appears in `modified_parameters`; it must appear in both `parameters` and `modified_parameters` if it is used in the affect definition.
217217 + `ctx` is a user-defined context object passed to `f!` when invoked. This value is aliased for each problem.
218218* A [`MutatingFunctionalAffect`](@ref); refer to its documentation for details.
219+
220+ Callbacks that impact a DAE are applied, then the DAE is reinitialized using `reinitializealg` (which defaults to `SciMLBase.CheckInit`).
221+ This reinitialization algorithm ensures that the DAE is satisfied after the callback runs. The default value of `CheckInit` will simply validate
222+ that the newly-assigned values indeed satisfy the algebraic system; see the documentation on DAE initialization for a more detailed discussion of
223+ initialization.
219224"""
220225struct SymbolicContinuousCallback
221226 eqs:: Vector{Equation}
@@ -224,14 +229,16 @@ struct SymbolicContinuousCallback
224229 affect:: Union{Vector{Equation}, FunctionalAffect, MutatingFunctionalAffect}
225230 affect_neg:: Union{Vector{Equation}, FunctionalAffect, MutatingFunctionalAffect, Nothing}
226231 rootfind:: SciMLBase.RootfindOpt
232+ reinitializealg:: SciMLBase.DAEInitializationAlgorithm
227233 function SymbolicContinuousCallback (;
228234 eqs:: Vector{Equation} ,
229235 affect = NULL_AFFECT,
230236 affect_neg = affect,
231237 rootfind = SciMLBase. LeftRootFind,
232238 initialize= NULL_AFFECT,
233- finalize= NULL_AFFECT)
234- new (eqs, initialize, finalize, make_affect (affect), make_affect (affect_neg), rootfind)
239+ finalize= NULL_AFFECT,
240+ reinitializealg= SciMLBase. CheckInit ())
241+ new (eqs, initialize, finalize, make_affect (affect), make_affect (affect_neg), rootfind, reinitializealg)
235242 end # Default affect to nothing
236243end
237244make_affect (affect) = affect
@@ -373,6 +380,10 @@ function finalize_affects(cbs::Vector{SymbolicContinuousCallback})
373380 mapreduce (finalize_affects, vcat, cbs, init = Equation[])
374381end
375382
383+ reinitialization_alg (cb:: SymbolicContinuousCallback ) = cb. reinitializealg
384+ reinitialization_algs (cbs:: Vector{SymbolicContinuousCallback} ) =
385+ mapreduce (reinitialization_alg, vcat, cbs, init = SciMLBase. DAEInitializationAlgorithm[])
386+
376387namespace_affects (af:: Vector , s) = Equation[namespace_affect (a, s) for a in af]
377388namespace_affects (af:: FunctionalAffect , s) = namespace_affect (af, s)
378389namespace_affects (af:: MutatingFunctionalAffect , s) = namespace_affect (af, s)
@@ -419,11 +430,12 @@ struct SymbolicDiscreteCallback
419430 # TODO : Iterative
420431 condition:: Any
421432 affects:: Any
433+ reinitializealg:: SciMLBase.DAEInitializationAlgorithm
422434
423- function SymbolicDiscreteCallback (condition, affects = NULL_AFFECT)
435+ function SymbolicDiscreteCallback (condition, affects = NULL_AFFECT, reinitializealg = SciMLBase . CheckInit () )
424436 c = scalarize_condition (condition)
425437 a = scalarize_affects (affects)
426- new (c, a)
438+ new (c, a, reinitializealg )
427439 end # Default affect to nothing
428440end
429441
@@ -481,6 +493,10 @@ function affects(cbs::Vector{SymbolicDiscreteCallback})
481493 reduce (vcat, affects (cb) for cb in cbs; init = [])
482494end
483495
496+ reinitialization_alg (cb:: SymbolicDiscreteCallback ) = cb. reinitializealg
497+ reinitialization_algs (cbs:: Vector{SymbolicDiscreteCallback} ) =
498+ mapreduce (reinitialization_alg, vcat, cbs, init = SciMLBase. DAEInitializationAlgorithm[])
499+
484500function namespace_callback (cb:: SymbolicDiscreteCallback , s):: SymbolicDiscreteCallback
485501 af = affects (cb)
486502 af = af isa AbstractVector ? namespace_affect .(af, Ref (s)) : namespace_affect (af, s)
@@ -776,12 +792,13 @@ function generate_single_rootfinding_callback(
776792 return ContinuousCallback (
777793 cond, affect_function. affect, affect_function. affect_neg, rootfind = cb. rootfind,
778794 initialize = isnothing (affect_function. initialize) ? SciMLBase. INITIALIZE_DEFAULT : (c, u, t, i) -> affect_function. initialize (i),
779- finalize = isnothing (affect_function. finalize) ? SciMLBase. FINALIZE_DEFAULT : (c, u, t, i) -> affect_function. finalize (i))
795+ finalize = isnothing (affect_function. finalize) ? SciMLBase. FINALIZE_DEFAULT : (c, u, t, i) -> affect_function. finalize (i),
796+ initializealg = reinitialization_alg (cb))
780797end
781798
782799function generate_vector_rootfinding_callback (
783800 cbs, sys:: AbstractODESystem , dvs = unknowns (sys),
784- ps = parameters (sys); rootfind = SciMLBase. RightRootFind, kwargs... )
801+ ps = parameters (sys); rootfind = SciMLBase. RightRootFind, reinitialization = SciMLBase . CheckInit (), kwargs... )
785802 eqs = map (cb -> flatten_equations (cb. eqs), cbs)
786803 num_eqs = length .(eqs)
787804 # fuse equations to create VectorContinuousCallback
@@ -847,7 +864,7 @@ function generate_vector_rootfinding_callback(
847864 initialize = handle_optional_setup_fn (map (fn -> fn. initialize, affect_functions), SciMLBase. INITIALIZE_DEFAULT)
848865 finalize = handle_optional_setup_fn (map (fn -> fn. finalize, affect_functions), SciMLBase. FINALIZE_DEFAULT)
849866 return VectorContinuousCallback (
850- cond, affect, affect_neg, length (eqs), rootfind = rootfind, initialize = initialize, finalize = finalize)
867+ cond, affect, affect_neg, length (eqs), rootfind = rootfind, initialize = initialize, finalize = finalize, initializealg = reinitialization )
851868end
852869
853870"""
@@ -893,18 +910,22 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
893910 # group the cbs by what rootfind op they use
894911 # groupby would be very useful here, but alas
895912 cb_classes = Dict{
896- @NamedTuple {rootfind:: SciMLBase.RootfindOpt }, Vector{SymbolicContinuousCallback}}()
913+ @NamedTuple {
914+ rootfind:: SciMLBase.RootfindOpt ,
915+ reinitialization:: SciMLBase.DAEInitializationAlgorithm }, Vector{SymbolicContinuousCallback}}()
897916 for cb in cbs
898917 push! (
899- get! (() -> SymbolicContinuousCallback[], cb_classes, (rootfind = cb. rootfind,)),
918+ get! (() -> SymbolicContinuousCallback[], cb_classes, (
919+ rootfind = cb. rootfind,
920+ reinitialization = reinitialization_alg (cb))),
900921 cb)
901922 end
902923
903924 # generate the callbacks out; we sort by the equivalence class to ensure a deterministic preference order
904925 compiled_callbacks = map (collect (pairs (sort! (
905926 OrderedDict (cb_classes); by = p -> p. rootfind)))) do (equiv_class, cbs_in_class)
906927 return generate_vector_rootfinding_callback (
907- cbs_in_class, sys, dvs, ps; rootfind = equiv_class. rootfind, kwargs... )
928+ cbs_in_class, sys, dvs, ps; rootfind = equiv_class. rootfind, reinitialization = equiv_class . reinitialization, kwargs... )
908929 end
909930 if length (compiled_callbacks) == 1
910931 return compiled_callbacks[]
0 commit comments