@@ -217,7 +217,7 @@ struct SymbolicContinuousCallback <: AbstractCallback
217217 end # Default affect to nothing
218218end
219219
220- SymbolicContinuousCallback (p:: Pair , args... ; kwargs... ) = SymbolicContinuousCallback (p[1 ], p[2 ])
220+ SymbolicContinuousCallback (p:: Pair , args... ; kwargs... ) = SymbolicContinuousCallback (p[1 ], p[2 ], args ... ; kwargs ... )
221221SymbolicContinuousCallback (cb:: SymbolicContinuousCallback , args... ; kwargs... ) = cb
222222
223223make_affect (affect:: Nothing ; kwargs... ) = nothing
@@ -395,7 +395,7 @@ Arguments:
395395- algeeqs: Algebraic equations of the system that must be satisfied after the callback occurs.
396396"""
397397struct SymbolicDiscreteCallback <: AbstractCallback
398- conditions:: Union{Real, Vector{<:Real}, Vector{Equation}}
398+ conditions:: Any
399399 affect:: Union{Affect, Nothing}
400400 initialize:: Union{Affect, Nothing}
401401 finalize:: Union{Affect, Nothing}
@@ -410,7 +410,7 @@ struct SymbolicDiscreteCallback <: AbstractCallback
410410 end # Default affect to nothing
411411end
412412
413- SymbolicDiscreteCallback (p:: Pair , args... ; kwargs... ) = SymbolicDiscreteCallback (p[1 ], p[2 ])
413+ SymbolicDiscreteCallback (p:: Pair , args... ; kwargs... ) = SymbolicDiscreteCallback (p[1 ], p[2 ], args ... ; kwargs ... )
414414SymbolicDiscreteCallback (cb:: SymbolicDiscreteCallback , args... ; kwargs... ) = cb
415415
416416"""
630630"""
631631Compile user-defined functional affect.
632632"""
633- function compile_functional_affect (affect:: FunctionalAffect , cb, sys; kwargs... )
633+ function compile_functional_affect (affect:: FunctionalAffect , sys; kwargs... )
634634 dvs = unknowns (sys)
635635 ps = parameters (sys)
636636 dvs_ind = Dict (reverse (en) for en in enumerate (dvs))
@@ -639,11 +639,9 @@ function compile_functional_affect(affect::FunctionalAffect, cb, sys; kwargs...)
639639 if has_index_cache (sys) && (ic = get_index_cache (sys)) != = nothing
640640 p_inds = [(pind = parameter_index (sys, sym)) === nothing ? sym : pind
641641 for sym in parameters (affect)]
642- save_idxs = get (ic. callback_to_clocks, cb, Int[])
643642 else
644643 ps_ind = Dict (reverse (en) for en in enumerate (ps))
645644 p_inds = map (sym -> get (ps_ind, sym, sym), parameters (affect))
646- save_idxs = Int[]
647645 end
648646 # HACK: filter out eliminated symbols. Not clear this is the right thing to do
649647 # (MTK should keep these symbols)
@@ -652,13 +650,9 @@ function compile_functional_affect(affect::FunctionalAffect, cb, sys; kwargs...)
652650 p = filter (x -> ! isnothing (x[2 ]), collect (zip (parameters_syms (affect), p_inds))) |>
653651 NamedTuple
654652
655- let u = u, p = p, user_affect = func (affect), ctx = context (affect),
656- save_idxs = save_idxs
657- function (integ)
653+ let u = u, p = p, user_affect = func (affect), ctx = context (affect)
654+ (integ) -> begin
658655 user_affect (integ, u, p, ctx)
659- for idx in save_idxs
660- SciMLBase. save_discretes! (integ, idx)
661- end
662656 end
663657 end
664658end
@@ -670,6 +664,8 @@ function generate_continuous_callbacks(sys::AbstractSystem, dvs = unknowns(sys),
670664 cbs = continuous_events (sys)
671665 isempty (cbs) && return nothing
672666 cb_classes = Dict {SciMLBase.RootfindOpt, Vector{SymbolicContinuousCallback}} ()
667+
668+ # Sort the callbacks by their rootfinding method
673669 for cb in cbs
674670 _cbs = get! (() -> SymbolicContinuousCallback[], cb_classes, cb. rootfind)
675671 push! (_cbs, cb)
@@ -709,12 +705,12 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
709705 inits = []
710706 finals = []
711707 for cb in cbs
712- affect = compile_affect (cb. affect, cb, sys, default = (args ... ) -> () )
708+ affect = compile_affect (cb. affect, cb, sys, default = nothing )
713709 push! (affects, affect)
714- affect_neg = (cb. affect_neg === cb. affect) ? affect : compile_affect (cb. affect_neg, cb, sys, default = (args ... ) -> () )
710+ affect_neg = (cb. affect_neg == cb. affect) ? affect : compile_affect (cb. affect_neg, cb, sys, default = nothing )
715711 push! (affect_negs, affect_neg)
716- push! (inits, compile_affect (cb. initialize, cb, sys, default = nothing ))
717- push! (finals, compile_affect (cb. finalize, cb, sys, default = nothing ))
712+ push! (inits, compile_affect (cb. initialize, cb, sys; default = nothing , is_init = true ))
713+ push! (finals, compile_affect (cb. finalize, cb, sys; default = nothing ))
718714 end
719715
720716 # Since there may be different number of conditions and affects,
@@ -746,10 +742,16 @@ function generate_callback(cb, sys; kwargs...)
746742
747743 trigger = is_timed ? conditions (cb) : compile_condition (cb, sys, dvs, ps; kwargs... )
748744 affect = compile_affect (cb. affect, cb, sys, default = (args... ) -> ())
749- affect_neg = hasfield (typeof (cb), :affect_neg ) ?
750- compile_affect (cb. affect_neg, cb, sys, default = affect) : nothing
751- initialize = compile_affect (cb. initialize, cb, sys, default = SciMLBase. INITIALIZE_DEFAULT)
752- finalize = compile_affect (cb. finalize, cb, sys, default = SciMLBase. FINALIZE_DEFAULT)
745+ affect_neg = if is_discrete (cb)
746+ nothing
747+ else
748+ (cb. affect == cb. affect_neg) ? affect : compile_affect (cb. affect_neg, cb, sys, default = nothing )
749+ end
750+ init = compile_affect (cb. initialize, cb, sys, default = SciMLBase. INITIALIZE_DEFAULT, is_init = true )
751+ final = compile_affect (cb. finalize, cb, sys, default = SciMLBase. FINALIZE_DEFAULT)
752+
753+ initialize = isnothing (cb. initialize) ? init : ((c, u, t, i) -> init (i))
754+ finalize = isnothing (cb. finalize) ? final : ((c, u, t, i) -> final (i))
753755
754756 if is_discrete (cb)
755757 if is_timed && conditions (cb) isa AbstractVector
@@ -784,32 +786,81 @@ Notes
784786 - `kwargs` are passed through to `Symbolics.build_function`.
785787"""
786788function compile_affect (
787- aff:: Union{Nothing, Affect} , cb:: AbstractCallback , sys:: AbstractSystem ; default = nothing , kwargs... )
788- isnothing (aff) && return default
789-
789+ aff:: Union{Nothing, Affect} , cb:: AbstractCallback , sys:: AbstractSystem ; default = nothing , is_init = false , kwargs... )
790790 save_idxs = if ! (has_index_cache (sys) && (ic = get_index_cache (sys)) != = nothing )
791791 Int[]
792792 else
793793 get (ic. callback_to_clocks, cb, Int[])
794794 end
795795
796- if aff isa AffectSystem
797- affsys = system (aff)
798- aff_map = aff_to_sys (aff)
799- sys_map = Dict ([v => k for (k, v) in aff_map])
800- reinit = has_alg_eqs (sys)
801- ps_to_modify = discretes (aff)
802- dvs_to_modify = setdiff (unknowns (aff), getfield .(observed (sys), :lhs ))
796+ f = if isnothing (aff)
797+ default
798+ elseif aff isa AffectSystem
799+ compile_equational_affect (aff, sys)
800+ elseif aff isa FunctionalAffect || aff isa ImperativeAffect
801+ compile_functional_affect (aff, sys; kwargs... )
802+ end
803+ wrap_save_discretes (f, save_idxs; is_init)
804+ end
805+
806+ # Init can be: user defined function, nothing, or INITIALIZE_DEFAULT
807+ function wrap_save_discretes (f, save_idxs; is_init = false )
808+ if isempty (save_idxs) || f === SciMLBase. FINALIZE_DEFAULT || (isnothing (f) && ! is_init)
809+ return f
810+ elseif f === SciMLBase. INITIALIZE_DEFAULT
811+ let save_idxs = save_idxs
812+ (c, u, t, i) -> begin
813+ f (c, u, t, i)
814+ for idx in save_idxs
815+ SciMLBase. save_discretes! (i, idx)
816+ end
817+ end
818+ end
819+ else
820+ let save_idxs = save_idxs
821+ (i) -> begin
822+ isnothing (f) || f (i)
823+ for idx in save_idxs
824+ SciMLBase. save_discretes! (i, idx)
825+ end
826+ end
827+ end
828+ end
829+ end
830+
831+ """
832+ Initialize and Finalize for VectorContinuousCallback.
833+ """
834+ function compile_vector_optional_affect (funs, default)
835+ all (isnothing, funs) && return default
836+ return let funs = funs
837+ function (cb, u, t, integ)
838+ for func in funs
839+ isnothing (func) ? continue : func (integ)
840+ end
841+ end
842+ end
843+ end
844+
845+ function compile_equational_affect (aff:: AffectSystem , sys; kwargs... )
846+ affsys = system (aff)
847+ aff_map = aff_to_sys (aff)
848+ sys_map = Dict ([v => k for (k, v) in aff_map])
849+ ps_to_modify = discretes (aff)
850+ dvs_to_modify = setdiff (unknowns (aff), getfield .(observed (sys), :lhs ))
851+ # TODO : Add an optimization for systems without algebraic equations
803852
804- function affect! (integrator)
853+ return let dvs_to_modify = dvs_to_modify, aff_map = aff_map, sys_map = sys_map, affsys = affsys, ps_to_modify = ps_to_modify
854+
855+ @inline function affect! (integrator)
805856 pmap = Pair[]
806857 for pre_p in parameters (affsys)
807858 p = only (arguments (unwrap (pre_p)))
808859 pval = isparameter (p) ? integrator. ps[p] : integrator[p]
809860 push! (pmap, pre_p => pval)
810861 end
811862 guesses = Pair[u => integrator[aff_map[u]] for u in unknowns (affsys)]
812- affprob = ImplicitDiscreteProblem (affsys, Pair[], (0 , 1 ), pmap; guesses, build_initializeprob = reinit )
863+ affprob = ImplicitDiscreteProblem (affsys, Pair[], (0 , 1 ), pmap; guesses, build_initializeprob = false )
813864
814865 affsol = init (affprob, SimpleIDSolve ())
815866 for u in dvs_to_modify
@@ -818,28 +869,9 @@ function compile_affect(
818869 for p in ps_to_modify
819870 integrator. ps[p] = affsol[sys_map[p]]
820871 end
821- for idx in save_idxs
822- SciMLBase. save_discretes! (integrator, idx)
823- end
824872
825873 sys isa JumpSystem && reset_aggregated_jumps! (integrator)
826874 end
827- elseif aff isa FunctionalAffect || aff isa ImperativeAffect
828- compile_functional_affect (aff, cb, sys; kwargs... )
829- end
830- end
831-
832- """
833- Initialize and Finalize for VectorContinuousCallback.
834- """
835- function compile_vector_optional_affect (funs, default)
836- all (isnothing, funs) && return default
837- return let funs = funs
838- function (cb, u, t, integ)
839- for func in funs
840- isnothing (func) ? continue : func (integ)
841- end
842- end
843875 end
844876end
845877
0 commit comments