@@ -200,7 +200,9 @@ Affects (i.e. `affect` and `affect_neg`) can be specified as either:
200200 + `ctx` is a user-defined context object passed to `f!` when invoked. This value is aliased for each problem.
201201* A [`ImperativeAffect`](@ref); refer to its documentation for details.
202202
203- DAEs will automatically be reinitialized.
203+ `reinitializealg` is used to set how the system will be reinitialized after the callback.
204+ - Symbolic affects have reinitialization built in. In this case the algorithm will default to SciMLBase.NoInit(), and should **not** be provided.
205+ - Functional and imperative affects will default to SciMLBase.CheckInit(), which will error if the system is not properly reinitialized after the callback. If your system is a DAE, pass in an algorithm like SciMLBase.BrownBasicFullInit() to properly re-initialize.
204206
205207Initial and final affects can also be specified identically to positive and negative edge affects. Initialization affects
206208will run as soon as the solver starts, while finalization affects will be executed after termination.
@@ -212,6 +214,7 @@ struct SymbolicContinuousCallback <: AbstractCallback
212214 initialize:: Union{Affect, Nothing}
213215 finalize:: Union{Affect, Nothing}
214216 rootfind:: Union{Nothing, SciMLBase.RootfindOpt}
217+ reinitializealg:: SciMLBase.DAEInitializationAlgorithm
215218
216219 function SymbolicContinuousCallback (
217220 conditions:: Union{Equation, Vector{Equation}} ,
@@ -221,13 +224,21 @@ struct SymbolicContinuousCallback <: AbstractCallback
221224 initialize = nothing ,
222225 finalize = nothing ,
223226 rootfind = SciMLBase. LeftRootFind,
227+ reinitializealg = nothing ,
224228 iv = nothing ,
225229 algeeqs = Equation[])
226230 conditions = (conditions isa AbstractVector) ? conditions : [conditions]
231+
232+ if isnothing (reinitializealg)
233+ any (a -> (a isa FunctionalAffect || a isa ImperativeAffect), [affect, affect_neg, initialize, finalize]) ?
234+ reinitializealg = SciMLBase. CheckInit () :
235+ reinitializealg = SciMLBase. NoInit ()
236+ end
237+
227238 new (conditions, make_affect (affect; iv, algeeqs, discrete_parameters),
228239 make_affect (affect_neg; iv, algeeqs, discrete_parameters),
229240 make_affect (initialize; iv, algeeqs, discrete_parameters), make_affect (
230- finalize; iv, algeeqs, discrete_parameters), rootfind)
241+ finalize; iv, algeeqs, discrete_parameters), rootfind, reinitializealg )
231242 end # Default affect to nothing
232243end
233244
@@ -424,16 +435,22 @@ struct SymbolicDiscreteCallback <: AbstractCallback
424435 affect:: Union{Affect, Nothing}
425436 initialize:: Union{Affect, Nothing}
426437 finalize:: Union{Affect, Nothing}
438+ reinitializealg:: SciMLBase.DAEInitializationAlgorithm
427439
428440 function SymbolicDiscreteCallback (
429441 condition, affect = nothing ;
430442 initialize = nothing , finalize = nothing , iv = nothing ,
431- algeeqs = Equation[], discrete_parameters = Any[])
443+ algeeqs = Equation[], discrete_parameters = Any[], reinitializealg = nothing )
432444 c = is_timed_condition (condition) ? condition : value (scalarize (condition))
433445
446+ if isnothing (reinitializealg)
447+ any (a -> (a isa FunctionalAffect || a isa ImperativeAffect), [affect, affect_neg, initialize, finalize]) ?
448+ reinitializealg = SciMLBase. CheckInit () :
449+ reinitializealg = SciMLBase. NoInit ()
450+ end
434451 new (c, make_affect (affect; iv, algeeqs, discrete_parameters),
435452 make_affect (initialize; iv, algeeqs, discrete_parameters),
436- make_affect (finalize; iv, algeeqs, discrete_parameters))
453+ make_affect (finalize; iv, algeeqs, discrete_parameters), reinitializealg )
437454 end # Default affect to nothing
438455end
439456
@@ -525,7 +542,8 @@ function Base.hash(cb::AbstractCallback, s::UInt)
525542 ! is_discrete (cb) && (s = hash (affect_negs (cb), s))
526543 s = hash (initialize_affects (cb), s)
527544 s = hash (finalize_affects (cb), s)
528- ! is_discrete (cb) ? hash (cb. rootfind, s) : s
545+ ! is_discrete (cb) && (s = hash (cb. rootfind, s))
546+ hash (cb. reinitializealg, s)
529547end
530548
531549# ##########################
562580function Base.:(== )(e1:: AbstractCallback , e2:: AbstractCallback )
563581 (is_discrete (e1) === is_discrete (e2)) || return false
564582 (isequal (e1. conditions, e2. conditions) && isequal (e1. affect, e2. affect) &&
565- isequal (e1. initialize, e2. initialize) && isequal (e1. finalize, e2. finalize)) ||
583+ isequal (e1. initialize, e2. initialize) && isequal (e1. finalize, e2. finalize)) && isequal (e1 . reinitializealg, e2 . reinitializealg) ||
566584 return false
567585 is_discrete (e1) ||
568586 (isequal (e1. affect_neg, e2. affect_neg) && isequal (e1. rootfind, e2. rootfind))
@@ -664,15 +682,15 @@ function generate_continuous_callbacks(sys::AbstractSystem, dvs = unknowns(sys),
664682 ps = parameters (sys; initial_parameters = true ); kwargs... )
665683 cbs = continuous_events (sys)
666684 isempty (cbs) && return nothing
667- cb_classes = Dict {SciMLBase.RootfindOpt, Vector{SymbolicContinuousCallback}} ()
685+ cb_classes = Dict {Tuple{ SciMLBase.RootfindOpt, SciMLBase.DAEReinitializationAlg} , Vector{SymbolicContinuousCallback}} ()
668686
669687 # Sort the callbacks by their rootfinding method
670688 for cb in cbs
671- _cbs = get! (() -> SymbolicContinuousCallback[], cb_classes, cb. rootfind)
689+ _cbs = get! (() -> SymbolicContinuousCallback[], cb_classes, ( cb. rootfind, cb . reinitializealg) )
672690 push! (_cbs, cb)
673691 end
674- cb_classes = sort! (OrderedDict (cb_classes))
675- compiled_callbacks = [generate_callback (cb, sys; kwargs... ) for (rf , cb) in cb_classes]
692+ sort! (OrderedDict (cb_classes), by = cb -> cb . rootfind )
693+ compiled_callbacks = [generate_callback (cb, sys; kwargs... ) for ((rf, reinit) , cb) in cb_classes]
676694 if length (compiled_callbacks) == 1
677695 return only (compiled_callbacks)
678696 else
@@ -741,7 +759,7 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
741759
742760 return VectorContinuousCallback (
743761 trigger, affect, affect_neg, length (eqs); initialize, finalize,
744- rootfind = cbs[1 ]. rootfind, initializealg = SciMLBase . NoInit () )
762+ rootfind = cbs[1 ]. rootfind, initializealg = cbs[ 1 ] . reinitializealg )
745763end
746764
747765function generate_callback (cb, sys; kwargs... )
@@ -768,16 +786,16 @@ function generate_callback(cb, sys; kwargs...)
768786 if is_discrete (cb)
769787 if is_timed && conditions (cb) isa AbstractVector
770788 return PresetTimeCallback (trigger, affect; initialize,
771- finalize, initializealg = SciMLBase . NoInit () )
789+ finalize, initializealg = cb . reinitializealg )
772790 elseif is_timed
773- return PeriodicCallback (affect, trigger; initialize, finalize, initializealg = SciMLBase . NoInit () )
791+ return PeriodicCallback (affect, trigger; initialize, finalize, initializealg = cb . reinitializealg )
774792 else
775793 return DiscreteCallback (trigger, affect; initialize,
776- finalize, initializealg = SciMLBase . NoInit () )
794+ finalize, initializealg = cb . reinitializealg )
777795 end
778796 else
779797 return ContinuousCallback (trigger, affect, affect_neg; initialize, finalize,
780- rootfind = cb. rootfind, initializealg = SciMLBase . NoInit () )
798+ rootfind = cb. rootfind, initializealg = cb . reinitializealg )
781799 end
782800end
783801
0 commit comments