@@ -233,8 +233,9 @@ function make_affect(affect::Vector{Equation}; iv = nothing, algeeqs = Equation[
233233 dvs = OrderedSet ()
234234 params = OrderedSet ()
235235 for eq in affect
236- ! haspre (eq) &&
236+ if ! haspre (eq) && ! ( symbolic_type (eq . rhs) === NotSymbolic ())
237237 @warn " Affect equation $eq has no `Pre` operator. As such it will be interpreted as an algebraic equation to be satisfied after the callback. If you intended to use the value of a variable x before the affect, use Pre(x)."
238+ end
238239 collect_vars! (dvs, params, eq, iv; op = Pre)
239240 end
240241 for eq in algeeqs
@@ -299,19 +300,19 @@ function SymbolicContinuousCallbacks(events; algeeqs::Vector{Equation} = Equatio
299300 callbacks
300301end
301302
302- function Base. show (io:: IO , cb:: SymbolicContinuousCallback )
303+ function Base. show (io:: IO , cb:: AbstractCallback )
303304 indent = get (io, :indent , 0 )
304305 iio = IOContext (io, :indent => indent + 1 )
305- print (io, " SymbolicContinuousCallback(" )
306- print (iio, " Equations :" )
306+ is_discrete (cb) ? print (io, " SymbolicDiscreteCallback( " ) : print (io, " SymbolicContinuousCallback(" )
307+ print (iio, " Conditions :" )
307308 show (iio, equations (cb))
308309 print (iio, " ; " )
309310 if affects (cb) != nothing
310311 print (iio, " Affect:" )
311312 show (iio, affects (cb))
312313 print (iio, " , " )
313314 end
314- if affect_negs (cb) != nothing
315+ if ! is_discrete (cb) && affect_negs (cb) != nothing
315316 print (iio, " Negative-edge affect:" )
316317 show (iio, affect_negs (cb))
317318 print (iio, " , " )
@@ -328,19 +329,19 @@ function Base.show(io::IO, cb::SymbolicContinuousCallback)
328329 print (iio, " )" )
329330end
330331
331- function Base. show (io:: IO , mime:: MIME"text/plain" , cb:: SymbolicContinuousCallback )
332+ function Base. show (io:: IO , mime:: MIME"text/plain" , cb:: AbstractCallback )
332333 indent = get (io, :indent , 0 )
333334 iio = IOContext (io, :indent => indent + 1 )
334- println (io, " SymbolicContinuousCallback:" )
335- println (iio, " Equations :" )
335+ is_discrete (cb) ? println (io, " SymbolicDiscreteCallback: " ) : println (io, " SymbolicContinuousCallback:" )
336+ println (iio, " Conditions :" )
336337 show (iio, mime, equations (cb))
337338 print (iio, " \n " )
338339 if affects (cb) != nothing
339340 println (iio, " Affect:" )
340341 show (iio, mime, affects (cb))
341342 print (iio, " \n " )
342343 end
343- if affect_negs (cb) != nothing
344+ if ! is_discrete (cb) && affect_negs (cb) != nothing
344345 print (iio, " Negative-edge affect:\n " )
345346 show (iio, mime, affect_negs (cb))
346347 print (iio, " \n " )
@@ -394,8 +395,8 @@ Arguments:
394395- algeeqs: Algebraic equations of the system that must be satisfied after the callback occurs.
395396"""
396397struct SymbolicDiscreteCallback <: AbstractCallback
397- conditions:: Any
398- affect:: Affect
398+ conditions:: Union{Real, Vector{<:Real}, Vector{Equation}}
399+ affect:: Union{ Affect, Nothing}
399400 initialize:: Union{Affect, Nothing}
400401 finalize:: Union{Affect, Nothing}
401402
@@ -409,6 +410,9 @@ struct SymbolicDiscreteCallback <: AbstractCallback
409410 end # Default affect to nothing
410411end
411412
413+ SymbolicDiscreteCallback (p:: Pair , args... ; kwargs... ) = SymbolicDiscreteCallback (p[1 ], p[2 ])
414+ SymbolicDiscreteCallback (cb:: SymbolicDiscreteCallback , args... ; kwargs... ) = cb
415+
412416"""
413417Generate discrete callbacks.
414418"""
@@ -438,29 +442,6 @@ function is_timed_condition(condition::T) where {T}
438442 end
439443end
440444
441- function Base. show (io:: IO , db:: SymbolicDiscreteCallback )
442- indent = get (io, :indent , 0 )
443- iio = IOContext (io, :indent => indent + 1 )
444- println (io, " SymbolicDiscreteCallback:" )
445- println (iio, " Conditions:" )
446- print (iio, " ; " )
447- if affects (db) != nothing
448- print (iio, " Affect:" )
449- show (iio, affects (db))
450- print (iio, " , " )
451- end
452- if initialize_affects (db) != nothing
453- print (iio, " Initialization affect:" )
454- show (iio, initialize_affects (db))
455- print (iio, " , " )
456- end
457- if finalize_affects (db) != nothing
458- print (iio, " Finalization affect:" )
459- show (iio, finalize_affects (db))
460- end
461- print (iio, " )" )
462- end
463-
464445function vars! (vars, cb:: SymbolicDiscreteCallback ; op = Differential)
465446 if symbolic_type (conditions (cb)) == NotSymbolic
466447 if conditions (cb) isa AbstractArray
@@ -529,7 +510,7 @@ function namespace_callback(cb::SymbolicDiscreteCallback, s)::SymbolicDiscreteCa
529510end
530511
531512function Base. hash (cb:: SymbolicContinuousCallback , s:: UInt )
532- s = foldr (hash, cb. eqs , init = s)
513+ s = foldr (hash, cb. conditions , init = s)
533514 s = hash (cb. affect, s)
534515 s = hash (cb. affect_neg, s)
535516 s = hash (cb. initialize, s)
@@ -538,8 +519,8 @@ function Base.hash(cb::SymbolicContinuousCallback, s::UInt)
538519end
539520
540521function Base. hash (cb:: SymbolicDiscreteCallback , s:: UInt )
541- s = hash ( cb. condition, s)
542- s = hash (cb. affects , s)
522+ s = foldr (hash, cb. conditions, init = s)
523+ s = hash (cb. affect , s)
543524 s = hash (cb. initialize, s)
544525 hash (cb. finalize, s)
545526end
649630"""
650631Compile user-defined functional affect.
651632"""
652- function compile_functional_affect (affect:: FunctionalAffect , cb, sys, dvs, ps; kwargs... )
633+ function compile_functional_affect (affect:: FunctionalAffect , cb, sys; kwargs... )
634+ dvs = unknowns (sys)
635+ ps = parameters (sys)
653636 dvs_ind = Dict (reverse (en) for en in enumerate (dvs))
654637 v_inds = map (sym -> dvs_ind[sym], unknowns (affect))
655638
@@ -686,7 +669,18 @@ is_discrete(cb::Vector{<:AbstractCallback}) = eltype(cb) isa SymbolicDiscreteCal
686669function generate_continuous_callbacks (sys:: AbstractSystem , dvs = unknowns (sys), ps = parameters (sys; initial_parameters = true ); kwargs... )
687670 cbs = continuous_events (sys)
688671 isempty (cbs) && return nothing
689- generate_callback (cbs, sys; kwargs... )
672+ cb_classes = Dict {SciMLBase.RootfindOpt, Vector{SymbolicContinuousCallback}} ()
673+ for cb in cbs
674+ _cbs = get! (() -> SymbolicContinuousCallback[], cb_classes, cb. rootfind)
675+ push! (_cbs, cb)
676+ end
677+ cb_classes = sort! (OrderedDict (cb_classes))
678+ compiled_callbacks = [generate_callback (cb, sys; kwargs... ) for (rf, cb) in cb_classes]
679+ if length (compiled_callbacks) == 1
680+ return only (compiled_callbacks)
681+ else
682+ return CallbackSet (compiled_callbacks... )
683+ end
690684end
691685
692686function generate_discrete_callbacks (sys:: AbstractSystem , dvs = unknowns (sys), ps = parameters (sys; initial_parameters = true ); kwargs... )
@@ -716,9 +710,9 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
716710 finals = []
717711 for cb in cbs
718712 affect = compile_affect (cb. affect, cb, sys, default = (args... ) -> ())
719-
720713 push! (affects, affect)
721- push! (affect_negs, compile_affect (cb. affect_neg, cb, sys, default = affect))
714+ affect_neg = (cb. affect_neg === cb. affect) ? affect : compile_affect (cb. affect_neg, cb, sys, default = (args... ) -> ())
715+ push! (affect_negs, affect_neg)
722716 push! (inits, compile_affect (cb. initialize, cb, sys, default = nothing ))
723717 push! (finals, compile_affect (cb. finalize, cb, sys, default = nothing ))
724718 end
@@ -728,8 +722,6 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
728722 eq2affect = reduce (vcat,
729723 [fill (i, num_eqs[i]) for i in eachindex (affects)])
730724 eqs = reduce (vcat, eqs)
731- @assert length (eq2affect) == length (eqs)
732- @assert maximum (eq2affect) == length (affects)
733725
734726 affect = function (integ, idx)
735727 affects[eq2affect[idx]](integ)
@@ -744,7 +736,7 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
744736
745737 return VectorContinuousCallback (
746738 trigger, affect, affect_neg, length (eqs); initialize, finalize,
747- rootfind = cbs[1 ]. rootfind, initializealg = SciMLBase. NoInit)
739+ rootfind = cbs[1 ]. rootfind, initializealg = SciMLBase. NoInit () )
748740end
749741
750742function generate_callback (cb, sys; kwargs... )
@@ -762,16 +754,16 @@ function generate_callback(cb, sys; kwargs...)
762754 if is_discrete (cb)
763755 if is_timed && conditions (cb) isa AbstractVector
764756 return PresetTimeCallback (trigger, affect; initialize,
765- finalize, initializealg = SciMLBase. NoInit)
757+ finalize, initializealg = SciMLBase. NoInit () )
766758 elseif is_timed
767759 return PeriodicCallback (affect, trigger; initialize, finalize)
768760 else
769761 return DiscreteCallback (trigger, affect; initialize,
770- finalize, initializealg = SciMLBase. NoInit)
762+ finalize, initializealg = SciMLBase. NoInit () )
771763 end
772764 else
773765 return ContinuousCallback (trigger, affect, affect_neg; initialize, finalize,
774- rootfind = cb. rootfind, initializealg = SciMLBase. NoInit)
766+ rootfind = cb. rootfind, initializealg = SciMLBase. NoInit () )
775767 end
776768end
777769
@@ -793,27 +785,25 @@ Notes
793785"""
794786function compile_affect (
795787 aff:: Union{Nothing, Affect} , cb:: AbstractCallback , sys:: AbstractSystem ; default = nothing , kwargs... )
788+ isnothing (aff) && return default
789+
796790 save_idxs = if ! (has_index_cache (sys) && (ic = get_index_cache (sys)) != = nothing )
797791 Int[]
798792 else
799793 get (ic. callback_to_clocks, cb, Int[])
800794 end
801795
802- isnothing (aff) && return default
803-
804- ps = parameters (aff)
805- dvs = unknowns (aff)
806- dvs_to_modify = setdiff (dvs, getfield .(observed (sys), :lhs ))
807-
808796 if aff isa AffectSystem
809797 affsys = system (aff)
810798 aff_map = aff_to_sys (aff)
811799 sys_map = Dict ([v => k for (k, v) in aff_map])
812800 reinit = has_alg_eqs (sys)
801+ ps_to_modify = discretes (aff)
802+ dvs_to_modify = setdiff (unknowns (aff), getfield .(observed (sys), :lhs ))
813803
814804 function affect! (integrator)
815805 pmap = Pair[]
816- for pre_p in previous_vals (aff )
806+ for pre_p in parameters (affsys )
817807 p = only (arguments (unwrap (pre_p)))
818808 pval = isparameter (p) ? integrator. ps[p] : integrator[p]
819809 push! (pmap, pre_p => pval)
@@ -825,17 +815,17 @@ function compile_affect(
825815 for u in dvs_to_modify
826816 integrator[u] = affsol[sys_map[u]]
827817 end
828- for p in discretes (aff)
818+ for p in ps_to_modify
829819 integrator. ps[p] = affsol[sys_map[p]]
830820 end
831821 for idx in save_idxs
832- SciMLBase. save_discretes! (integ , idx)
822+ SciMLBase. save_discretes! (integrator , idx)
833823 end
834824
835825 sys isa JumpSystem && reset_aggregated_jumps! (integrator)
836826 end
837827 elseif aff isa FunctionalAffect || aff isa ImperativeAffect
838- compile_functional_affect (aff, cb, sys, dvs, ps ; kwargs... )
828+ compile_functional_affect (aff, cb, sys; kwargs... )
839829 end
840830end
841831
0 commit comments