Skip to content

Commit d79d49d

Browse files
committed
Switch MutatingFunctionalAffect from using ComponentArrays to using NamedTuples for heterotyped operation support.
1 parent f57215a commit d79d49d

File tree

3 files changed

+94
-75
lines changed

3 files changed

+94
-75
lines changed

src/systems/callbacks.jl

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -77,25 +77,29 @@ end
7777
`MutatingFunctionalAffect` is a helper for writing affect functions that will compute observed values and
7878
ensure that modified values are correctly written back into the system. The affect function `f` needs to have
7979
one of four signatures:
80-
* `f(modified::ComponentArray)` if the function only writes values (unknowns or parameters) to the system,
81-
* `f(modified::ComponentArray, observed::ComponentArray)` if the function also reads observed values from the system,
82-
* `f(modified::ComponentArray, observed::ComponentArray, ctx)` if the function needs the user-defined context,
83-
* `f(modified::ComponentArray, observed::ComponentArray, ctx, integrator)` if the function needs the low-level integrator.
80+
* `f(modified::NamedTuple)::NamedTuple` if the function only writes values (unknowns or parameters) to the system,
81+
* `f(modified::NamedTuple, observed::NamedTuple)::NamedTuple` if the function also reads observed values from the system,
82+
* `f(modified::NamedTuple, observed::NamedTuple, ctx)::NamedTuple` if the function needs the user-defined context,
83+
* `f(modified::NamedTuple, observed::NamedTuple, ctx, integrator)::NamedTuple` if the function needs the low-level integrator.
8484
These will be checked in reverse order (that is, the four-argument version first, than the 3, etc).
8585
86-
The function `f` will be called with `observed` and `modified` `ComponentArray`s that are derived from their respective `NamedTuple` definitions.
87-
Each `NamedTuple` should map an expression to a symbol; for example if we pass `observed=(; x = a + b)` this will alias the result of executing `a+b` in the system as `x`
86+
The function `f` will be called with `observed` and `modified` `NamedTuple`s that are derived from their respective `NamedTuple` definitions.
87+
Each declaration`NamedTuple` should map an expression to a symbol; for example if we pass `observed=(; x = a + b)` this will alias the result of executing `a+b` in the system as `x`
8888
so the value of `a + b` will be accessible as `observed.x` in `f`. `modified` currently restricts symbolic expressions to only bare variables, so only tuples of the form
8989
`(; x = y)` or `(; x)` (which aliases `x` as itself) are allowed.
9090
91-
Both `observed` and `modified` will be automatically populated with the current values of their corresponding expressions on function entry.
92-
The values in `modified` will be written back to the system after `f` returns. For example, if we want to update the value of `x` to be the result of `x + y` we could write
91+
The argument NamedTuples (for instance `(;x=y)`) will be populated with the declared values on function entry; if we require `(;x=y)` in `observed` and `y=2`, for example,
92+
then the NamedTuple `(;x=2)` will be passed as `observed` to the affect function `f`.
93+
94+
The NamedTuple returned from `f` includes the values to be written back to the system after `f` returns. For example, if we want to update the value of `x` to be the result of `x + y` we could write
9395
9496
MutatingFunctionalAffect(observed=(; x_plus_y = x + y), modified=(; x)) do m, o
95-
m.x = o.x_plus_y
97+
@set! m.x = o.x_plus_y
9698
end
9799
98-
The affect function updates the value at `x` in `modified` to be the result of evaluating `x + y` as passed in the observed values.
100+
Where we use Setfield to copy the tuple `m` with a new value for `x`, then return the modified value of `m`. All values updated by the tuple must have names originally declared in
101+
`modified`; a runtime error will be produced if a value is written that does not appear in `modified`. The user can dynamically decide not to write a value back by not including it
102+
in the returned tuple, in which case the associated field will not be updated.
99103
"""
100104
@kwdef struct MutatingFunctionalAffect
101105
f::Any
@@ -983,6 +987,18 @@ function unassignable_variables(sys, expr)
983987
x -> !any(isequal(x), assignable_syms), written)
984988
end
985989

990+
@generated function _generated_writeback(integ, setters::NamedTuple{NS1,<:Tuple}, values::NamedTuple{NS2, <:Tuple}) where {NS1, NS2}
991+
setter_exprs = []
992+
for name in NS2
993+
if !(name in NS1)
994+
missing_name = "Tried to write back to $name from affect; only declared states ($NS1) may be written to."
995+
error(missing_name)
996+
end
997+
push!(setter_exprs, :(setters.$name(integ, values.$name)))
998+
end
999+
return :(begin $(setter_exprs...) end)
1000+
end
1001+
9861002
function compile_user_affect(affect::MutatingFunctionalAffect, cb, sys, dvs, ps; kwargs...)
9871003
#=
9881004
Implementation sketch:
@@ -1016,7 +1032,6 @@ function compile_user_affect(affect::MutatingFunctionalAffect, cb, sys, dvs, ps;
10161032
end
10171033
obs_syms = observed_syms(affect)
10181034
obs_syms, obs_exprs = check_dups(obs_syms, obs_exprs)
1019-
obs_size = size.(obs_exprs) # we will generate a work buffer of a ComponentArray that maps obs_syms to arrays of size obs_size
10201035

10211036
mod_exprs = modified(affect)
10221037
for mexpr in mod_exprs
@@ -1033,8 +1048,6 @@ function compile_user_affect(affect::MutatingFunctionalAffect, cb, sys, dvs, ps;
10331048
end
10341049
mod_syms = modified_syms(affect)
10351050
mod_syms, mod_exprs = check_dups(mod_syms, mod_exprs)
1036-
_, mod_og_val_fun = build_explicit_observed_function(
1037-
sys, mod_exprs; return_inplace = true)
10381051

10391052
overlapping_syms = intersect(mod_syms, obs_syms)
10401053
if length(overlapping_syms) > 0 && !affect.skip_checks
@@ -1048,31 +1061,20 @@ function compile_user_affect(affect::MutatingFunctionalAffect, cb, sys, dvs, ps;
10481061
else
10491062
zeros(sz)
10501063
end
1051-
_, obs_fun = build_explicit_observed_function(
1064+
obs_fun = build_explicit_observed_function(
10521065
sys, reduce(vcat, Symbolics.scalarize.(obs_exprs); init = []);
1053-
return_inplace = true)
1054-
obs_component_array = ComponentArrays.ComponentArray(NamedTuple{(obs_syms...,)}(mkzero.(obs_size)))
1066+
array_type = :tuple)
1067+
obs_sym_tuple = (obs_syms...,)
10551068

10561069
# okay so now to generate the stuff to assign it back into the system
1057-
# note that we reorder the componentarray to make the views coherent wrt the base array
10581070
mod_pairs = mod_exprs .=> mod_syms
1059-
mod_param_pairs = filter(v -> is_parameter(sys, v[1]), mod_pairs)
1060-
mod_unk_pairs = filter(v -> !is_parameter(sys, v[1]), mod_pairs)
1061-
_, mod_og_val_fun = build_explicit_observed_function(
1062-
sys, reduce(vcat, Symbolics.scalarize.([first.(mod_param_pairs); first.(mod_unk_pairs)]); init = []);
1063-
return_inplace = true)
1064-
upd_params_fun = setu(
1065-
sys, reduce(vcat, Symbolics.scalarize.(first.(mod_param_pairs)); init = []))
1066-
upd_unk_fun = setu(
1067-
sys, reduce(vcat, Symbolics.scalarize.(first.(mod_unk_pairs)); init = []))
1068-
1069-
upd_component_array = ComponentArrays.ComponentArray(NamedTuple{([last.(mod_param_pairs);
1070-
last.(mod_unk_pairs)]...,)}(
1071-
[collect(mkzero(size(e)) for e in first.(mod_param_pairs));
1072-
collect(mkzero(size(e)) for e in first.(mod_unk_pairs))]))
1073-
upd_params_view = view(upd_component_array, last.(mod_param_pairs))
1074-
upd_unks_view = view(upd_component_array, last.(mod_unk_pairs))
1071+
mod_names = (mod_syms..., )
1072+
mod_og_val_fun = build_explicit_observed_function(
1073+
sys, reduce(vcat, Symbolics.scalarize.(first.(mod_pairs)); init = []);
1074+
array_type = :tuple)
10751075

1076+
upd_funs = NamedTuple{mod_names}((setu.((sys,), first.(mod_pairs))...,))
1077+
10761078
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
10771079
save_idxs = get(ic.callback_to_clocks, cb, Int[])
10781080
else
@@ -1082,13 +1084,13 @@ function compile_user_affect(affect::MutatingFunctionalAffect, cb, sys, dvs, ps;
10821084
let user_affect = func(affect), ctx = context(affect)
10831085
function (integ)
10841086
# update the to-be-mutated values; this ensures that if you do a no-op then nothing happens
1085-
mod_og_val_fun(upd_component_array, integ.u, integ.p..., integ.t)
1087+
upd_component_array = NamedTuple{mod_names}(mod_og_val_fun(integ.u, integ.p..., integ.t))
10861088

10871089
# update the observed values
1088-
obs_fun(obs_component_array, integ.u, integ.p..., integ.t)
1090+
obs_component_array = NamedTuple{obs_sym_tuple}(obs_fun(integ.u, integ.p..., integ.t))
10891091

10901092
# let the user do their thing
1091-
if applicable(user_affect, upd_component_array, obs_component_array, ctx, integ)
1093+
modvals = if applicable(user_affect, upd_component_array, obs_component_array, ctx, integ)
10921094
user_affect(upd_component_array, obs_component_array, ctx, integ)
10931095
elseif applicable(user_affect, upd_component_array, obs_component_array, ctx)
10941096
user_affect(upd_component_array, obs_component_array, ctx)
@@ -1102,9 +1104,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, cb, sys, dvs, ps;
11021104
end
11031105

11041106
# write the new values back to the integrator
1105-
upd_params_fun(integ, upd_params_view)
1106-
upd_unk_fun(integ, upd_unks_view)
1107-
1107+
_generated_writeback(integ, upd_funs, modvals)
11081108

11091109
for idx in save_idxs
11101110
SciMLBase.save_discretes!(integ, idx)

src/systems/diffeqs/odesystem.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,7 @@ Options not otherwise specified are:
429429
* `op = Operator` sets the recursion terminator for the walk done by `vars` to identify the variables that appear in `ts`. See the documentation for `vars` for more detail.
430430
* `throw = true` if true, throw an error when generating a function for `ts` that reference variables that do not exist
431431
* `drop_expr` is deprecated.
432+
* `array_type`; only used if the output is an array (that is, `!isscalar(ts)`). If `:array`, then it will generate an array, if `:tuple` then it will generate a tuple.
432433
"""
433434
function build_explicit_observed_function(sys, ts;
434435
inputs = nothing,
@@ -442,7 +443,8 @@ function build_explicit_observed_function(sys, ts;
442443
return_inplace = false,
443444
param_only = false,
444445
op = Operator,
445-
throw = true)
446+
throw = true,
447+
array_type=:array)
446448
if (isscalar = symbolic_type(ts) !== NotSymbolic())
447449
ts = [ts]
448450
end
@@ -587,12 +589,10 @@ function build_explicit_observed_function(sys, ts;
587589
oop_mtkp_wrapper = mtkparams_wrapper
588590
end
589591

592+
output_expr = isscalar ? ts[1] : (array_type == :array ? MakeArray(ts, output_type) : MakeTuple(ts))
590593
# Need to keep old method of building the function since it uses `output_type`,
591594
# which can't be provided to `build_function`
592-
oop_fn = Func(args, [],
593-
pre(Let(obsexprs,
594-
isscalar ? ts[1] : MakeArray(ts, output_type),
595-
false))) |> array_wrapper[1] |> oop_mtkp_wrapper |> toexpr
595+
oop_fn = Func(args, [], pre(Let(obsexprs, output_expr, false))) |> array_wrapper[1] |> oop_mtkp_wrapper |> toexpr
596596
oop_fn = expression ? oop_fn : eval_or_rgf(oop_fn; eval_expression, eval_module)
597597

598598
if !isscalar

test/symbolic_events.jl

Lines changed: 50 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using ModelingToolkit: SymbolicContinuousCallback,
88
using StableRNGs
99
import SciMLBase
1010
using SymbolicIndexingInterface
11+
using Setfield
1112
rng = StableRNG(12345)
1213

1314
@variables x(t) = 0
@@ -1010,12 +1011,12 @@ end
10101011
furnace_off = ModelingToolkit.SymbolicContinuousCallback(
10111012
[temp ~ furnace_off_threshold],
10121013
ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o, i, c
1013-
x.furnace_on = false
1014+
@set! x.furnace_on = false
10141015
end)
10151016
furnace_enable = ModelingToolkit.SymbolicContinuousCallback(
10161017
[temp ~ furnace_on_threshold],
10171018
ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o, i, c
1018-
x.furnace_on = true
1019+
@set! x.furnace_on = true
10191020
end)
10201021
@named sys = ODESystem(
10211022
eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable])
@@ -1027,12 +1028,12 @@ end
10271028
furnace_off = ModelingToolkit.SymbolicContinuousCallback(
10281029
[temp ~ furnace_off_threshold],
10291030
ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o, i
1030-
x.furnace_on = false
1031+
@set! x.furnace_on = false
10311032
end)
10321033
furnace_enable = ModelingToolkit.SymbolicContinuousCallback(
10331034
[temp ~ furnace_on_threshold],
10341035
ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o, i
1035-
x.furnace_on = true
1036+
@set! x.furnace_on = true
10361037
end)
10371038
@named sys = ODESystem(
10381039
eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable])
@@ -1044,12 +1045,12 @@ end
10441045
furnace_off = ModelingToolkit.SymbolicContinuousCallback(
10451046
[temp ~ furnace_off_threshold],
10461047
ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o
1047-
x.furnace_on = false
1048+
@set! x.furnace_on = false
10481049
end)
10491050
furnace_enable = ModelingToolkit.SymbolicContinuousCallback(
10501051
[temp ~ furnace_on_threshold],
10511052
ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o
1052-
x.furnace_on = true
1053+
@set! x.furnace_on = true
10531054
end)
10541055
@named sys = ODESystem(
10551056
eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable])
@@ -1061,12 +1062,12 @@ end
10611062
furnace_off = ModelingToolkit.SymbolicContinuousCallback(
10621063
[temp ~ furnace_off_threshold],
10631064
ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x
1064-
x.furnace_on = false
1065+
@set! x.furnace_on = false
10651066
end)
10661067
furnace_enable = ModelingToolkit.SymbolicContinuousCallback(
10671068
[temp ~ furnace_on_threshold],
10681069
ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x
1069-
x.furnace_on = true
1070+
@set! x.furnace_on = true
10701071
end)
10711072
@named sys = ODESystem(
10721073
eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable])
@@ -1078,14 +1079,14 @@ end
10781079
furnace_off = ModelingToolkit.SymbolicContinuousCallback(
10791080
[temp ~ furnace_off_threshold],
10801081
ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x
1081-
x.furnace_on = false
1082+
@set! x.furnace_on = false
10821083
end; initialize = ModelingToolkit.MutatingFunctionalAffect(modified = (; temp)) do x
1083-
x.temp = 0.2
1084+
@set! x.temp = 0.2
10841085
end)
10851086
furnace_enable = ModelingToolkit.SymbolicContinuousCallback(
10861087
[temp ~ furnace_on_threshold],
10871088
ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o, c, i
1088-
x.furnace_on = true
1089+
@set! x.furnace_on = true
10891090
end)
10901091
@named sys = ODESystem(
10911092
eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable])
@@ -1107,7 +1108,7 @@ end
11071108
[temp ~ furnace_off_threshold],
11081109
ModelingToolkit.MutatingFunctionalAffect(
11091110
modified = (; furnace_on), observed = (; furnace_on)) do x, o, c, i
1110-
x.furnace_on = false
1111+
@set! x.furnace_on = false
11111112
end)
11121113
@named sys = ODESystem(eqs, t, [temp], params; continuous_events = [furnace_off])
11131114
ss = structural_simplify(sys)
@@ -1123,7 +1124,7 @@ end
11231124
[temp ~ furnace_off_threshold],
11241125
ModelingToolkit.MutatingFunctionalAffect(
11251126
modified = (; furnace_on, tempsq), observed = (; furnace_on)) do x, o, c, i
1126-
x.furnace_on = false
1127+
@set! x.furnace_on = false
11271128
end)
11281129
@named sys = ODESystem(
11291130
eqs, t, [temp, tempsq], params; continuous_events = [furnace_off])
@@ -1136,18 +1137,32 @@ end
11361137
[temp ~ furnace_off_threshold],
11371138
ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on),
11381139
observed = (; furnace_on, not_actually_here)) do x, o, c, i
1139-
x.furnace_on = false
1140+
@set! x.furnace_on = false
11401141
end)
11411142
@named sys = ODESystem(
11421143
eqs, t, [temp, tempsq], params; continuous_events = [furnace_off])
11431144
ss = structural_simplify(sys)
11441145
@test_throws "refers to missing variable(s)" prob=ODEProblem(
11451146
ss, [temp => 0.0, furnace_on => true], (0.0, 100.0))
1147+
1148+
1149+
furnace_off = ModelingToolkit.SymbolicContinuousCallback(
1150+
[temp ~ furnace_off_threshold],
1151+
ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on),
1152+
observed = (; furnace_on)) do x, o, c, i
1153+
return (;fictional2 = false)
1154+
end)
1155+
@named sys = ODESystem(
1156+
eqs, t, [temp, tempsq], params; continuous_events = [furnace_off])
1157+
ss = structural_simplify(sys)
1158+
prob=ODEProblem(
1159+
ss, [temp => 0.0, furnace_on => true], (0.0, 100.0))
1160+
@test_throws "Tried to write back to" solve(prob, Tsit5())
11461161
end
11471162

11481163
@testset "Quadrature" begin
11491164
@variables theta(t) omega(t)
1150-
params = @parameters qA=0 qB=0 hA=0 hB=0 cnt=0
1165+
params = @parameters qA=0 qB=0 hA=0 hB=0 cnt::Int=0
11511166
eqs = [D(theta) ~ omega
11521167
omega ~ 1.0]
11531168
function decoder(oldA, oldB, newA, newB)
@@ -1167,31 +1182,35 @@ end
11671182
end
11681183
qAevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta) ~ 0],
11691184
ModelingToolkit.MutatingFunctionalAffect((; qA, hA, hB, cnt), (; qB)) do x, o, i, c
1170-
x.hA = x.qA
1171-
x.hB = o.qB
1172-
x.qA = 1
1173-
x.cnt += decoder(x.hA, x.hB, x.qA, o.qB)
1185+
@set! x.hA = x.qA
1186+
@set! x.hB = o.qB
1187+
@set! x.qA = 1
1188+
@set! x.cnt += decoder(x.hA, x.hB, x.qA, o.qB)
1189+
x
11741190
end,
11751191
affect_neg = ModelingToolkit.MutatingFunctionalAffect(
11761192
(; qA, hA, hB, cnt), (; qB)) do x, o, c, i
1177-
x.hA = x.qA
1178-
x.hB = o.qB
1179-
x.qA = 0
1180-
x.cnt += decoder(x.hA, x.hB, x.qA, o.qB)
1193+
@set! x.hA = x.qA
1194+
@set! x.hB = o.qB
1195+
@set! x.qA = 0
1196+
@set! x.cnt += decoder(x.hA, x.hB, x.qA, o.qB)
1197+
x
11811198
end; rootfind = SciMLBase.RightRootFind)
11821199
qBevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta - π / 2) ~ 0],
11831200
ModelingToolkit.MutatingFunctionalAffect((; qB, hA, hB, cnt), (; qA)) do x, o, i, c
1184-
x.hA = o.qA
1185-
x.hB = x.qB
1186-
x.qB = 1
1187-
x.cnt += decoder(x.hA, x.hB, o.qA, x.qB)
1201+
@set! x.hA = o.qA
1202+
@set! x.hB = x.qB
1203+
@set! x.qB = 1
1204+
@set! x.cnt += decoder(x.hA, x.hB, o.qA, x.qB)
1205+
x
11881206
end,
11891207
affect_neg = ModelingToolkit.MutatingFunctionalAffect(
11901208
(; qB, hA, hB, cnt), (; qA)) do x, o, c, i
1191-
x.hA = o.qA
1192-
x.hB = x.qB
1193-
x.qB = 0
1194-
x.cnt += decoder(x.hA, x.hB, o.qA, x.qB)
1209+
@set! x.hA = o.qA
1210+
@set! x.hB = x.qB
1211+
@set! x.qB = 0
1212+
@set! x.cnt += decoder(x.hA, x.hB, o.qA, x.qB)
1213+
x
11951214
end; rootfind = SciMLBase.RightRootFind)
11961215
@named sys = ODESystem(
11971216
eqs, t, [theta, omega], params; continuous_events = [qAevt, qBevt])

0 commit comments

Comments
 (0)