@@ -67,21 +67,31 @@ struct AffectSystem
6767 unknowns:: Vector
6868 parameters:: Vector
6969 discretes:: Vector
70- """ Maps the unknowns in the ImplicitDiscreteSystem to the corresponding parameter or unknown in the parent system."""
71- affu_to_sysu :: Dict
70+ """ Maps the symbols of unknowns/observed in the ImplicitDiscreteSystem to its corresponding unknown/parameter in the parent system."""
71+ aff_to_sys :: Dict
7272end
7373
7474system (a:: AffectSystem ) = a. system
7575discretes (a:: AffectSystem ) = a. discretes
7676unknowns (a:: AffectSystem ) = a. unknowns
7777parameters (a:: AffectSystem ) = a. parameters
78- affu_to_sysu (a:: AffectSystem ) = a. affu_to_sysu
78+ aff_to_sys (a:: AffectSystem ) = a. aff_to_sys
79+ previous_vals (a:: AffectSystem ) = parameters (system (a))
80+ updated_vals (a:: AffectSystem ) = unknowns (system (a))
7981
8082function Base. show (iio:: IO , aff:: AffectSystem )
8183 eqs = vcat (equations (system (aff)), observed (system (aff)))
8284 show (iio, eqs)
8385end
8486
87+ function Base.:(== )(a1:: AffectSystem , a2:: AffectSystem )
88+ isequal (system (a1), system (a2)) &&
89+ isequal (discretes (a1), discretes (a2)) &&
90+ isequal (unknowns (a1), unknowns (a2)) &&
91+ isequal (parameters (a1), parameters (a2)) &&
92+ isequal (aff_to_sys (a1), aff_to_sys (a2))
93+ end
94+
8595"""
8696 Pre(x)
8797
@@ -112,14 +122,14 @@ function (p::Pre)(x)
112122 iscall (x) && operation (x) isa Pre && return x
113123 result = if symbolic_type (x) == ArraySymbolic ()
114124 # create an array for `Pre(array)`
115- Symbolics. array_term (p, toparam (x) )
125+ Symbolics. array_term (p, x )
116126 elseif iscall (x) && operation (x) == getindex
117127 # instead of `Pre(x[1])` create `Pre(x)[1]`
118128 # which allows parameter indexing to handle this case automatically.
119129 arr = arguments (x)[1 ]
120- term (getindex, p (toparam ( arr) ), arguments (x)[2 : end ]. .. )
130+ term (getindex, p (arr), arguments (x)[2 : end ]. .. )
121131 else
122- term (p, toparam (x) )
132+ term (p, x )
123133 end
124134 # the result should be a parameter
125135 result = toparam (result)
@@ -231,25 +241,36 @@ function make_affect(affect::Vector{Equation}; warn = true)
231241 discretes = Any[]
232242 p_as_unknowns = Any[]
233243 for p in params
234- if iscall (p) && (operator (p) isa Pre)
244+ if iscall (p) && (operation (p) isa Pre)
235245 push! (cb_params, p)
236246 elseif iscall (p) && length (arguments (p)) == 1 &&
237247 isequal (only (arguments (p)), iv)
238248 push! (discretes, p)
239249 push! (p_as_unknowns, tovar (p))
240250 else
241251 push! (discretes, p)
242- p = iscall (p) ? wrap (Sym {FnType{Tuple{symtype(iv)}, Real}} (nameof (operation (p)))(iv)) :
243- wrap (Sym {FnType{Tuple{symtype(iv)}, Real}} (nameof (p))(iv))
252+ name = iscall (p) ? nameof (operation (p)) : nameof (p)
253+ p = wrap (Sym {FnType{Tuple{symtype(iv)}, Real}} (name)(iv))
254+ p = setmetadata (p, Symbolics. VariableSource, (:variables , name))
244255 push! (p_as_unknowns, p)
245256 end
246257 end
258+ aff_map = Dict (zip (p_as_unknowns, discretes))
259+ rev_map = Dict ([v => k for (k, v) in aff_map])
260+ affect = Symbolics. substitute (affect, rev_map)
247261 @mtkbuild affectsys = ImplicitDiscreteSystem (
248262 affect, iv, collect (union (unknowns, p_as_unknowns)), cb_params)
249- params = map (x -> only (arguments (unwrap (x))), cb_params)
250- affmap = Dict (zip ([p_as_unknowns, unknowns], [discretes, unknowns]))
263+ params = filter (isparameter, map (x -> only (arguments (unwrap (x))), cb_params))
264+ @show params
265+
266+ for u in unknowns
267+ aff_map[u] = u
268+ end
269+
270+ @show unknowns
271+ @show params
251272
252- return AffectSystem (affectsys, collect (unknowns), params, discretes, affmap )
273+ return AffectSystem (affectsys, collect (unknowns), params, discretes, aff_map )
253274end
254275
255276function make_affect (affect)
@@ -393,17 +414,19 @@ function SymbolicDiscreteCallbacks(events, algeeqs::Vector{Equation} = Equation[
393414
394415 for event in events
395416 cond, affs = event isa Pair ? (event[1 ], event[2 ]) : (event, nothing )
396- if aff isa AbstractVector
397- aff = vcat (aff , algeeqs)
417+ if affs isa AbstractVector
418+ affs = vcat (affs , algeeqs)
398419 end
399- affect = make_affect (aff )
400- push! (callbacks, SymbolicDiscreteCallback (cond, affect, nothing , nothing ))
420+ affect = make_affect (affs )
421+ push! (callbacks, SymbolicDiscreteCallback (cond, affect))
401422 end
402423 callbacks
403424end
404425
405426function is_timed_condition (condition:: T ) where {T}
406- if T <: Real
427+ if T === Num
428+ false
429+ elseif T <: Real
407430 true
408431 elseif T <: AbstractVector
409432 eltype (condition) <: Real
@@ -582,23 +605,31 @@ function compile_condition(cbs::Union{AbstractCallback, Vector{<:AbstractCallbac
582605 condit = substitute (condit, cmap)
583606 end
584607
585- f_oop, f_iip = build_function_wrapper (sys,
586- condit, u, t, p... ; expression = Val{true },
587- p_start = 3 , p_end = length (p) + 2 ,
608+ if ! is_discrete (cbs)
609+ condit = [cond. lhs - cond. rhs for cond in condit]
610+ end
611+
612+ fs = build_function_wrapper (sys,
613+ condit, u, p... , t; expression,
588614 kwargs... )
589615
590- if cbs isa AbstractVector
591- cond (out, u, t, integ) = f_iip (out, u, t, parameter_values (integ))
616+ if expression == Val{true }
617+ fs = eval_or_rgf .(fs; eval_expression, eval_module)
618+ end
619+ is_discrete (cbs) ? (f_oop = fs) : (f_oop, f_iip = fs)
620+
621+ cond = if cbs isa AbstractVector
622+ (out, u, t, integ) -> f_iip (out, u, parameter_values (integ), t)
592623 elseif is_discrete (cbs)
593- cond (u, t, integ) = f_oop (u, t, parameter_values (integ))
624+ (u, t, integ) -> f_oop (u, parameter_values (integ), t )
594625 else
595- cond = function (u, t, integ)
626+ function (u, t, integ)
596627 if DiffEqBase. isinplace (integ. sol. prob)
597628 tmp, = DiffEqBase. get_tmp_cache (integ)
598- f_iip (tmp, u, t, parameter_values (integ))
629+ f_iip (tmp, u, parameter_values (integ), t )
599630 tmp[1 ]
600631 else
601- f_oop (u, t, parameter_values (integ))
632+ f_oop (u, parameter_values (integ), t )
602633 end
603634 end
604635 end
@@ -641,6 +672,7 @@ function compile_functional_affect(affect::FunctionalAffect, cb, sys, dvs, ps; k
641672end
642673
643674is_discrete (cb:: AbstractCallback ) = cb isa SymbolicDiscreteCallback
675+ is_discrete (cb:: Vector{<:AbstractCallback} ) = eltype (cb) isa SymbolicDiscreteCallback
644676
645677function generate_continuous_callbacks (sys:: AbstractSystem , dvs = unknowns (sys), ps = parameters (sys; initial_parameters = true ); kwargs... )
646678 cbs = continuous_events (sys)
@@ -668,27 +700,27 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
668700 return generate_callback (cbs[cb_ind], sys; kwargs... )
669701 end
670702
671- trigger = compile_condition (cbs, sys, dvs, ps ; kwargs... )
703+ trigger = compile_condition (cbs, sys, unknowns (sys), parameters (sys; initial_parameters = true ) ; kwargs... )
672704 affects = []
673705 affect_negs = []
674706 inits = []
675707 finals = []
676708 for cb in cbs
677- affect = compile_affect (cb. affect, cb, sys)
709+ affect = compile_affect (cb. affect, cb, sys, default = (args ... ) -> () )
678710
679711 push! (affects, affect)
680712 push! (affect_negs, compile_affect (cb. affect_neg, cb, sys, default = affect))
681- push! (inits, compile_affect (cb. initialize, cb, sys, default = SciMLBase . INITALIZE_DEFAULT ))
682- push! (finals, compile_affect (cb. finalize, cb, sys, default = SciMLBase . FINALIZE_DEFAULT ))
713+ push! (inits, compile_affect (cb. initialize, cb, sys, default = nothing ))
714+ push! (finals, compile_affect (cb. finalize, cb, sys, default = nothing ))
683715 end
684716
685717 # Since there may be different number of conditions and affects,
686718 # we build a map that translates the condition eq. number to the affect number
687- num_eqs = length .(eqs)
688719 eq2affect = reduce (vcat,
689720 [fill (i, num_eqs[i]) for i in eachindex (affects)])
721+ eqs = reduce (vcat, eqs)
690722 @assert length (eq2affect) == length (eqs)
691- @assert maximum (eq2affect) == length (affect_functions )
723+ @assert maximum (eq2affect) == length (affects )
692724
693725 affect = function (integ, idx)
694726 affects[eq2affect[idx]](integ)
@@ -702,8 +734,8 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
702734 finalize = compile_vector_optional_affect (finals, SciMLBase. FINALIZE_DEFAULT)
703735
704736 return VectorContinuousCallback (
705- trigger, affect, length (cbs); affect_neg, initialize, finalize,
706- rootfind = callback . rootfind, initializealg = SciMLBase. NoInit)
737+ trigger, affect, affect_neg, length (eqs); initialize, finalize,
738+ rootfind = cbs[ 1 ] . rootfind, initializealg = SciMLBase. NoInit)
707739end
708740
709741function generate_callback (cb, sys; kwargs... )
@@ -712,14 +744,14 @@ function generate_callback(cb, sys; kwargs...)
712744 ps = parameters (sys; initial_parameters = true )
713745
714746 trigger = is_timed ? conditions (cb) : compile_condition (cb, sys, dvs, ps; kwargs... )
715- affect = compile_affect (cb. affect, cb, sys)
747+ affect = compile_affect (cb. affect, cb, sys, default = (args ... ) -> () )
716748 affect_neg = hasfield (typeof (cb), :affect_neg ) ?
717749 compile_affect (cb. affect_neg, cb, sys, default = affect) : nothing
718750 initialize = compile_affect (cb. initialize, cb, sys, default = SciMLBase. INITIALIZE_DEFAULT)
719751 finalize = compile_affect (cb. finalize, cb, sys, default = SciMLBase. FINALIZE_DEFAULT)
720752
721753 if is_discrete (cb)
722- if is_timed && condition (cb) isa AbstractVector
754+ if is_timed && conditions (cb) isa AbstractVector
723755 return PresetTimeCallback (trigger, affect; affect_neg, initialize,
724756 finalize, initializealg = SciMLBase. NoInit)
725757 elseif is_timed
@@ -762,22 +794,30 @@ function compile_affect(
762794
763795 ps = parameters (aff)
764796 dvs = unknowns (aff)
797+ @show ps
765798
766799 if aff isa AffectSystem
767- aff_map = affu_to_sysu (aff)
800+ aff_map = aff_to_sys (aff)
801+ sys_map = Dict ([v => k for (k, v) in aff_map])
802+ build_initializeprob = has_alg_eqs (sys)
803+
768804 function affect! (integrator)
769- pmap = []
770- for pre_p in parameters ( system (affect) )
805+ pmap = Pair []
806+ for pre_p in previous_vals (aff )
771807 p = only (arguments (unwrap (pre_p)))
772- push! (pmap, pre_p => integrator[p])
773- end
774- guesses = [u => integrator[aff_map[u]] for u in unknowns (system (affect))]
775- prob = ImplicitDiscreteProblem (system (affect), [], (0 , 1 ), pmap; guesses)
776- sol = init (prob, SimpleIDSolve ())
777- for u in unknowns (system (affect))
778- integrator[aff_map[u]] = sol[u]
808+ pval = isparameter (p) ? integrator. ps[p] : integrator[p]
809+ push! (pmap, pre_p => pval)
779810 end
811+ guesses = Pair[u => integrator[aff_map[u]] for u in updated_vals (aff)]
812+ affprob = ImplicitDiscreteProblem (system (aff), Pair[], (0 , 1 ), pmap; guesses, build_initializeprob)
780813
814+ affsol = init (affprob, SimpleIDSolve ())
815+ for u in unknowns (aff)
816+ integrator[u] = affsol[u]
817+ end
818+ for p in discretes (aff)
819+ integrator. ps[p] = affsol[sys_map[p]]
820+ end
781821 for idx in save_idxs
782822 SciMLBase. save_discretes! (integ, idx)
783823 end
0 commit comments