@@ -11,35 +11,139 @@ function get_discrete_events(sys::AbstractSystem)
1111    getfield (sys, :discrete_events )
1212end 
1313
14- struct  Callback
15-     eqs:: Vector{Equation} 
16-     initialize:: Union{ImplicitDiscreteSystem, FunctionalAffect, ImperativeAffect} 
17-     finalize:: ImplicitDiscreteSystem 
18-     affect:: ImplicitDiscreteSystem 
19-     affect_neg:: ImplicitDiscreteSystem 
20-     rootfind:: Union{Nothing, SciMLBase.RootfindOpt} 
21- end 
14+ abstract type  Callback end 
15+ 
16+ const  Affect =  Union{ImplicitDiscreteSystem, FunctionalAffect, ImperativeAffect}
2217
2318#  Callbacks: 
2419#    mapping (cond) => ImplicitDiscreteSystem
2520function  generate_continuous_callbacks (events, sys)
2621    algeeqs =  alg_equations (sys)
27-     callbacks =  Callback []
28-     for  (cond, aff ) in  events
29-         @mtkbuild  affect =  ImplicitDiscreteSystem ([aff , algeeqs], t)
30-         push! (callbacks, Callback (cond, NULL_AFFECT, NULL_AFFECT, affect, affect, SciMLBase. LeftRootFind)) 
22+     callbacks =  MTKContinuousCallback []
23+     for  (cond, affs ) in  events
24+         @mtkbuild  affect =  ImplicitDiscreteSystem ([affs , algeeqs], t)
25+         push! (callbacks, MTKContinuousCallback (cond, NULL_AFFECT, NULL_AFFECT, affect, affect, SciMLBase. LeftRootFind)) 
3126    end 
3227    callbacks
3328end 
3429
35- function  generate_discrete_callback_system (events, sys)
30+ function  generate_discrete_callbacks (events, sys)
31+     algeeqs =  alg_equations (sys)
32+     callbacks =  MTKDiscreteCallback[]
33+     for  (cond, affs) in  events
34+         @mtkbuild  affect =  ImplicitDiscreteSystem ([affs, algeeqs], t)
35+         push! (callbacks, MTKDiscreteCallback (cond, NULL_AFFECT, NULL_AFFECT, affect)) 
36+     end 
37+     callbacks
3638end 
3739
38- function  generate_callback_function () 
39-     
40+ """ 
41+ Create a DifferentialEquations callback. A set of continuous callbacks becomes a VectorContinuousCallback. 
42+ """ 
43+ function  create_callback (cbs:: Vector{MTKContinuousCallback} , sys; is_discrete =  false )
44+     eqs =  flatten_equations (cbs)
45+     _, f_iip =  generate_custom_function (
46+         sys, [eq. lhs -  eq. rhs for  eq in  eqs], unknowns (sys), parameters (sys); 
47+         expression =  Val{false })
48+     trigger =  (out, u, t, integ) ->  f_iip (out, u, parameter_values (integ), t)
49+ 
50+     affects =  []
51+     affect_negs =  []
52+     inits =  []
53+     finals =  []
54+     for  cb in  cbs
55+         affect =  compile_affect (cb. affect)
56+         push! (affects, affect)
57+         isnothing (cb. affect_neg) ?  push! (affect_negs, affect) :  push! (affect_negs, compile_affect (cb. affect_neg))
58+         push! (inits, compile_affect (cb. initialize, default =  SciMLBase. INITALIZE_DEFAULT))
59+         push! (finals, compile_affect (cb. finalize, default =  SciMLBase. FINALIZE_DEFAULT))
60+     end 
61+ 
62+     #  since there may be different number of conditions and affects,
63+     #  we build a map that translates the condition eq. number to the affect number
64+     num_eqs =  length .(eqs)
65+     eq2affect =  reduce (vcat,
66+         [fill (i, num_eqs[i]) for  i in  eachindex (affects)])
67+     @assert  length (eq2affect) ==  length (eqs)
68+     @assert  maximum (eq2affect) ==  length (affect_functions)
69+ 
70+     affect =  function  (integ, idx)
71+         affects[eq2affect[idx]](integ)
72+     end 
73+     affect_neg =  function  (integ, idx)
74+         f =  affect_negs[eq2affect[idx]]
75+         isnothing (f) &&  return 
76+         f (integ)
77+     end 
78+     initialize =  compile_optional_setup (inits, SciMLBase. INITIALIZE_DEFAULT)
79+     finalize =  compile_optional_setup (finals, SciMLBase. FINALIZE_DEFAULT)
80+ 
81+     return  VectorContinuousCallback (trigger, affect; affect_neg, initialize, finalize, rootfind =  callback. rootfind, initializealg =  SciMLBase. NoInit)
82+ end 
83+ 
84+ function  create_callback (cb, sys; is_discrete =  false )
85+     is_timed =  is_timed_condition (cb)
86+ 
87+     trigger =  if  is_discrete
88+         is_timed ?  condition (cb) : 
89+             compile_condition (callback, sys, unknowns (sys), parameters (sys))
90+         else 
91+             _, f_iip =  generate_custom_function (
92+                 sys, [eq. rhs -  eq. lhs for  eq in  equations (cb)], unknowns (sys), parameters (sys); 
93+                 expression =  Val{false })
94+             (out, u, t, integ) ->  f_iip (out, u, parameter_values (integ), t)
95+         end 
96+ 
97+     affect =  compile_affect (cb. affect) 
98+     affect_neg =  isnothing (cb. affect_neg) ?  affect_fn :  compile_affect (cb. affect_neg)
99+     initialize =  compile_affect (cb. initialize, default =  SciMLBase. INITIALIZE_DEFAULT)
100+     finalize =  compile_affect (cb. finalize, default =  SciMLBase. FINALIZE_DEFAULT)
101+ 
102+     if  is_discrete
103+         if  is_timed &&  condition (cb) isa  AbstractVector
104+             return  PresetTimeCallback (trigger, affect; affect_neg, initialize, finalize, initializealg =  SciMLBase. NoInit)
105+         elseif  is_timed
106+             return  PeriodicCallback (affect, trigger; initialize, finalize)
107+         else 
108+             return  DiscreteCallback (trigger, affect; affect_neg, initialize, finalize, initializealg =  SciMLBase. NoInit)
109+         end 
110+     else 
111+         return  ContinuousCallback (trigger, affect; affect_neg, initialize, finalize, rootfind =  callback. rootfind, initializealg =  SciMLBase. NoInit)
112+     end 
113+ end 
114+ 
115+ function  compile_affect (aff; default =  nothing )
116+     if  aff isa  ImplicitDiscreteSystem
117+         function  affect! (integrator) 
118+             u0map =  [u =>  integrator[u] for  u in  unknowns (aff)]
119+             pmap =  [p =>  integrator[p] for  p in  parameters (aff)]
120+             prob =  ImplicitDiscreteProblem (aff, u0map, (0 , 1 ), pmap)
121+             sol =  solve (prob)
122+             for  u in  unknowns (aff)
123+                 integrator[u] =  sol[u][end ]
124+             end 
125+             for  p in  parameters (aff)
126+                 integrator[p] =  sol[p][end ]
127+             end 
128+         end 
129+     elseif  aff isa  FunctionalAffect ||  aff isa  ImperativeAffect
130+         compile_user_affect (aff, callback, sys, unknowns (sys), parameters (sys))
131+     else 
132+         default
133+     end 
134+ end 
135+ 
136+ function  compile_setup_funcs (funs, default)
137+     all (isnothing, funs) &&  return  default
138+     return  let  funs =  funs
139+         function  (cb, u, t, integ)
140+            for  func in  funs
141+                isnothing (func) ?  continue  :  func (integ)
142+            end 
143+         end 
144+     end 
40145end 
41146
42- # ############ Old implementation ###
43147struct  FunctionalAffect
44148    f:: Any 
45149    sts:: Vector 
@@ -50,6 +154,22 @@ struct FunctionalAffect
50154    ctx:: Any 
51155end 
52156
157+ struct  MTKContinuousCallback <:  Callback 
158+     eqs:: Vector{Equation} 
159+     initialize:: Union{Affect, Nothing} 
160+     finalize:: Union{Affect, Nothing} 
161+     affect:: Affect 
162+     affect_neg:: Union{Affect, Nothing} 
163+     rootfind:: Union{Nothing, SciMLBase.RootfindOpt} 
164+ end 
165+ 
166+ struct  MTKDiscreteCallback <:  Callback 
167+     conds:: Vector{Equation} 
168+     initialize:: Union{Affect, Nothing} 
169+     finalize:: Union{Affect, Nothing} 
170+     affect:: Affect 
171+ end 
172+ 
53173function  FunctionalAffect (f, sts, pars, discretes, ctx =  nothing )
54174    #  sts & pars contain either pairs: resistor.R => R, or Syms: R
55175    vs =  [x isa  Pair ?  x. first :  x for  x in  sts]
@@ -67,7 +187,7 @@ function FunctionalAffect(; f, sts, pars, discretes, ctx = nothing)
67187    FunctionalAffect (f, sts, pars, discretes, ctx)
68188end 
69189
70- func (f :: FunctionalAffect ) =  f . f
190+ func (a :: FunctionalAffect ) =  a . f
71191context (a:: FunctionalAffect ) =  a. ctx
72192parameters (a:: FunctionalAffect ) =  a. pars
73193parameters_syms (a:: FunctionalAffect ) =  a. pars_syms
@@ -699,6 +819,7 @@ function compile_affect(eqs::Vector{Equation}, cb, sys, dvs, ps; outputidxs = no
699819            outputidxs =  update_inds,
700820            create_bindings =  false ,
701821            kwargs... )
822+         @show  rf_oop
702823        #  applied user-provided function to the generated expression
703824        if  postprocess_affect_expr! != =  nothing 
704825            postprocess_affect_expr! (rf_ip, integ)
@@ -869,13 +990,7 @@ function compile_affect_fn(cb, sys::AbstractTimeDependentSystem, dvs, ps, kwargs
869990    eq_aff =  affects (cb)
870991    eq_neg_aff =  affect_negs (cb)
871992    affect =  compile_affect (eq_aff, cb, sys, dvs, ps; expression =  Val{false }, kwargs... )
872-     function  compile_optional_affect (aff, default =  nothing )
873-         if  isnothing (aff) ||  aff ==  default
874-             return  nothing 
875-         else 
876-             return  compile_affect (aff, cb, sys, dvs, ps; expression =  Val{false }, kwargs... )
877-         end 
878-     end 
993+ 
879994    if  eq_neg_aff ===  eq_aff
880995        affect_neg =  affect
881996    else 
@@ -1017,13 +1132,15 @@ end
10171132function  compile_affect (affect:: FunctionalAffect , cb, sys, dvs, ps; kwargs... )
10181133    compile_user_affect (affect, cb, sys, dvs, ps; kwargs... )
10191134end 
1135+ 
10201136function  _compile_optional_affect (default, aff, cb, sys, dvs, ps; kwargs... )
10211137    if  isnothing (aff) ||  aff ==  default
10221138        return  nothing 
10231139    else 
10241140        return  compile_affect (aff, cb, sys, dvs, ps; expression =  Val{false }, kwargs... )
10251141    end 
10261142end 
1143+ 
10271144function  generate_timed_callback (cb, sys, dvs, ps; postprocess_affect_expr! =  nothing ,
10281145        kwargs... )
10291146    cond =  condition (cb)
0 commit comments