Skip to content

Commit 1f86fa0

Browse files
refactor: use SII for ImperativeAffect
1 parent 4d65458 commit 1f86fa0

File tree

1 file changed

+87
-142
lines changed

1 file changed

+87
-142
lines changed

src/systems/imperative_affect.jl

Lines changed: 87 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,8 @@ in the returned tuple, in which case the associated field will not be updated.
3030
"""
3131
struct ImperativeAffect
3232
f::Any
33-
obs::Vector
34-
obs_syms::Vector{Symbol}
35-
modified::Vector
36-
mod_syms::Vector{Symbol}
33+
observed::NamedTuple
34+
modified::NamedTuple
3735
ctx::Any
3836
skip_checks::Bool
3937
end
@@ -43,10 +41,7 @@ function ImperativeAffect(f;
4341
modified::NamedTuple = NamedTuple{()}(()),
4442
ctx = nothing,
4543
skip_checks = false)
46-
ImperativeAffect(f,
47-
collect(values(observed)), collect(keys(observed)),
48-
collect(values(modified)), collect(keys(modified)),
49-
ctx, skip_checks)
44+
ImperativeAffect(f, observed, modified, ctx, skip_checks)
5045
end
5146
function ImperativeAffect(f, modified::NamedTuple;
5247
observed::NamedTuple = NamedTuple{()}(()), ctx = nothing, skip_checks = false)
@@ -68,61 +63,54 @@ function ImperativeAffect(; f, kwargs...)
6863
end
6964

7065
function Base.show(io::IO, mfa::ImperativeAffect)
71-
obs_vals = join(map((ob, nm) -> "$ob => $nm", mfa.obs, mfa.obs_syms), ", ")
72-
mod_vals = join(map((md, nm) -> "$md => $nm", mfa.modified, mfa.mod_syms), ", ")
66+
obs = mfa.observed
67+
mod = mfa.modified
7368
affect = mfa.f
7469
print(io,
75-
"ImperativeAffect(observed: [$obs_vals], modified: [$mod_vals], affect:$affect)")
70+
"ImperativeAffect(observed: [$(obs)], modified: [$(mod)], affect:$affect)")
7671
end
7772
func(f::ImperativeAffect) = f.f
7873
context(a::ImperativeAffect) = a.ctx
79-
observed(a::ImperativeAffect) = a.obs
80-
observed_syms(a::ImperativeAffect) = a.obs_syms
8174
function discretes(a::ImperativeAffect)
8275
Iterators.filter(ModelingToolkit.isparameter,
8376
Iterators.flatten(Iterators.map(
8477
x -> symbolic_type(x) == NotSymbolic() && x isa AbstractArray ? x : [x],
8578
a.modified)))
8679
end
87-
modified(a::ImperativeAffect) = a.modified
88-
modified_syms(a::ImperativeAffect) = a.mod_syms
8980

9081
function Base.:(==)(a1::ImperativeAffect, a2::ImperativeAffect)
91-
isequal(a1.f, a2.f) && isequal(a1.obs, a2.obs) && isequal(a1.modified, a2.modified) &&
92-
isequal(a1.obs_syms, a2.obs_syms) && isequal(a1.mod_syms, a2.mod_syms) &&
82+
isequal(a1.f, a2.f) && isequal(a1.observed, a2.observed) &&
83+
isequal(a1.modified, a2.modified) &&
9384
isequal(a1.ctx, a2.ctx)
9485
end
9586

9687
function Base.hash(a::ImperativeAffect, s::UInt)
9788
s = hash(a.f, s)
98-
s = hash(a.obs, s)
99-
s = hash(a.obs_syms, s)
89+
s = hash(a.observed, s)
10090
s = hash(a.modified, s)
101-
s = hash(a.mod_syms, s)
10291
hash(a.ctx, s)
10392
end
10493

10594
namespace_affects(af::ImperativeAffect, s) = namespace_affect(af, s)
106-
function namespace_affect(affect::ImperativeAffect, s)
107-
rmn = []
108-
for modded in modified(affect)
109-
if symbolic_type(modded) == NotSymbolic() && modded isa AbstractArray
110-
res = []
111-
for m in modded
112-
push!(res, renamespace(s, m))
113-
end
114-
push!(rmn, res)
95+
96+
function _namespace_nt(nt::NamedTuple, s::AbstractSystem)
97+
return NamedTuple{keys(nt)}(_namespace_nt(values(nt), s))
98+
end
99+
100+
function _namespace_nt(nt::Union{AbstractArray, Tuple}, s::AbstractSystem)
101+
return map(length(nt)) do v
102+
if symbolic_type(v) == NotSymbolic()
103+
_namespace_nt(v, s)
115104
else
116-
push!(rmn, renamespace(s, modded))
105+
renamespace(s, v)
117106
end
118107
end
119-
ImperativeAffect(func(affect),
120-
namespace_expr.(observed(affect), (s,)),
121-
observed_syms(affect),
122-
rmn,
123-
modified_syms(affect),
124-
context(affect),
125-
affect.skip_checks)
108+
end
109+
110+
function namespace_affect(affect::ImperativeAffect, s)
111+
obs = _namespace_nt(affect.observed, s)
112+
mod = _namespace_nt(affect.modified, s)
113+
ImperativeAffect(affect.f, obs, mod, affect.ctx, affect.skip_checks)
126114
end
127115

128116
function invalid_variables(sys, expr)
@@ -139,21 +127,6 @@ function unassignable_variables(sys, expr)
139127
x -> !any(isequal(x), assignable_syms), written)
140128
end
141129

142-
@generated function _generated_writeback(integ, setters::NamedTuple{NS1, <:Tuple},
143-
values::NamedTuple{NS2, <:Tuple}) where {NS1, NS2}
144-
setter_exprs = []
145-
for name in NS2
146-
if !(name in NS1)
147-
missing_name = "Tried to write back to $name from affect; only declared states ($NS1) may be written to."
148-
error(missing_name)
149-
end
150-
push!(setter_exprs, :(setters.$name(integ, values.$name)))
151-
end
152-
return :(begin
153-
$(setter_exprs...)
154-
end)
155-
end
156-
157130
function check_assignable(sys, sym)
158131
if symbolic_type(sym) == ScalarSymbolic()
159132
is_variable(sys, sym) || is_parameter(sys, sym)
@@ -167,6 +140,41 @@ function check_assignable(sys, sym)
167140
end
168141
end
169142

143+
function _nt_check_valid(nt::NamedTuple, s::AbstractSystem, isobserved::Bool)
144+
_nt_check_valid(values(nt), s, isobserved)
145+
end
146+
147+
function _nt_check_valid(nt::Union{Tuple, AbstractArray}, s::AbstractSystem, isobserved::Bool)
148+
for v in nt
149+
if symbolic_type(v) == NotSymbolic()
150+
_nt_check_valid(v, s, isobserved)
151+
continue
152+
end
153+
if !isobserved && !check_assignable(s, v)
154+
@warn """
155+
Expression $v cannot be assigned to; currently only unknowns and parameters may \
156+
be updated by an affect.
157+
"""
158+
end
159+
invalid = invalid_variables(s, v)
160+
isempty(invalid) && continue
161+
name = isobserved ? "Observed" : "Modified"
162+
error("""
163+
$name expression $(v) in affect refers to missing variable(s) $(invalid); \
164+
the variables may not have been added (e.g. if a component is missing).
165+
""")
166+
end
167+
end
168+
169+
function _nt_check_overlap(nta::NamedTuple, ntb::NamedTuple)
170+
common = intersect(keys(nta), keys(ntb))
171+
isempty(common) && return
172+
@warn """
173+
The symbols $common are declared as both observed and modified; this is a code smell \
174+
because it becomes easy to confuse them and assign/not assign a value.
175+
"""
176+
end
177+
170178
function compile_functional_affect(
171179
affect::ImperativeAffect, sys; reset_jumps = false, kwargs...)
172180
#=
@@ -176,93 +184,27 @@ function compile_functional_affect(
176184
call the affect method
177185
unpack and apply the resulting values
178186
=#
179-
function check_dups(syms, exprs) # = (syms_dedup, exprs_dedup)
180-
seen = Set{Symbol}()
181-
syms_dedup = []
182-
exprs_dedup = []
183-
for (sym, exp) in Iterators.zip(syms, exprs)
184-
if !in(sym, seen)
185-
push!(syms_dedup, sym)
186-
push!(exprs_dedup, exp)
187-
push!(seen, sym)
188-
elseif !affect.skip_checks
189-
@warn "Expression $(expr) is aliased as $sym, which has already been used. The first definition will be used."
190-
end
191-
end
192-
return (syms_dedup, exprs_dedup)
193-
end
194187

195-
dvs = unknowns(sys)
196-
ps = parameters(sys)
197-
198-
obs_exprs = observed(affect)
199-
if !affect.skip_checks
200-
for oexpr in obs_exprs
201-
invalid_vars = invalid_variables(sys, oexpr)
202-
if length(invalid_vars) > 0
203-
error("Observed equation $(oexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing).")
204-
end
205-
end
206-
end
207-
obs_syms = observed_syms(affect)
208-
obs_syms, obs_exprs = check_dups(obs_syms, obs_exprs)
209-
210-
mod_exprs = modified(affect)
211188
if !affect.skip_checks
212-
for mexpr in mod_exprs
213-
if !check_assignable(sys, mexpr)
214-
@warn ("Expression $mexpr cannot be assigned to; currently only unknowns and parameters may be updated by an affect.")
215-
end
216-
invalid_vars = unassignable_variables(sys, mexpr)
217-
if length(invalid_vars) > 0
218-
error("Modified equation $(mexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing) or they may have been reduced away.")
219-
end
220-
end
221-
end
222-
mod_syms = modified_syms(affect)
223-
mod_syms, mod_exprs = check_dups(mod_syms, mod_exprs)
224-
225-
overlapping_syms = intersect(mod_syms, obs_syms)
226-
if length(overlapping_syms) > 0 && !affect.skip_checks
227-
@warn "The symbols $overlapping_syms are declared as both observed and modified; this is a code smell because it becomes easy to confuse them and assign/not assign a value."
189+
_nt_check_valid(affect.observed, sys, true)
190+
_nt_check_valid(affect.modified, sys, false)
191+
_nt_check_overlap(affect.observed, affect.modified)
228192
end
229193

230194
# sanity checks done! now build the data and update function for observed values
231-
mkzero(sz) =
232-
if sz === ()
233-
0.0
234-
else
235-
zeros(sz)
236-
end
237-
obs_fun = build_explicit_observed_function(
238-
sys, Symbolics.scalarize.(obs_exprs);
239-
mkarray = (es, _) -> MakeTuple(es))
240-
obs_sym_tuple = (obs_syms...,)
241-
242-
# okay so now to generate the stuff to assign it back into the system
243-
mod_pairs = mod_exprs .=> mod_syms
244-
mod_names = (mod_syms...,)
245-
mod_og_val_fun = build_explicit_observed_function(
246-
sys, Symbolics.scalarize.(first.(mod_pairs));
247-
mkarray = (es, _) -> MakeTuple(es))
195+
let user_affect = func(affect), ctx = context(affect),
196+
obs_getter = getsym(sys, affect.observed),
197+
mod_getter = getsym(sys, affect.modified),
198+
mod_setter = setsym(sys, affect.modified),
199+
reset_jumps = reset_jumps
248200

249-
upd_funs = NamedTuple{mod_names}((setu.((sys,), first.(mod_pairs))...,))
250-
251-
let user_affect = func(affect), ctx = context(affect), reset_jumps = reset_jumps
252201
@inline function (integ)
253-
# update the to-be-mutated values; this ensures that if you do a no-op then nothing happens
254-
modvals = mod_og_val_fun(integ.u, integ.p, integ.t)
255-
upd_component_array = NamedTuple{mod_names}(modvals)
256-
257-
# update the observed values
258-
obs_component_array = NamedTuple{obs_sym_tuple}(obs_fun(
259-
integ.u, integ.p, integ.t))
202+
mod = mod_getter(integ)
203+
obs = obs_getter(integ)
260204

261205
# let the user do their thing
262-
upd_vals = user_affect(upd_component_array, obs_component_array, ctx, integ)
263-
264-
# write the new values back to the integrator
265-
_generated_writeback(integ, upd_funs, upd_vals)
206+
upd_vals = user_affect(mod, obs, ctx, integ)
207+
mod_setter(integ, upd_vals)
266208

267209
reset_jumps && reset_aggregated_jumps!(integ)
268210
end
@@ -271,19 +213,22 @@ end
271213

272214
scalarize_affects(affects::ImperativeAffect) = affects
273215

274-
function vars!(vars, aff::ImperativeAffect; op = Differential)
275-
for var in Iterators.flatten((observed(aff), modified(aff)))
276-
if symbolic_type(var) == NotSymbolic()
277-
if var isa AbstractArray
278-
for v in var
279-
v = unwrap(v)
280-
vars!(vars, v)
281-
end
282-
end
283-
else
284-
var = unwrap(var)
285-
vars!(vars, var)
216+
function _vars_nt!(vars, nt::NamedTuple, op)
217+
_vars_nt!(vars, values(nt), op)
218+
end
219+
220+
function _vars_nt!(vars, nt::Union{AbstractArray, Tuple}, op)
221+
for v in nt
222+
if symbolic_type(v) == NotSymbolic()
223+
_vars_nt!(vars, v, op)
224+
continue
286225
end
226+
vars!(vars, v; op)
287227
end
228+
end
229+
230+
function vars!(vars, aff::ImperativeAffect; op = Differential)
231+
_vars_nt!(vars, aff.observed, op)
232+
_vars_nt!(vars, aff.modified, op)
288233
return vars
289234
end

0 commit comments

Comments
 (0)