Skip to content

Commit 07b477f

Browse files
YingboMabaggepinnen
andcommitted
Propagate input domain back to the variable and add multirate tests
Co-authored-by: Fredrik Bagge Carlson <[email protected]>
1 parent cd22d89 commit 07b477f

File tree

4 files changed

+68
-6
lines changed

4 files changed

+68
-6
lines changed

src/clock.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,14 @@ See also [`has_continuous_domain`](@ref)
8989
"""
9090
is_continuous_domain(x) = !has_discrete_domain(x) && has_continuous_domain(x)
9191

92+
struct ClockInferenceException <: Exception
93+
msg::Any
94+
end
95+
96+
function Base.showerror(io::IO, cie::ClockInferenceException)
97+
print(io, "ClockInferenceException: ", cie.msg)
98+
end
99+
92100
abstract type AbstractClock <: AbstractDiscrete end
93101

94102
"""

src/systems/clock_inference.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,12 @@ function infer_clocks!(ci::ClockInference)
4646
c = BitSet(c′)
4747
idxs = intersect(c, inferred)
4848
isempty(idxs) && continue
49+
for i in idxs
50+
@show var_domain[i]
51+
end
4952
if !allequal(var_domain[i] for i in idxs)
50-
throw(ClockInferenceException("Clocks are not consistent in connected component $(fullvars[c])"))
53+
display(fullvars[c′])
54+
throw(ClockInferenceException("Clocks are not consistent in connected component $(fullvars[c′])"))
5155
end
5256
vd = var_domain[first(idxs)]
5357
for v in c′

src/systems/systemstructure.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,13 +236,15 @@ function TearingState(sys; quick_cancel = false, check = true)
236236
isalgeq = true
237237
statevars = []
238238
for var in vars
239+
set_incidence = true
240+
@label ANOTHER_VAR
239241
_var, _ = var_from_nested_derivative(var)
240242
any(isequal(_var), ivs) && continue
241243
if isparameter(_var) || (istree(_var) && isparameter(operation(_var)))
242244
continue
243245
end
244246
varidx = addvar!(var)
245-
push!(statevars, var)
247+
set_incidence && push!(statevars, var)
246248

247249
dvar = var
248250
idx = varidx
@@ -254,6 +256,14 @@ function TearingState(sys; quick_cancel = false, check = true)
254256
dvar = arguments(dvar)[1]
255257
idx = addvar!(dvar)
256258
end
259+
260+
if istree(var) && operation(var) isa Symbolics.Operator &&
261+
!isdifferential(var) && (it = input_timedomain(var)) !== nothing
262+
set_incidence = false
263+
var = only(arguments(var))
264+
var = setmetadata(var, ModelingToolkit.TimeDomain, it)
265+
@goto ANOTHER_VAR
266+
end
257267
end
258268
push!(symbolic_incidence, copy(statevars))
259269
empty!(statevars)

test/clock.jl

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
11
using ModelingToolkit, Test
2+
3+
function infer_clocks(sys)
4+
ts = TearingState(sys)
5+
ci = ModelingToolkit.ClockInference(ts)
6+
ModelingToolkit.infer_clocks!(ci), Dict(ci.ts.fullvars .=> ci.var_domain)
7+
end
8+
9+
@info "Testing hybrid system"
210
dt = 0.1
311
@variables t x(t) y(t) u(t) yd(t) ud(t) r(t)
412
@parameters kp
@@ -14,10 +22,7 @@ eqs = [yd ~ Sample(t, dt)(y)
1422
@named sys = ODESystem(eqs)
1523
# compute equation and variables' time domains
1624

17-
ts = TearingState(sys)
18-
ci = ModelingToolkit.ClockInference(ts)
19-
ModelingToolkit.infer_clocks!(ci)
20-
varmap = Dict(ci.ts.fullvars .=> ci.var_domain)
25+
ci, varmap = infer_clocks(sys)
2126
eqmap = ci.eq_domain
2227

2328
d = Clock(t, dt)
@@ -34,3 +39,38 @@ d = Clock(t, dt)
3439
@test varmap[x] == Continuous()
3540
@test varmap[y] == Continuous()
3641
@test varmap[u] == Continuous()
42+
43+
@info "Testing multi-rate hybrid system"
44+
dt = 0.1
45+
dt2 = 0.2
46+
@variables t x(t) y(t) u(t) r(t) yd1(t) ud1(t) yd2(t) ud2(t)
47+
@parameters kp
48+
D = Differential(t)
49+
50+
eqs = [
51+
# controller (time discrete part `dt=0.1`)
52+
yd1 ~ Sample(t, dt)(y)
53+
ud1 ~ kp * (Sample(t, dt)(r) - yd1)
54+
yd2 ~ Sample(t, dt2)(y)
55+
ud2 ~ kp * (Sample(t, dt2)(r) - yd2)
56+
57+
# plant (time continuous part)
58+
u ~ Hold(ud1) + Hold(ud2)
59+
D(x) ~ -x + u
60+
y ~ x]
61+
@named sys = ODESystem(eqs)
62+
ci, varmap = infer_clocks(sys)
63+
64+
d = Clock(t, dt)
65+
d2 = Clock(t, dt2)
66+
#@test get_eq_domain(eqs[1]) == d
67+
#@test get_eq_domain(eqs[3]) == d2
68+
69+
@test varmap[yd1] == d
70+
@test varmap[ud1] == d
71+
@test varmap[yd2] == d2
72+
@test varmap[ud2] == d2
73+
@test varmap[r] == Continuous()
74+
@test varmap[x] == Continuous()
75+
@test varmap[y] == Continuous()
76+
@test varmap[u] == Continuous()

0 commit comments

Comments
 (0)