Skip to content

Commit e82f9e2

Browse files
authored
Merge pull request #1959 from SciML/multirate
add test for multi-rate system
2 parents d617c56 + 7b0d100 commit e82f9e2

File tree

2 files changed

+86
-3
lines changed

2 files changed

+86
-3
lines changed

src/systems/clock_inference.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,20 @@ function generate_discrete_affect(syss, inputs, continuous_id, id_to_clock;
169169
let_block = Let(assignments, let_body, false)
170170
needed_cont_to_disc_obs = map(v -> arguments(v)[1], input)
171171
# TODO: filter the needed ones
172-
needed_disc_to_cont_obs = map(v -> arguments(v)[1], inputs[continuous_id])
172+
fullvars = Set{Any}(eq.lhs for eq in observed(sys))
173+
for s in states(sys)
174+
push!(fullvars, s)
175+
end
176+
needed_disc_to_cont_obs = []
177+
disc_to_cont_idxs = Int[]
178+
for v in inputs[continuous_id]
179+
vv = arguments(v)[1]
180+
if vv in fullvars
181+
push!(needed_disc_to_cont_obs, vv)
182+
push!(disc_to_cont_idxs, param_to_idx[v])
183+
end
184+
end
173185
append!(appended_parameters, input, states(sys))
174-
disc_to_cont_idxs = map(Base.Fix1(getindex, param_to_idx), inputs[continuous_id])
175186
cont_to_disc_obs = build_explicit_observed_function(syss[continuous_id],
176187
needed_cont_to_disc_obs,
177188
throw = false,
@@ -198,6 +209,7 @@ function generate_discrete_affect(syss, inputs, continuous_id, id_to_clock;
198209
for i in 1:ns
199210
push!(save_vec.args, :(p[$(input_offset + i)]))
200211
end
212+
empty_disc = isempty(disc_range)
201213
affect! = :(function (integrator, saved_values)
202214
@unpack u, p, t = integrator
203215
c2d_obs = $cont_to_disc_obs
@@ -212,7 +224,7 @@ function generate_discrete_affect(syss, inputs, continuous_id, id_to_clock;
212224
copyto!(d2c_view, d2c_obs(disc_state, p, t))
213225
push!(saved_values.t, t)
214226
push!(saved_values.saveval, $save_vec)
215-
disc(disc_state, disc_state, p, t)
227+
$empty_disc || disc(disc_state, disc_state, p, t)
216228
end)
217229
sv = SavedValues(Float64, Vector{Float64})
218230
push!(affect_funs, affect!)

test/clock.jl

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,3 +240,74 @@ ci, varmap = infer_clocks(cl)
240240
@test varmap[f.u] == Clock(t, 0.5)
241241
@test varmap[p.u] == Continuous()
242242
@test varmap[c.r] == Clock(t, 0.5)
243+
244+
## Multiple clock rates
245+
@info "Testing multi-rate hybrid system"
246+
dt = 0.1
247+
dt2 = 0.2
248+
@variables t x(t)=0 y(t)=0 u(t)=0 yd1(t)=0 ud1(t)=0 yd2(t)=0 ud2(t)=0
249+
@parameters kp=1 r=1
250+
D = Differential(t)
251+
252+
eqs = [
253+
# controller (time discrete part `dt=0.1`)
254+
yd1 ~ Sample(t, dt)(y)
255+
ud1 ~ kp * (r - yd1)
256+
# controller (time discrete part `dt=0.2`)
257+
yd2 ~ Sample(t, dt2)(y)
258+
ud2 ~ kp * (r - yd2)
259+
260+
# plant (time continuous part)
261+
u ~ Hold(ud1) + Hold(ud2)
262+
D(x) ~ -x + u
263+
y ~ x]
264+
265+
@named cl = ODESystem(eqs, t)
266+
267+
d = Clock(t, dt)
268+
d2 = Clock(t, dt2)
269+
270+
ci, varmap = infer_clocks(cl)
271+
@test varmap[yd1] == d
272+
@test varmap[ud1] == d
273+
@test varmap[yd2] == d2
274+
@test varmap[ud2] == d2
275+
@test varmap[x] == Continuous()
276+
@test varmap[y] == Continuous()
277+
@test varmap[u] == Continuous()
278+
279+
ss = structural_simplify(cl)
280+
281+
if VERSION >= v"1.7"
282+
prob = ODEProblem(ss, [x => 0.0], (0.0, 1.0), [kp => 1.0])
283+
sol = solve(prob, Tsit5(), kwargshandle = KeywordArgSilent)
284+
285+
function foo!(dx, x, p, t)
286+
kp, ud1, ud2 = p
287+
dx[1] = -x[1] + ud1 + ud2
288+
end
289+
290+
function affect1!(integrator)
291+
kp = integrator.p[1]
292+
y = integrator.u[1]
293+
r = 1.0
294+
ud1 = kp * (r - y)
295+
integrator.p[2] = ud1
296+
nothing
297+
end
298+
function affect2!(integrator)
299+
kp = integrator.p[1]
300+
y = integrator.u[1]
301+
r = 1.0
302+
ud2 = kp * (r - y)
303+
integrator.p[3] = ud2
304+
nothing
305+
end
306+
cb1 = PeriodicCallback(affect1!, dt)
307+
cb2 = PeriodicCallback(affect2!, dt2)
308+
cb = CallbackSet(cb1, cb2)
309+
prob = ODEProblem(foo!, [0.0], (0.0, 1.0), [1.0, 0.0, 0.0], callback = cb)
310+
sol2 = solve(prob, Tsit5())
311+
312+
@test sol.u sol2.u
313+
end

0 commit comments

Comments
 (0)