@@ -30,10 +30,8 @@ in the returned tuple, in which case the associated field will not be updated.
30
30
"""
31
31
struct ImperativeAffect
32
32
f:: Any
33
- obs:: Vector
34
- obs_syms:: Vector{Symbol}
35
- modified:: Vector
36
- mod_syms:: Vector{Symbol}
33
+ observed:: NamedTuple
34
+ modified:: NamedTuple
37
35
ctx:: Any
38
36
skip_checks:: Bool
39
37
end
@@ -43,10 +41,7 @@ function ImperativeAffect(f;
43
41
modified:: NamedTuple = NamedTuple {()} (()),
44
42
ctx = nothing ,
45
43
skip_checks = false )
46
- ImperativeAffect (f,
47
- collect (values (observed)), collect (keys (observed)),
48
- collect (values (modified)), collect (keys (modified)),
49
- ctx, skip_checks)
44
+ ImperativeAffect (f, observed, modified, ctx, skip_checks)
50
45
end
51
46
function ImperativeAffect (f, modified:: NamedTuple ;
52
47
observed:: NamedTuple = NamedTuple {()} (()), ctx = nothing , skip_checks = false )
@@ -68,61 +63,54 @@ function ImperativeAffect(; f, kwargs...)
68
63
end
69
64
70
65
function Base. show (io:: IO , mfa:: ImperativeAffect )
71
- obs_vals = join ( map ((ob, nm) -> " $ob => $nm " , mfa. obs, mfa . obs_syms), " , " )
72
- mod_vals = join ( map ((md, nm) -> " $md => $nm " , mfa. modified, mfa . mod_syms), " , " )
66
+ obs = mfa. observed
67
+ mod = mfa. modified
73
68
affect = mfa. f
74
69
print (io,
75
- " ImperativeAffect(observed: [$obs_vals ], modified: [$mod_vals ], affect:$affect )" )
70
+ " ImperativeAffect(observed: [$(obs) ], modified: [$(mod) ], affect:$affect )" )
76
71
end
77
72
func (f:: ImperativeAffect ) = f. f
78
73
context (a:: ImperativeAffect ) = a. ctx
79
- observed (a:: ImperativeAffect ) = a. obs
80
- observed_syms (a:: ImperativeAffect ) = a. obs_syms
81
74
function discretes (a:: ImperativeAffect )
82
75
Iterators. filter (ModelingToolkit. isparameter,
83
76
Iterators. flatten (Iterators. map (
84
77
x -> symbolic_type (x) == NotSymbolic () && x isa AbstractArray ? x : [x],
85
78
a. modified)))
86
79
end
87
- modified (a:: ImperativeAffect ) = a. modified
88
- modified_syms (a:: ImperativeAffect ) = a. mod_syms
89
80
90
81
function Base.:(== )(a1:: ImperativeAffect , a2:: ImperativeAffect )
91
- isequal (a1. f, a2. f) && isequal (a1. obs , a2. obs) && isequal (a1 . modified, a2 . modified ) &&
92
- isequal (a1. obs_syms , a2. obs_syms) && isequal (a1 . mod_syms, a2 . mod_syms ) &&
82
+ isequal (a1. f, a2. f) && isequal (a1. observed , a2. observed ) &&
83
+ isequal (a1. modified , a2. modified ) &&
93
84
isequal (a1. ctx, a2. ctx)
94
85
end
95
86
96
87
function Base. hash (a:: ImperativeAffect , s:: UInt )
97
88
s = hash (a. f, s)
98
- s = hash (a. obs, s)
99
- s = hash (a. obs_syms, s)
89
+ s = hash (a. observed, s)
100
90
s = hash (a. modified, s)
101
- s = hash (a. mod_syms, s)
102
91
hash (a. ctx, s)
103
92
end
104
93
105
94
namespace_affects (af:: ImperativeAffect , s) = namespace_affect (af, s)
106
- function namespace_affect (affect :: ImperativeAffect , s)
107
- rmn = []
108
- for modded in modified (affect )
109
- if symbolic_type (modded) == NotSymbolic () && modded isa AbstractArray
110
- res = []
111
- for m in modded
112
- push! (res, renamespace (s, m))
113
- end
114
- push! (rmn, res )
95
+
96
+ function _namespace_nt (nt :: NamedTuple , s :: AbstractSystem )
97
+ return NamedTuple {keys(nt)} ( _namespace_nt ( values (nt), s) )
98
+ end
99
+
100
+ function _namespace_nt (nt :: Union{AbstractArray, Tuple} , s :: AbstractSystem )
101
+ return map ( length (nt)) do v
102
+ if symbolic_type (v) == NotSymbolic ()
103
+ _namespace_nt (v, s )
115
104
else
116
- push! (rmn, renamespace (s, modded) )
105
+ renamespace (s, v )
117
106
end
118
107
end
119
- ImperativeAffect (func (affect),
120
- namespace_expr .(observed (affect), (s,)),
121
- observed_syms (affect),
122
- rmn,
123
- modified_syms (affect),
124
- context (affect),
125
- affect. skip_checks)
108
+ end
109
+
110
+ function namespace_affect (affect:: ImperativeAffect , s)
111
+ obs = _namespace_nt (affect. observed, s)
112
+ mod = _namespace_nt (affect. modified, s)
113
+ ImperativeAffect (affect. f, obs, mod, affect. ctx, affect. skip_checks)
126
114
end
127
115
128
116
function invalid_variables (sys, expr)
@@ -139,21 +127,6 @@ function unassignable_variables(sys, expr)
139
127
x -> ! any (isequal (x), assignable_syms), written)
140
128
end
141
129
142
- @generated function _generated_writeback (integ, setters:: NamedTuple{NS1, <:Tuple} ,
143
- values:: NamedTuple{NS2, <:Tuple} ) where {NS1, NS2}
144
- setter_exprs = []
145
- for name in NS2
146
- if ! (name in NS1)
147
- missing_name = " Tried to write back to $name from affect; only declared states ($NS1 ) may be written to."
148
- error (missing_name)
149
- end
150
- push! (setter_exprs, :(setters.$ name (integ, values.$ name)))
151
- end
152
- return :(begin
153
- $ (setter_exprs... )
154
- end )
155
- end
156
-
157
130
function check_assignable (sys, sym)
158
131
if symbolic_type (sym) == ScalarSymbolic ()
159
132
is_variable (sys, sym) || is_parameter (sys, sym)
@@ -167,6 +140,41 @@ function check_assignable(sys, sym)
167
140
end
168
141
end
169
142
143
+ function _nt_check_valid (nt:: NamedTuple , s:: AbstractSystem , isobserved:: Bool )
144
+ _nt_check_valid (values (nt), s, isobserved)
145
+ end
146
+
147
+ function _nt_check_valid (nt:: Union{Tuple, AbstractArray} , s:: AbstractSystem , isobserved:: Bool )
148
+ for v in nt
149
+ if symbolic_type (v) == NotSymbolic ()
150
+ _nt_check_valid (v, s, isobserved)
151
+ continue
152
+ end
153
+ if ! isobserved && ! check_assignable (s, v)
154
+ @warn """
155
+ Expression $v cannot be assigned to; currently only unknowns and parameters may \
156
+ be updated by an affect.
157
+ """
158
+ end
159
+ invalid = invalid_variables (s, v)
160
+ isempty (invalid) && continue
161
+ name = isobserved ? " Observed" : " Modified"
162
+ error ("""
163
+ $name expression $(v) in affect refers to missing variable(s) $(invalid) ; \
164
+ the variables may not have been added (e.g. if a component is missing).
165
+ """ )
166
+ end
167
+ end
168
+
169
+ function _nt_check_overlap (nta:: NamedTuple , ntb:: NamedTuple )
170
+ common = intersect (keys (nta), keys (ntb))
171
+ isempty (common) && return
172
+ @warn """
173
+ The symbols $common are declared as both observed and modified; this is a code smell \
174
+ because it becomes easy to confuse them and assign/not assign a value.
175
+ """
176
+ end
177
+
170
178
function compile_functional_affect (
171
179
affect:: ImperativeAffect , sys; reset_jumps = false , kwargs... )
172
180
#=
@@ -176,93 +184,27 @@ function compile_functional_affect(
176
184
call the affect method
177
185
unpack and apply the resulting values
178
186
=#
179
- function check_dups (syms, exprs) # = (syms_dedup, exprs_dedup)
180
- seen = Set {Symbol} ()
181
- syms_dedup = []
182
- exprs_dedup = []
183
- for (sym, exp) in Iterators. zip (syms, exprs)
184
- if ! in (sym, seen)
185
- push! (syms_dedup, sym)
186
- push! (exprs_dedup, exp)
187
- push! (seen, sym)
188
- elseif ! affect. skip_checks
189
- @warn " Expression $(expr) is aliased as $sym , which has already been used. The first definition will be used."
190
- end
191
- end
192
- return (syms_dedup, exprs_dedup)
193
- end
194
187
195
- dvs = unknowns (sys)
196
- ps = parameters (sys)
197
-
198
- obs_exprs = observed (affect)
199
- if ! affect. skip_checks
200
- for oexpr in obs_exprs
201
- invalid_vars = invalid_variables (sys, oexpr)
202
- if length (invalid_vars) > 0
203
- error (" Observed equation $(oexpr) in affect refers to missing variable(s) $(invalid_vars) ; the variables may not have been added (e.g. if a component is missing)." )
204
- end
205
- end
206
- end
207
- obs_syms = observed_syms (affect)
208
- obs_syms, obs_exprs = check_dups (obs_syms, obs_exprs)
209
-
210
- mod_exprs = modified (affect)
211
188
if ! affect. skip_checks
212
- for mexpr in mod_exprs
213
- if ! check_assignable (sys, mexpr)
214
- @warn (" Expression $mexpr cannot be assigned to; currently only unknowns and parameters may be updated by an affect." )
215
- end
216
- invalid_vars = unassignable_variables (sys, mexpr)
217
- if length (invalid_vars) > 0
218
- error (" Modified equation $(mexpr) in affect refers to missing variable(s) $(invalid_vars) ; the variables may not have been added (e.g. if a component is missing) or they may have been reduced away." )
219
- end
220
- end
221
- end
222
- mod_syms = modified_syms (affect)
223
- mod_syms, mod_exprs = check_dups (mod_syms, mod_exprs)
224
-
225
- overlapping_syms = intersect (mod_syms, obs_syms)
226
- if length (overlapping_syms) > 0 && ! affect. skip_checks
227
- @warn " The symbols $overlapping_syms are declared as both observed and modified; this is a code smell because it becomes easy to confuse them and assign/not assign a value."
189
+ _nt_check_valid (affect. observed, sys, true )
190
+ _nt_check_valid (affect. modified, sys, false )
191
+ _nt_check_overlap (affect. observed, affect. modified)
228
192
end
229
193
230
194
# sanity checks done! now build the data and update function for observed values
231
- mkzero (sz) =
232
- if sz === ()
233
- 0.0
234
- else
235
- zeros (sz)
236
- end
237
- obs_fun = build_explicit_observed_function (
238
- sys, Symbolics. scalarize .(obs_exprs);
239
- mkarray = (es, _) -> MakeTuple (es))
240
- obs_sym_tuple = (obs_syms... ,)
241
-
242
- # okay so now to generate the stuff to assign it back into the system
243
- mod_pairs = mod_exprs .=> mod_syms
244
- mod_names = (mod_syms... ,)
245
- mod_og_val_fun = build_explicit_observed_function (
246
- sys, Symbolics. scalarize .(first .(mod_pairs));
247
- mkarray = (es, _) -> MakeTuple (es))
195
+ let user_affect = func (affect), ctx = context (affect),
196
+ obs_getter = getsym (sys, affect. observed),
197
+ mod_getter = getsym (sys, affect. modified),
198
+ mod_setter = setsym (sys, affect. modified),
199
+ reset_jumps = reset_jumps
248
200
249
- upd_funs = NamedTuple {mod_names} ((setu .((sys,), first .(mod_pairs))... ,))
250
-
251
- let user_affect = func (affect), ctx = context (affect), reset_jumps = reset_jumps
252
201
@inline function (integ)
253
- # update the to-be-mutated values; this ensures that if you do a no-op then nothing happens
254
- modvals = mod_og_val_fun (integ. u, integ. p, integ. t)
255
- upd_component_array = NamedTuple {mod_names} (modvals)
256
-
257
- # update the observed values
258
- obs_component_array = NamedTuple {obs_sym_tuple} (obs_fun (
259
- integ. u, integ. p, integ. t))
202
+ mod = mod_getter (integ)
203
+ obs = obs_getter (integ)
260
204
261
205
# let the user do their thing
262
- upd_vals = user_affect (upd_component_array, obs_component_array, ctx, integ)
263
-
264
- # write the new values back to the integrator
265
- _generated_writeback (integ, upd_funs, upd_vals)
206
+ upd_vals = user_affect (mod, obs, ctx, integ)
207
+ mod_setter (integ, upd_vals)
266
208
267
209
reset_jumps && reset_aggregated_jumps! (integ)
268
210
end
@@ -271,19 +213,22 @@ end
271
213
272
214
scalarize_affects (affects:: ImperativeAffect ) = affects
273
215
274
- function vars! (vars, aff:: ImperativeAffect ; op = Differential)
275
- for var in Iterators. flatten ((observed (aff), modified (aff)))
276
- if symbolic_type (var) == NotSymbolic ()
277
- if var isa AbstractArray
278
- for v in var
279
- v = unwrap (v)
280
- vars! (vars, v)
281
- end
282
- end
283
- else
284
- var = unwrap (var)
285
- vars! (vars, var)
216
+ function _vars_nt! (vars, nt:: NamedTuple , op)
217
+ _vars_nt! (vars, values (nt), op)
218
+ end
219
+
220
+ function _vars_nt! (vars, nt:: Union{AbstractArray, Tuple} , op)
221
+ for v in nt
222
+ if symbolic_type (v) == NotSymbolic ()
223
+ _vars_nt! (vars, v, op)
224
+ continue
286
225
end
226
+ vars! (vars, v; op)
287
227
end
228
+ end
229
+
230
+ function vars! (vars, aff:: ImperativeAffect ; op = Differential)
231
+ _vars_nt! (vars, aff. observed, op)
232
+ _vars_nt! (vars, aff. modified, op)
288
233
return vars
289
234
end
0 commit comments