Skip to content

Commit 16d3f5c

Browse files
committed
Refactor ImperativeAffect into its own file
1 parent 5fcf864 commit 16d3f5c

File tree

4 files changed

+224
-220
lines changed

4 files changed

+224
-220
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ include("systems/parameter_buffer.jl")
145145
include("systems/abstractsystem.jl")
146146
include("systems/model_parsing.jl")
147147
include("systems/connectors.jl")
148+
include("systems/imperative_affect.jl")
148149
include("systems/callbacks.jl")
149150
include("systems/problem_utils.jl")
150151

src/systems/callbacks.jl

Lines changed: 3 additions & 218 deletions
Original file line numberDiff line numberDiff line change
@@ -73,111 +73,6 @@ function namespace_affect(affect::FunctionalAffect, s)
7373
context(affect))
7474
end
7575

76-
"""
77-
ImperativeAffect(f::Function; modified::NamedTuple, observed::NamedTuple, ctx)
78-
79-
`ImperativeAffect` is a helper for writing affect functions that will compute observed values and
80-
ensure that modified values are correctly written back into the system. The affect function `f` needs to have
81-
the signature
82-
83-
```
84-
f(modified::NamedTuple, observed::NamedTuple, ctx, integrator)::NamedTuple
85-
```
86-
87-
The function `f` will be called with `observed` and `modified` `NamedTuple`s that are derived from their respective `NamedTuple` definitions.
88-
Each declaration`NamedTuple` should map an expression to a symbol; for example if we pass `observed=(; x = a + b)` this will alias the result of executing `a+b` in the system as `x`
89-
so the value of `a + b` will be accessible as `observed.x` in `f`. `modified` currently restricts symbolic expressions to only bare variables, so only tuples of the form
90-
`(; x = y)` or `(; x)` (which aliases `x` as itself) are allowed.
91-
92-
The argument NamedTuples (for instance `(;x=y)`) will be populated with the declared values on function entry; if we require `(;x=y)` in `observed` and `y=2`, for example,
93-
then the NamedTuple `(;x=2)` will be passed as `observed` to the affect function `f`.
94-
95-
The NamedTuple returned from `f` includes the values to be written back to the system after `f` returns. For example, if we want to update the value of `x` to be the result of `x + y` we could write
96-
97-
ImperativeAffect(observed=(; x_plus_y = x + y), modified=(; x)) do m, o
98-
@set! m.x = o.x_plus_y
99-
end
100-
101-
Where we use Setfield to copy the tuple `m` with a new value for `x`, then return the modified value of `m`. All values updated by the tuple must have names originally declared in
102-
`modified`; a runtime error will be produced if a value is written that does not appear in `modified`. The user can dynamically decide not to write a value back by not including it
103-
in the returned tuple, in which case the associated field will not be updated.
104-
"""
105-
@kwdef struct ImperativeAffect
106-
f::Any
107-
obs::Vector
108-
obs_syms::Vector{Symbol}
109-
modified::Vector
110-
mod_syms::Vector{Symbol}
111-
ctx::Any
112-
skip_checks::Bool
113-
end
114-
115-
function ImperativeAffect(f::Function;
116-
observed::NamedTuple = NamedTuple{()}(()),
117-
modified::NamedTuple = NamedTuple{()}(()),
118-
ctx = nothing,
119-
skip_checks = false)
120-
ImperativeAffect(f,
121-
collect(values(observed)), collect(keys(observed)),
122-
collect(values(modified)), collect(keys(modified)),
123-
ctx, skip_checks)
124-
end
125-
function ImperativeAffect(f::Function, modified::NamedTuple;
126-
observed::NamedTuple = NamedTuple{()}(()), ctx = nothing, skip_checks = false)
127-
ImperativeAffect(
128-
f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
129-
end
130-
function ImperativeAffect(
131-
f::Function, modified::NamedTuple, observed::NamedTuple; ctx = nothing, skip_checks = false)
132-
ImperativeAffect(
133-
f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
134-
end
135-
function ImperativeAffect(
136-
f::Function, modified::NamedTuple, observed::NamedTuple, ctx; skip_checks = false)
137-
ImperativeAffect(
138-
f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
139-
end
140-
141-
function Base.show(io::IO, mfa::ImperativeAffect)
142-
obs_vals = join(map((ob, nm) -> "$ob => $nm", mfa.obs, mfa.obs_syms), ", ")
143-
mod_vals = join(map((md, nm) -> "$md => $nm", mfa.modified, mfa.mod_syms), ", ")
144-
affect = mfa.f
145-
print(io,
146-
"ImperativeAffect(observed: [$obs_vals], modified: [$mod_vals], affect:$affect)")
147-
end
148-
func(f::ImperativeAffect) = f.f
149-
context(a::ImperativeAffect) = a.ctx
150-
observed(a::ImperativeAffect) = a.obs
151-
observed_syms(a::ImperativeAffect) = a.obs_syms
152-
discretes(a::ImperativeAffect) = filter(ModelingToolkit.isparameter, a.modified)
153-
modified(a::ImperativeAffect) = a.modified
154-
modified_syms(a::ImperativeAffect) = a.mod_syms
155-
156-
function Base.:(==)(a1::ImperativeAffect, a2::ImperativeAffect)
157-
isequal(a1.f, a2.f) && isequal(a1.obs, a2.obs) && isequal(a1.modified, a2.modified) &&
158-
isequal(a1.obs_syms, a2.obs_syms) && isequal(a1.mod_syms, a2.mod_syms) &&
159-
isequal(a1.ctx, a2.ctx)
160-
end
161-
162-
function Base.hash(a::ImperativeAffect, s::UInt)
163-
s = hash(a.f, s)
164-
s = hash(a.obs, s)
165-
s = hash(a.obs_syms, s)
166-
s = hash(a.modified, s)
167-
s = hash(a.mod_syms, s)
168-
hash(a.ctx, s)
169-
end
170-
171-
function namespace_affect(affect::ImperativeAffect, s)
172-
ImperativeAffect(func(affect),
173-
namespace_expr.(observed(affect), (s,)),
174-
observed_syms(affect),
175-
renamespace.((s,), modified(affect)),
176-
modified_syms(affect),
177-
context(affect),
178-
affect.skip_checks)
179-
end
180-
18176
function has_functional_affect(cb)
18277
(affects(cb) isa FunctionalAffect || affects(cb) isa ImperativeAffect)
18378
end
@@ -203,13 +98,13 @@ sharp discontinuity between integrator steps (which in this example would not no
20398
guaranteed to be triggered.
20499
205100
Once detected the integrator will "wind back" through a root-finding process to identify the point when the condition became active; the method used
206-
is specified by `rootfind` from [`SciMLBase.RootfindOpt`](@ref). If we denote the time when the condition becomes active at tc,
101+
is specified by `rootfind` from [`SciMLBase.RootfindOpt`](@ref). If we denote the time when the condition becomes active as `tc``,
207102
the value in the integrator after windback will be:
208103
* `u[tc-epsilon], p[tc-epsilon], tc` if `LeftRootFind` is used,
209104
* `u[tc+epsilon], p[tc+epsilon], tc` if `RightRootFind` is used,
210105
* or `u[t], p[t], t` if `NoRootFind` is used.
211106
For example, if we want to detect when an unknown variable `x` satisfies `x > 0` using the condition `x ~ 0` on a positive edge (that is, `D(x) > 0`),
212-
then left root finding will get us `x=-epsilon`, right root finding `x=epsilon` and no root finding whatever the next step of the integrator was after
107+
then left root finding will get us `x=-epsilon`, right root finding `x=epsilon` and no root finding will produce whatever the next step of the integrator was after
213108
it passed through 0.
214109
215110
Multiple callbacks in the same system with different `rootfind` operations will be grouped
@@ -405,7 +300,6 @@ end
405300

406301
namespace_affects(af::Vector, s) = Equation[namespace_affect(a, s) for a in af]
407302
namespace_affects(af::FunctionalAffect, s) = namespace_affect(af, s)
408-
namespace_affects(af::ImperativeAffect, s) = namespace_affect(af, s)
409303
namespace_affects(::Nothing, s) = nothing
410304

411305
function namespace_callback(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallback
@@ -480,7 +374,6 @@ scalarize_affects(affects) = scalarize(affects)
480374
scalarize_affects(affects::Tuple) = FunctionalAffect(affects...)
481375
scalarize_affects(affects::NamedTuple) = FunctionalAffect(; affects...)
482376
scalarize_affects(affects::FunctionalAffect) = affects
483-
scalarize_affects(affects::ImperativeAffect) = affects
484377

485378
SymbolicDiscreteCallback(p::Pair) = SymbolicDiscreteCallback(p[1], p[2])
486379
SymbolicDiscreteCallback(cb::SymbolicDiscreteCallback) = cb # passthrough
@@ -1099,117 +992,9 @@ function check_assignable(sys, sym)
1099992
end
1100993
end
1101994

1102-
function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs...)
1103-
#=
1104-
Implementation sketch:
1105-
generate observed function (oop), should save to a component array under obs_syms
1106-
do the same stuff as the normal FA for pars_syms
1107-
call the affect method
1108-
unpack and apply the resulting values
1109-
=#
1110-
function check_dups(syms, exprs) # = (syms_dedup, exprs_dedup)
1111-
seen = Set{Symbol}()
1112-
syms_dedup = []
1113-
exprs_dedup = []
1114-
for (sym, exp) in Iterators.zip(syms, exprs)
1115-
if !in(sym, seen)
1116-
push!(syms_dedup, sym)
1117-
push!(exprs_dedup, exp)
1118-
push!(seen, sym)
1119-
elseif !affect.skip_checks
1120-
@warn "Expression $(expr) is aliased as $sym, which has already been used. The first definition will be used."
1121-
end
1122-
end
1123-
return (syms_dedup, exprs_dedup)
1124-
end
1125-
1126-
obs_exprs = observed(affect)
1127-
if !affect.skip_checks
1128-
for oexpr in obs_exprs
1129-
invalid_vars = invalid_variables(sys, oexpr)
1130-
if length(invalid_vars) > 0
1131-
error("Observed equation $(oexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing).")
1132-
end
1133-
end
1134-
end
1135-
obs_syms = observed_syms(affect)
1136-
obs_syms, obs_exprs = check_dups(obs_syms, obs_exprs)
1137-
1138-
mod_exprs = modified(affect)
1139-
if !affect.skip_checks
1140-
for mexpr in mod_exprs
1141-
if !check_assignable(sys, mexpr)
1142-
@warn ("Expression $mexpr cannot be assigned to; currently only unknowns and parameters may be updated by an affect.")
1143-
end
1144-
invalid_vars = unassignable_variables(sys, mexpr)
1145-
if length(invalid_vars) > 0
1146-
error("Modified equation $(mexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing) or they may have been reduced away.")
1147-
end
1148-
end
1149-
end
1150-
mod_syms = modified_syms(affect)
1151-
mod_syms, mod_exprs = check_dups(mod_syms, mod_exprs)
1152-
1153-
overlapping_syms = intersect(mod_syms, obs_syms)
1154-
if length(overlapping_syms) > 0 && !affect.skip_checks
1155-
@warn "The symbols $overlapping_syms are declared as both observed and modified; this is a code smell because it becomes easy to confuse them and assign/not assign a value."
1156-
end
1157-
1158-
# sanity checks done! now build the data and update function for observed values
1159-
mkzero(sz) =
1160-
if sz === ()
1161-
0.0
1162-
else
1163-
zeros(sz)
1164-
end
1165-
obs_fun = build_explicit_observed_function(
1166-
sys, Symbolics.scalarize.(obs_exprs);
1167-
array_type = Tuple)
1168-
obs_sym_tuple = (obs_syms...,)
1169-
1170-
# okay so now to generate the stuff to assign it back into the system
1171-
mod_pairs = mod_exprs .=> mod_syms
1172-
mod_names = (mod_syms...,)
1173-
mod_og_val_fun = build_explicit_observed_function(
1174-
sys, Symbolics.scalarize.(first.(mod_pairs));
1175-
array_type = Tuple)
1176-
1177-
upd_funs = NamedTuple{mod_names}((setu.((sys,), first.(mod_pairs))...,))
1178-
1179-
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
1180-
save_idxs = get(ic.callback_to_clocks, cb, Int[])
1181-
else
1182-
save_idxs = Int[]
1183-
end
1184-
1185-
let user_affect = func(affect), ctx = context(affect)
1186-
function (integ)
1187-
# update the to-be-mutated values; this ensures that if you do a no-op then nothing happens
1188-
modvals = mod_og_val_fun(integ.u, integ.p, integ.t)
1189-
upd_component_array = NamedTuple{mod_names}(modvals)
1190-
1191-
# update the observed values
1192-
obs_component_array = NamedTuple{obs_sym_tuple}(obs_fun(
1193-
integ.u, integ.p, integ.t))
1194-
1195-
# let the user do their thing
1196-
upd_vals = user_affect(upd_component_array, obs_component_array, ctx, integ)
1197-
1198-
# write the new values back to the integrator
1199-
_generated_writeback(integ, upd_funs, upd_vals)
1200-
1201-
for idx in save_idxs
1202-
SciMLBase.save_discretes!(integ, idx)
1203-
end
1204-
end
1205-
end
1206-
end
1207-
1208-
function compile_affect(
1209-
affect::Union{FunctionalAffect, ImperativeAffect}, cb, sys, dvs, ps; kwargs...)
995+
function compile_affect(affect::FunctionalAffect, cb, sys, dvs, ps; kwargs...)
1210996
compile_user_affect(affect, cb, sys, dvs, ps; kwargs...)
1211997
end
1212-
1213998
function _compile_optional_affect(default, aff, cb, sys, dvs, ps; kwargs...)
1214999
if isnothing(aff) || aff == default
12151000
return nothing

src/systems/diffeqs/odesystem.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -629,8 +629,6 @@ function build_explicit_observed_function(sys, ts;
629629
oop_mtkp_wrapper = mtkparams_wrapper
630630
end
631631

632-
output_expr = isscalar ? ts[1] :
633-
(array_type <: Vector ? MakeArray(ts, output_type) : MakeTuple(ts))
634632
# Need to keep old method of building the function since it uses `output_type`,
635633
# which can't be provided to `build_function`
636634
return_value = if isscalar

0 commit comments

Comments
 (0)