Skip to content

Commit 6d744b0

Browse files
authored
Merge pull request #1938 from SciML/myb_fb/clock_codegen
Add affect codegen for hybrid systems
2 parents 6ec96a1 + 76f19d2 commit 6d744b0

File tree

4 files changed

+117
-13
lines changed

4 files changed

+117
-13
lines changed

src/systems/abstractsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ function renamespace(sys, x)
373373
sys === nothing && return x
374374
x = unwrap(x)
375375
if x isa Symbolic
376-
if isdifferential(x)
376+
if istree(x) && operation(x) isa Operator
377377
return similarterm(x, operation(x), Any[renamespace(sys, only(arguments(x)))])
378378
end
379379
let scope = getmetadata(x, SymScope, LocalScope())

src/systems/clock_inference.jl

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,71 @@ function split_system(ci::ClockInference)
146146
@set! ts_i.sys.eqs = eqs_i
147147
@set! ts_i.structure.eq_to_diff = eq_to_diff
148148
tss[id] = ts_i
149-
# TODO: just mark current and sample variables as inputs
150149
end
151-
return tss, inputs
150+
return tss, inputs, continuous_id
151+
end
152152

153-
#id_to_clock, cid_to_eq, cid_to_var
153+
function generate_discrete_affect(syss, inputs, continuous_id, check_bounds = true)
154+
out = Sym{Any}(:out)
155+
appended_parameters = parameters(syss[continuous_id])
156+
param_to_idx = Dict{Any, Int}(reverse(en) for en in enumerate(appended_parameters))
157+
offset = length(appended_parameters)
158+
affect_funs = []
159+
svs = []
160+
for (i, (sys, input)) in enumerate(zip(syss, inputs))
161+
i == continuous_id && continue
162+
subs = get_substitutions(sys)
163+
assignments = map(s -> Assignment(s.lhs, s.rhs), subs.subs)
164+
let_body = SetArray(!check_bounds, out, rhss(equations(sys)))
165+
let_block = Let(assignments, let_body, false)
166+
needed_cont_to_disc_obs = map(v -> arguments(v)[1], input)
167+
# TODO: filter the needed ones
168+
needed_disc_to_cont_obs = map(v -> arguments(v)[1], inputs[continuous_id])
169+
append!(appended_parameters, input, states(sys))
170+
disc_to_cont_idxs = map(Base.Fix1(getindex, param_to_idx), inputs[continuous_id])
171+
cont_to_disc_obs = build_explicit_observed_function(syss[continuous_id],
172+
needed_cont_to_disc_obs,
173+
throw = false,
174+
expression = true,
175+
output_type = SVector)
176+
@set! sys.ps = appended_parameters
177+
disc_to_cont_obs = build_explicit_observed_function(sys, needed_disc_to_cont_obs,
178+
throw = false,
179+
expression = true,
180+
output_type = SVector)
181+
ni = length(input)
182+
ns = length(states(sys))
183+
disc = Func([
184+
out,
185+
DestructuredArgs(states(sys)),
186+
DestructuredArgs(appended_parameters),
187+
get_iv(sys),
188+
], [],
189+
let_block)
190+
cont_to_disc_idxs = (offset + 1):(offset += ni)
191+
input_offset = offset
192+
disc_range = (offset + 1):(offset += ns)
193+
affect! = quote
194+
function affect!(integrator, saved_values)
195+
@unpack u, p, t = integrator
196+
c2d_obs = $cont_to_disc_obs
197+
d2c_obs = $disc_to_cont_obs
198+
c2d_view = view(p, $cont_to_disc_idxs)
199+
d2c_view = view(p, $disc_to_cont_idxs)
200+
disc_state = view(p, $disc_range)
201+
disc = $disc
202+
# Write continuous info to discrete
203+
# Write discrete info to continuous
204+
copyto!(c2d_view, c2d_obs(integrator.u, p, t))
205+
copyto!(d2c_view, d2c_obs(disc_state, p, t))
206+
push!(saved_values.t, t)
207+
push!(saved_values.saveval, Base.@ntuple $ns i->p[$input_offset + i])
208+
disc(disc_state, disc_state, p, t)
209+
end
210+
end
211+
sv = SavedValues(Float64, NTuple{ns, Float64})
212+
push!(affect_funs, affect!)
213+
push!(svs, sv)
214+
end
215+
return map(a -> toexpr(LiteralExpr(a)), affect_funs), svs, appended_parameters
154216
end

src/systems/diffeqs/odesystem.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,8 @@ i.e. there are no cycles.
289289
function build_explicit_observed_function(sys, ts;
290290
expression = false,
291291
output_type = Array,
292-
checkbounds = true)
292+
checkbounds = true,
293+
throw = true)
293294
if (isscalar = !(ts isa AbstractVector))
294295
ts = [ts]
295296
end
@@ -336,7 +337,12 @@ function build_explicit_observed_function(sys, ts;
336337
subs[s] = s′
337338
continue
338339
end
339-
throw(ArgumentError("$s is neither an observed nor a state variable."))
340+
if throw
341+
Base.throw(ArgumentError("$s is neither an observed nor a state variable."))
342+
else
343+
# TODO: return variables that don't exist in the system.
344+
return nothing
345+
end
340346
end
341347
continue
342348
end

test/clock.jl

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ eqmap = ci.eq_domain
6666
tss, inputs = ModelingToolkit.split_system(deepcopy(ci))
6767
sss, = ModelingToolkit.structural_simplify!(deepcopy(tss[1]), (inputs[1], ()))
6868
@test equations(sss) == [D(x) ~ u - x]
69-
sss, = ModelingToolkit.structural_simplify!(deepcopy(tss[2]), (inputs[2], ()),
70-
check_consistency = false)
69+
sss, = ModelingToolkit.structural_simplify!(deepcopy(tss[2]), (inputs[2], ()))
7170
@test isempty(equations(sss))
7271
@test observed(sss) == [r ~ 1.0; yd ~ Sample(t, dt)(y); ud ~ kp * (r - yd)]
7372

@@ -96,7 +95,7 @@ d = Clock(t, dt)
9695
k = ShiftIndex(d)
9796

9897
eqs = [yd ~ Sample(t, dt)(y)
99-
ud ~ kp * (r - yd)
98+
ud ~ kp * (r - yd) + z(k)
10099
r ~ 1.0
101100

102101
# plant (time continuous part)
@@ -114,11 +113,48 @@ eqs = [yd ~ Sample(t, dt)(y)
114113
@named sys = ODESystem(eqs)
115114
ci, varmap = infer_clocks(sys)
116115
tss, inputs = ModelingToolkit.split_system(deepcopy(ci))
117-
sss, = ModelingToolkit.structural_simplify!(deepcopy(tss[2]), (inputs[2], ()))
118-
@test length(states(sss)) == 2
119-
z, z_t = states(sss)
116+
syss = map(i -> ModelingToolkit.structural_simplify!(deepcopy(tss[i]), (inputs[i], ()))[1],
117+
eachindex(tss))
118+
sys1, sys2 = syss
119+
@test length(states(sys2)) == 2
120+
z, z_t = states(sys2)
120121
S = Shift(t, 1)
121-
@test full_equations(sss) == [S(z) ~ z_t; S(z_t) ~ z + Sample(t, dt)(y)]
122+
@test full_equations(sys2) == [S(z) ~ z_t; S(z_t) ~ z + Sample(t, dt)(y)]
123+
# TODO: set Hold(ud)
124+
prob = ODEProblem(sys1, [x => 0.0, y => 0.0], (0.0, 1.0), [kp => 1.0, Hold(ud) => 0.0]);
125+
using OrdinaryDiffEq, DiffEqCallbacks
126+
exprs, svs, pp = ModelingToolkit.generate_discrete_affect(syss, inputs, 1);
127+
prob = remake(prob, p = zeros(Float64, length(pp)));
128+
prob.p[1] = 1.0;
129+
gen_affect! = Base.Fix2(eval(exprs[1]), svs[1]);
130+
cb = PeriodicCallback(gen_affect!, 0.1);
131+
sol2 = solve(prob, Tsit5(), callback = cb);
132+
133+
# kp is the only real parameter
134+
function foo!(du, u, p, t)
135+
x = u[1]
136+
ud = p[2]
137+
du[1] = -x + ud
138+
end
139+
function affect!(integrator, saved_values)
140+
kp = integrator.p[1]
141+
yd = integrator.u[1]
142+
z_t = integrator.p[3]
143+
z = integrator.p[4]
144+
r = 1.0
145+
ud = kp * (r - yd) + z
146+
push!(saved_values.t, integrator.t)
147+
push!(saved_values.saveval, (integrator.p[3], integrator.p[4]))
148+
integrator.p[2] = ud
149+
integrator.p[3] = z + yd
150+
integrator.p[4] = z_t
151+
nothing
152+
end
153+
saved_values = SavedValues(Float64, Tuple{Float64, Float64});
154+
cb = PeriodicCallback(Base.Fix2(affect!, saved_values), 0.1);
155+
prob = ODEProblem(foo!, [0.0], (0.0, 1.0), [1.0, 0.0, 0.0, 0.0], callback = cb);
156+
sol = solve(prob, Tsit5());
157+
@test sol.u sol2.u
122158

123159
@info "Testing multi-rate hybrid system"
124160
dt = 0.1

0 commit comments

Comments
 (0)