Skip to content

Commit 9102a79

Browse files
refactor: use SII for ImperativeAffect
1 parent 4d65458 commit 9102a79

File tree

1 file changed

+88
-142
lines changed

1 file changed

+88
-142
lines changed

src/systems/imperative_affect.jl

Lines changed: 88 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,8 @@ in the returned tuple, in which case the associated field will not be updated.
3030
"""
3131
struct ImperativeAffect
3232
f::Any
33-
obs::Vector
34-
obs_syms::Vector{Symbol}
35-
modified::Vector
36-
mod_syms::Vector{Symbol}
33+
observed::NamedTuple
34+
modified::NamedTuple
3735
ctx::Any
3836
skip_checks::Bool
3937
end
@@ -43,10 +41,7 @@ function ImperativeAffect(f;
4341
modified::NamedTuple = NamedTuple{()}(()),
4442
ctx = nothing,
4543
skip_checks = false)
46-
ImperativeAffect(f,
47-
collect(values(observed)), collect(keys(observed)),
48-
collect(values(modified)), collect(keys(modified)),
49-
ctx, skip_checks)
44+
ImperativeAffect(f, observed, modified, ctx, skip_checks)
5045
end
5146
function ImperativeAffect(f, modified::NamedTuple;
5247
observed::NamedTuple = NamedTuple{()}(()), ctx = nothing, skip_checks = false)
@@ -68,61 +63,54 @@ function ImperativeAffect(; f, kwargs...)
6863
end
6964

7065
function Base.show(io::IO, mfa::ImperativeAffect)
71-
obs_vals = join(map((ob, nm) -> "$ob => $nm", mfa.obs, mfa.obs_syms), ", ")
72-
mod_vals = join(map((md, nm) -> "$md => $nm", mfa.modified, mfa.mod_syms), ", ")
66+
obs = mfa.observed
67+
mod = mfa.modified
7368
affect = mfa.f
7469
print(io,
75-
"ImperativeAffect(observed: [$obs_vals], modified: [$mod_vals], affect:$affect)")
70+
"ImperativeAffect(observed: [$(obs)], modified: [$(mod)], affect:$affect)")
7671
end
7772
func(f::ImperativeAffect) = f.f
7873
context(a::ImperativeAffect) = a.ctx
79-
observed(a::ImperativeAffect) = a.obs
80-
observed_syms(a::ImperativeAffect) = a.obs_syms
8174
function discretes(a::ImperativeAffect)
8275
Iterators.filter(ModelingToolkit.isparameter,
8376
Iterators.flatten(Iterators.map(
8477
x -> symbolic_type(x) == NotSymbolic() && x isa AbstractArray ? x : [x],
8578
a.modified)))
8679
end
87-
modified(a::ImperativeAffect) = a.modified
88-
modified_syms(a::ImperativeAffect) = a.mod_syms
8980

9081
function Base.:(==)(a1::ImperativeAffect, a2::ImperativeAffect)
91-
isequal(a1.f, a2.f) && isequal(a1.obs, a2.obs) && isequal(a1.modified, a2.modified) &&
92-
isequal(a1.obs_syms, a2.obs_syms) && isequal(a1.mod_syms, a2.mod_syms) &&
82+
isequal(a1.f, a2.f) && isequal(a1.observed, a2.observed) &&
83+
isequal(a1.modified, a2.modified) &&
9384
isequal(a1.ctx, a2.ctx)
9485
end
9586

9687
function Base.hash(a::ImperativeAffect, s::UInt)
9788
s = hash(a.f, s)
98-
s = hash(a.obs, s)
99-
s = hash(a.obs_syms, s)
89+
s = hash(a.observed, s)
10090
s = hash(a.modified, s)
101-
s = hash(a.mod_syms, s)
10291
hash(a.ctx, s)
10392
end
10493

10594
namespace_affects(af::ImperativeAffect, s) = namespace_affect(af, s)
106-
function namespace_affect(affect::ImperativeAffect, s)
107-
rmn = []
108-
for modded in modified(affect)
109-
if symbolic_type(modded) == NotSymbolic() && modded isa AbstractArray
110-
res = []
111-
for m in modded
112-
push!(res, renamespace(s, m))
113-
end
114-
push!(rmn, res)
95+
96+
function _namespace_nt(nt::NamedTuple, s::AbstractSystem)
97+
return NamedTuple{keys(nt)}(_namespace_nt(values(nt), s))
98+
end
99+
100+
function _namespace_nt(nt::Union{AbstractArray, Tuple}, s::AbstractSystem)
101+
return map(nt) do v
102+
if symbolic_type(v) == NotSymbolic()
103+
_namespace_nt(v, s)
115104
else
116-
push!(rmn, renamespace(s, modded))
105+
renamespace(s, v)
117106
end
118107
end
119-
ImperativeAffect(func(affect),
120-
namespace_expr.(observed(affect), (s,)),
121-
observed_syms(affect),
122-
rmn,
123-
modified_syms(affect),
124-
context(affect),
125-
affect.skip_checks)
108+
end
109+
110+
function namespace_affect(affect::ImperativeAffect, s)
111+
obs = _namespace_nt(affect.observed, s)
112+
mod = _namespace_nt(affect.modified, s)
113+
ImperativeAffect(affect.f, obs, mod, affect.ctx, affect.skip_checks)
126114
end
127115

128116
function invalid_variables(sys, expr)
@@ -139,21 +127,6 @@ function unassignable_variables(sys, expr)
139127
x -> !any(isequal(x), assignable_syms), written)
140128
end
141129

142-
@generated function _generated_writeback(integ, setters::NamedTuple{NS1, <:Tuple},
143-
values::NamedTuple{NS2, <:Tuple}) where {NS1, NS2}
144-
setter_exprs = []
145-
for name in NS2
146-
if !(name in NS1)
147-
missing_name = "Tried to write back to $name from affect; only declared states ($NS1) may be written to."
148-
error(missing_name)
149-
end
150-
push!(setter_exprs, :(setters.$name(integ, values.$name)))
151-
end
152-
return :(begin
153-
$(setter_exprs...)
154-
end)
155-
end
156-
157130
function check_assignable(sys, sym)
158131
if symbolic_type(sym) == ScalarSymbolic()
159132
is_variable(sys, sym) || is_parameter(sys, sym)
@@ -167,6 +140,42 @@ function check_assignable(sys, sym)
167140
end
168141
end
169142

143+
function _nt_check_valid(nt::NamedTuple, s::AbstractSystem, isobserved::Bool)
144+
_nt_check_valid(values(nt), s, isobserved)
145+
end
146+
147+
function _nt_check_valid(
148+
nt::Union{Tuple, AbstractArray}, s::AbstractSystem, isobserved::Bool)
149+
for v in nt
150+
if symbolic_type(v) == NotSymbolic()
151+
_nt_check_valid(v, s, isobserved)
152+
continue
153+
end
154+
if !isobserved && !check_assignable(s, v)
155+
error("""
156+
Expression $v cannot be assigned to; currently only unknowns and parameters may \
157+
be updated by an affect.
158+
""")
159+
end
160+
invalid = invalid_variables(s, v)
161+
isempty(invalid) && continue
162+
name = isobserved ? "Observed" : "Modified"
163+
error("""
164+
$name expression $(v) in affect refers to missing variable(s) $(invalid); \
165+
the variables may not have been added (e.g. if a component is missing).
166+
""")
167+
end
168+
end
169+
170+
function _nt_check_overlap(nta::NamedTuple, ntb::NamedTuple)
171+
common = intersect(keys(nta), keys(ntb))
172+
isempty(common) && return
173+
@warn """
174+
The symbols $common are declared as both observed and modified; this is a code smell \
175+
because it becomes easy to confuse them and assign/not assign a value.
176+
"""
177+
end
178+
170179
function compile_functional_affect(
171180
affect::ImperativeAffect, sys; reset_jumps = false, kwargs...)
172181
#=
@@ -176,93 +185,27 @@ function compile_functional_affect(
176185
call the affect method
177186
unpack and apply the resulting values
178187
=#
179-
function check_dups(syms, exprs) # = (syms_dedup, exprs_dedup)
180-
seen = Set{Symbol}()
181-
syms_dedup = []
182-
exprs_dedup = []
183-
for (sym, exp) in Iterators.zip(syms, exprs)
184-
if !in(sym, seen)
185-
push!(syms_dedup, sym)
186-
push!(exprs_dedup, exp)
187-
push!(seen, sym)
188-
elseif !affect.skip_checks
189-
@warn "Expression $(expr) is aliased as $sym, which has already been used. The first definition will be used."
190-
end
191-
end
192-
return (syms_dedup, exprs_dedup)
193-
end
194188

195-
dvs = unknowns(sys)
196-
ps = parameters(sys)
197-
198-
obs_exprs = observed(affect)
199-
if !affect.skip_checks
200-
for oexpr in obs_exprs
201-
invalid_vars = invalid_variables(sys, oexpr)
202-
if length(invalid_vars) > 0
203-
error("Observed equation $(oexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing).")
204-
end
205-
end
206-
end
207-
obs_syms = observed_syms(affect)
208-
obs_syms, obs_exprs = check_dups(obs_syms, obs_exprs)
209-
210-
mod_exprs = modified(affect)
211189
if !affect.skip_checks
212-
for mexpr in mod_exprs
213-
if !check_assignable(sys, mexpr)
214-
@warn ("Expression $mexpr cannot be assigned to; currently only unknowns and parameters may be updated by an affect.")
215-
end
216-
invalid_vars = unassignable_variables(sys, mexpr)
217-
if length(invalid_vars) > 0
218-
error("Modified equation $(mexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing) or they may have been reduced away.")
219-
end
220-
end
221-
end
222-
mod_syms = modified_syms(affect)
223-
mod_syms, mod_exprs = check_dups(mod_syms, mod_exprs)
224-
225-
overlapping_syms = intersect(mod_syms, obs_syms)
226-
if length(overlapping_syms) > 0 && !affect.skip_checks
227-
@warn "The symbols $overlapping_syms are declared as both observed and modified; this is a code smell because it becomes easy to confuse them and assign/not assign a value."
190+
_nt_check_valid(affect.observed, sys, true)
191+
_nt_check_valid(affect.modified, sys, false)
192+
_nt_check_overlap(affect.observed, affect.modified)
228193
end
229194

230195
# sanity checks done! now build the data and update function for observed values
231-
mkzero(sz) =
232-
if sz === ()
233-
0.0
234-
else
235-
zeros(sz)
236-
end
237-
obs_fun = build_explicit_observed_function(
238-
sys, Symbolics.scalarize.(obs_exprs);
239-
mkarray = (es, _) -> MakeTuple(es))
240-
obs_sym_tuple = (obs_syms...,)
241-
242-
# okay so now to generate the stuff to assign it back into the system
243-
mod_pairs = mod_exprs .=> mod_syms
244-
mod_names = (mod_syms...,)
245-
mod_og_val_fun = build_explicit_observed_function(
246-
sys, Symbolics.scalarize.(first.(mod_pairs));
247-
mkarray = (es, _) -> MakeTuple(es))
196+
let user_affect = func(affect), ctx = context(affect),
197+
obs_getter = isempty(affect.observed) ? Returns((;)) : getsym(sys, affect.observed),
198+
mod_getter = isempty(affect.modified) ? Returns((;)) : getsym(sys, affect.modified),
199+
mod_setter = isempty(affect.modified) ? Returns((;)) : setsym(sys, affect.modified),
200+
reset_jumps = reset_jumps
248201

249-
upd_funs = NamedTuple{mod_names}((setu.((sys,), first.(mod_pairs))...,))
250-
251-
let user_affect = func(affect), ctx = context(affect), reset_jumps = reset_jumps
252202
@inline function (integ)
253-
# update the to-be-mutated values; this ensures that if you do a no-op then nothing happens
254-
modvals = mod_og_val_fun(integ.u, integ.p, integ.t)
255-
upd_component_array = NamedTuple{mod_names}(modvals)
256-
257-
# update the observed values
258-
obs_component_array = NamedTuple{obs_sym_tuple}(obs_fun(
259-
integ.u, integ.p, integ.t))
203+
mod = mod_getter(integ)
204+
obs = obs_getter(integ)
260205

261206
# let the user do their thing
262-
upd_vals = user_affect(upd_component_array, obs_component_array, ctx, integ)
263-
264-
# write the new values back to the integrator
265-
_generated_writeback(integ, upd_funs, upd_vals)
207+
upd_vals = user_affect(mod, obs, ctx, integ)
208+
mod_setter(integ, upd_vals)
266209

267210
reset_jumps && reset_aggregated_jumps!(integ)
268211
end
@@ -271,19 +214,22 @@ end
271214

272215
scalarize_affects(affects::ImperativeAffect) = affects
273216

274-
function vars!(vars, aff::ImperativeAffect; op = Differential)
275-
for var in Iterators.flatten((observed(aff), modified(aff)))
276-
if symbolic_type(var) == NotSymbolic()
277-
if var isa AbstractArray
278-
for v in var
279-
v = unwrap(v)
280-
vars!(vars, v)
281-
end
282-
end
283-
else
284-
var = unwrap(var)
285-
vars!(vars, var)
217+
function _vars_nt!(vars, nt::NamedTuple, op)
218+
_vars_nt!(vars, values(nt), op)
219+
end
220+
221+
function _vars_nt!(vars, nt::Union{AbstractArray, Tuple}, op)
222+
for v in nt
223+
if symbolic_type(v) == NotSymbolic()
224+
_vars_nt!(vars, v, op)
225+
continue
286226
end
227+
vars!(vars, v; op)
287228
end
229+
end
230+
231+
function vars!(vars, aff::ImperativeAffect; op = Differential)
232+
_vars_nt!(vars, aff.observed, op)
233+
_vars_nt!(vars, aff.modified, op)
288234
return vars
289235
end

0 commit comments

Comments
 (0)