@@ -30,10 +30,8 @@ in the returned tuple, in which case the associated field will not be updated.
3030""" 
3131struct  ImperativeAffect
3232    f:: Any 
33-     obs:: Vector 
34-     obs_syms:: Vector{Symbol} 
35-     modified:: Vector 
36-     mod_syms:: Vector{Symbol} 
33+     observed:: NamedTuple 
34+     modified:: NamedTuple 
3735    ctx:: Any 
3836    skip_checks:: Bool 
3937end 
@@ -43,10 +41,7 @@ function ImperativeAffect(f;
4341        modified:: NamedTuple  =  NamedTuple {()} (()),
4442        ctx =  nothing ,
4543        skip_checks =  false )
46-     ImperativeAffect (f,
47-         collect (values (observed)), collect (keys (observed)),
48-         collect (values (modified)), collect (keys (modified)),
49-         ctx, skip_checks)
44+     ImperativeAffect (f, observed, modified, ctx, skip_checks)
5045end 
5146function  ImperativeAffect (f, modified:: NamedTuple ;
5247        observed:: NamedTuple  =  NamedTuple {()} (()), ctx =  nothing , skip_checks =  false )
@@ -68,61 +63,54 @@ function ImperativeAffect(; f, kwargs...)
6863end 
6964
7065function  Base. show (io:: IO , mfa:: ImperativeAffect )
71-     obs_vals  =  join ( map ((ob, nm)  ->   " $ob  =>  $nm " ,  mfa. obs, mfa . obs_syms),  " ,  " ) 
72-     mod_vals  =  join ( map ((md, nm)  ->   " $md  =>  $nm " ,  mfa. modified, mfa . mod_syms),  " ,  " ) 
66+     obs  =  mfa. observed 
67+     mod  =  mfa. modified
7368    affect =  mfa. f
7469    print (io,
75-         " ImperativeAffect(observed: [$obs_vals  ], modified: [$mod_vals  ], affect:$affect )" 
70+         " ImperativeAffect(observed: [$(obs)  ], modified: [$(mod)  ], affect:$affect )" 
7671end 
7772func (f:: ImperativeAffect ) =  f. f
7873context (a:: ImperativeAffect ) =  a. ctx
79- observed (a:: ImperativeAffect ) =  a. obs
80- observed_syms (a:: ImperativeAffect ) =  a. obs_syms
8174function  discretes (a:: ImperativeAffect )
8275    Iterators. filter (ModelingToolkit. isparameter,
8376        Iterators. flatten (Iterators. map (
8477            x ->  symbolic_type (x) ==  NotSymbolic () &&  x isa  AbstractArray ?  x :  [x],
8578            a. modified)))
8679end 
87- modified (a:: ImperativeAffect ) =  a. modified
88- modified_syms (a:: ImperativeAffect ) =  a. mod_syms
8980
9081function  Base.:(== )(a1:: ImperativeAffect , a2:: ImperativeAffect )
91-     isequal (a1. f, a2. f) &&  isequal (a1. obs , a2. obs)  &&   isequal (a1 . modified, a2 . modified ) && 
92-         isequal (a1. obs_syms , a2. obs_syms)  &&   isequal (a1 . mod_syms, a2 . mod_syms ) && 
82+     isequal (a1. f, a2. f) &&  isequal (a1. observed , a2. observed ) && 
83+         isequal (a1. modified , a2. modified ) && 
9384        isequal (a1. ctx, a2. ctx)
9485end 
9586
9687function  Base. hash (a:: ImperativeAffect , s:: UInt )
9788    s =  hash (a. f, s)
98-     s =  hash (a. obs, s)
99-     s =  hash (a. obs_syms, s)
89+     s =  hash (a. observed, s)
10090    s =  hash (a. modified, s)
101-     s =  hash (a. mod_syms, s)
10291    hash (a. ctx, s)
10392end 
10493
10594namespace_affects (af:: ImperativeAffect , s) =  namespace_affect (af, s)
106- function   namespace_affect (affect :: ImperativeAffect , s) 
107-     rmn  =  [] 
108-     for  modded  in   modified (affect )
109-          if   symbolic_type (modded)  ==   NotSymbolic ()  &&  modded  isa  AbstractArray 
110-             res  =  [] 
111-              for  m  in  modded 
112-                  push! (res,  renamespace (s, m)) 
113-              end 
114-             push! (rmn, res )
95+ 
96+ function   _namespace_nt (nt :: NamedTuple , s :: AbstractSystem ) 
97+     return   NamedTuple {keys(nt)} ( _namespace_nt ( values (nt), s) )
98+ end 
99+ 
100+ function   _namespace_nt (nt :: Union{AbstractArray, Tuple} , s :: AbstractSystem ) 
101+     return   map ( length (nt))  do  v 
102+         if   symbolic_type (v)  ==   NotSymbolic () 
103+             _namespace_nt (v, s )
115104        else 
116-             push! (rmn,  renamespace (s, modded) )
105+             renamespace (s, v )
117106        end 
118107    end 
119-     ImperativeAffect (func (affect),
120-         namespace_expr .(observed (affect), (s,)),
121-         observed_syms (affect),
122-         rmn,
123-         modified_syms (affect),
124-         context (affect),
125-         affect. skip_checks)
108+ end 
109+ 
110+ function  namespace_affect (affect:: ImperativeAffect , s)
111+     obs =  _namespace_nt (affect. observed, s)
112+     mod =  _namespace_nt (affect. modified, s)
113+     ImperativeAffect (affect. f, obs, mod, affect. ctx, affect. skip_checks)
126114end 
127115
128116function  invalid_variables (sys, expr)
@@ -139,21 +127,6 @@ function unassignable_variables(sys, expr)
139127        x ->  ! any (isequal (x), assignable_syms), written)
140128end 
141129
142- @generated  function  _generated_writeback (integ, setters:: NamedTuple{NS1, <:Tuple} ,
143-         values:: NamedTuple{NS2, <:Tuple} ) where  {NS1, NS2}
144-     setter_exprs =  []
145-     for  name in  NS2
146-         if  ! (name in  NS1)
147-             missing_name =  " Tried to write back to $name  from affect; only declared states ($NS1 ) may be written to." 
148-             error (missing_name)
149-         end 
150-         push! (setter_exprs, :(setters.$ name (integ, values.$ name)))
151-     end 
152-     return  :(begin 
153-         $ (setter_exprs... )
154-     end )
155- end 
156- 
157130function  check_assignable (sys, sym)
158131    if  symbolic_type (sym) ==  ScalarSymbolic ()
159132        is_variable (sys, sym) ||  is_parameter (sys, sym)
@@ -167,6 +140,41 @@ function check_assignable(sys, sym)
167140    end 
168141end 
169142
143+ function  _nt_check_valid (nt:: NamedTuple , s:: AbstractSystem , isobserved:: Bool )
144+     _nt_check_valid (values (nt), s, isobserved)
145+ end 
146+ 
147+ function  _nt_check_valid (nt:: Union{Tuple, AbstractArray} , s:: AbstractSystem , isobserved:: Bool )
148+     for  v in  nt
149+         if  symbolic_type (v) ==  NotSymbolic ()
150+             _nt_check_valid (v, s, isobserved)
151+             continue 
152+         end 
153+         if  ! isobserved &&  ! check_assignable (s, v)
154+             @warn  """ 
155+             Expression $v  cannot be assigned to; currently only unknowns and parameters may \ 
156+             be updated by an affect. 
157+             """  
158+         end 
159+         invalid =  invalid_variables (s, v)
160+         isempty (invalid) &&  continue 
161+         name =  isobserved ?  " Observed" :  " Modified" 
162+         error (""" 
163+         $name  expression $(v)  in affect refers to missing variable(s) $(invalid) ; \ 
164+         the variables may not have been added (e.g. if a component is missing). 
165+         """  )
166+     end 
167+ end 
168+ 
169+ function  _nt_check_overlap (nta:: NamedTuple , ntb:: NamedTuple )
170+     common =  intersect (keys (nta), keys (ntb))
171+     isempty (common) &&  return 
172+     @warn  """ 
173+     The symbols $common  are declared as both observed and modified; this is a code smell \ 
174+     because it becomes easy to confuse them and assign/not assign a value. 
175+     """  
176+ end 
177+ 
170178function  compile_functional_affect (
171179        affect:: ImperativeAffect , sys; reset_jumps =  false , kwargs... )
172180    #= 
@@ -176,93 +184,27 @@ function compile_functional_affect(
176184        call the affect method 
177185        unpack and apply the resulting values 
178186    =#  
179-     function  check_dups (syms, exprs) #  = (syms_dedup, exprs_dedup)
180-         seen =  Set {Symbol} ()
181-         syms_dedup =  []
182-         exprs_dedup =  []
183-         for  (sym, exp) in  Iterators. zip (syms, exprs)
184-             if  ! in (sym, seen)
185-                 push! (syms_dedup, sym)
186-                 push! (exprs_dedup, exp)
187-                 push! (seen, sym)
188-             elseif  ! affect. skip_checks
189-                 @warn  " Expression $(expr)  is aliased as $sym , which has already been used. The first definition will be used." 
190-             end 
191-         end 
192-         return  (syms_dedup, exprs_dedup)
193-     end 
194187
195-     dvs =  unknowns (sys)
196-     ps =  parameters (sys)
197- 
198-     obs_exprs =  observed (affect)
199-     if  ! affect. skip_checks
200-         for  oexpr in  obs_exprs
201-             invalid_vars =  invalid_variables (sys, oexpr)
202-             if  length (invalid_vars) >  0 
203-                 error (" Observed equation $(oexpr)  in affect refers to missing variable(s) $(invalid_vars) ; the variables may not have been added (e.g. if a component is missing)." 
204-             end 
205-         end 
206-     end 
207-     obs_syms =  observed_syms (affect)
208-     obs_syms, obs_exprs =  check_dups (obs_syms, obs_exprs)
209- 
210-     mod_exprs =  modified (affect)
211188    if  ! affect. skip_checks
212-         for  mexpr in  mod_exprs
213-             if  ! check_assignable (sys, mexpr)
214-                 @warn  (" Expression $mexpr  cannot be assigned to; currently only unknowns and parameters may be updated by an affect." 
215-             end 
216-             invalid_vars =  unassignable_variables (sys, mexpr)
217-             if  length (invalid_vars) >  0 
218-                 error (" Modified equation $(mexpr)  in affect refers to missing variable(s) $(invalid_vars) ; the variables may not have been added (e.g. if a component is missing) or they may have been reduced away." 
219-             end 
220-         end 
221-     end 
222-     mod_syms =  modified_syms (affect)
223-     mod_syms, mod_exprs =  check_dups (mod_syms, mod_exprs)
224- 
225-     overlapping_syms =  intersect (mod_syms, obs_syms)
226-     if  length (overlapping_syms) >  0  &&  ! affect. skip_checks
227-         @warn  " The symbols $overlapping_syms  are declared as both observed and modified; this is a code smell because it becomes easy to confuse them and assign/not assign a value." 
189+         _nt_check_valid (affect. observed, sys, true )
190+         _nt_check_valid (affect. modified, sys, false )
191+         _nt_check_overlap (affect. observed, affect. modified)
228192    end 
229193
230194    #  sanity checks done! now build the data and update function for observed values
231-     mkzero (sz) = 
232-         if  sz ===  ()
233-             0.0 
234-         else 
235-             zeros (sz)
236-         end 
237-     obs_fun =  build_explicit_observed_function (
238-         sys, Symbolics. scalarize .(obs_exprs);
239-         mkarray =  (es, _) ->  MakeTuple (es))
240-     obs_sym_tuple =  (obs_syms... ,)
241- 
242-     #  okay so now to generate the stuff to assign it back into the system
243-     mod_pairs =  mod_exprs .=>  mod_syms
244-     mod_names =  (mod_syms... ,)
245-     mod_og_val_fun =  build_explicit_observed_function (
246-         sys, Symbolics. scalarize .(first .(mod_pairs));
247-         mkarray =  (es, _) ->  MakeTuple (es))
195+     let  user_affect =  func (affect), ctx =  context (affect),
196+         obs_getter =  getsym (sys, affect. observed),
197+         mod_getter =  getsym (sys, affect. modified),
198+         mod_setter =  setsym (sys, affect. modified),
199+         reset_jumps =  reset_jumps
248200
249-     upd_funs =  NamedTuple {mod_names} ((setu .((sys,), first .(mod_pairs))... ,))
250- 
251-     let  user_affect =  func (affect), ctx =  context (affect), reset_jumps =  reset_jumps
252201        @inline  function  (integ)
253-             #  update the to-be-mutated values; this ensures that if you do a no-op then nothing happens
254-             modvals =  mod_og_val_fun (integ. u, integ. p, integ. t)
255-             upd_component_array =  NamedTuple {mod_names} (modvals)
256- 
257-             #  update the observed values
258-             obs_component_array =  NamedTuple {obs_sym_tuple} (obs_fun (
259-                 integ. u, integ. p, integ. t))
202+             mod =  mod_getter (integ)
203+             obs =  obs_getter (integ)
260204
261205            #  let the user do their thing
262-             upd_vals =  user_affect (upd_component_array, obs_component_array, ctx, integ)
263- 
264-             #  write the new values back to the integrator
265-             _generated_writeback (integ, upd_funs, upd_vals)
206+             upd_vals =  user_affect (mod, obs, ctx, integ)
207+             mod_setter (integ, upd_vals)
266208
267209            reset_jumps &&  reset_aggregated_jumps! (integ)
268210        end 
@@ -271,19 +213,22 @@ end
271213
272214scalarize_affects (affects:: ImperativeAffect ) =  affects
273215
274- function  vars! (vars, aff:: ImperativeAffect ; op =  Differential)
275-     for  var in  Iterators. flatten ((observed (aff), modified (aff)))
276-         if  symbolic_type (var) ==  NotSymbolic ()
277-             if  var isa  AbstractArray
278-                 for  v in  var
279-                     v =  unwrap (v)
280-                     vars! (vars, v)
281-                 end 
282-             end 
283-         else 
284-             var =  unwrap (var)
285-             vars! (vars, var)
216+ function  _vars_nt! (vars, nt:: NamedTuple , op)
217+     _vars_nt! (vars, values (nt), op)
218+ end 
219+ 
220+ function  _vars_nt! (vars, nt:: Union{AbstractArray, Tuple} , op)
221+     for  v in  nt
222+         if  symbolic_type (v) ==  NotSymbolic ()
223+             _vars_nt! (vars, v, op)
224+             continue 
286225        end 
226+         vars! (vars, v; op)
287227    end 
228+ end 
229+ 
230+ function  vars! (vars, aff:: ImperativeAffect ; op =  Differential)
231+     _vars_nt! (vars, aff. observed, op)
232+     _vars_nt! (vars, aff. modified, op)
288233    return  vars
289234end 
0 commit comments