Skip to content

Commit 53fd3f4

Browse files
feat: use and split initialization equations in clock inference
1 parent 087db08 commit 53fd3f4

File tree

3 files changed

+69
-8
lines changed

3 files changed

+69
-8
lines changed

src/systems/clock_inference.jl

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
@data ClockVertex begin
22
Variable(Int)
33
Equation(Int)
4+
InitEquation(Int)
45
Clock(SciMLBase.AbstractClock)
56
end
67

@@ -9,6 +10,8 @@ struct ClockInference{S}
910
ts::S
1011
"""The time domain (discrete clock, continuous) of each equation."""
1112
eq_domain::Vector{TimeDomain}
13+
"""The time domain (discrete clock, continuous) of each initialization equation."""
14+
init_eq_domain::Vector{TimeDomain}
1215
"""The output time domain (discrete clock, continuous) of each variable."""
1316
var_domain::Vector{TimeDomain}
1417
inference_graph::HyperGraph{ClockVertex.Type}
@@ -20,6 +23,8 @@ function ClockInference(ts::TransformationState)
2023
@unpack structure = ts
2124
@unpack graph = structure
2225
eq_domain = TimeDomain[ContinuousClock() for _ in 1:nsrcs(graph)]
26+
init_eq_domain = TimeDomain[ContinuousClock()
27+
for _ in 1:length(initialization_equations(ts.sys))]
2328
var_domain = TimeDomain[ContinuousClock() for _ in 1:ndsts(graph)]
2429
inferred = BitSet()
2530
for (i, v) in enumerate(get_fullvars(ts))
@@ -33,6 +38,9 @@ function ClockInference(ts::TransformationState)
3338
for i in 1:nsrcs(graph)
3439
add_vertex!(inference_graph, ClockVertex.Equation(i))
3540
end
41+
for i in eachindex(initialization_equations(ts.sys))
42+
add_vertex!(inference_graph, ClockVertex.InitEquation(i))
43+
end
3644
for i in 1:ndsts(graph)
3745
varvert = ClockVertex.Variable(i)
3846
add_vertex!(inference_graph, varvert)
@@ -43,7 +51,7 @@ function ClockInference(ts::TransformationState)
4351
add_vertex!(inference_graph, dvert)
4452
add_edge!(inference_graph, (varvert, dvert))
4553
end
46-
ClockInference(ts, eq_domain, var_domain, inference_graph, inferred)
54+
ClockInference(ts, eq_domain, init_eq_domain, var_domain, inference_graph, inferred)
4755
end
4856

4957
struct NotInferredTimeDomain end
@@ -96,7 +104,7 @@ end
96104
Update the equation-to-time domain mapping by inferring the time domain from the variables.
97105
"""
98106
function infer_clocks!(ci::ClockInference)
99-
@unpack ts, eq_domain, var_domain, inferred, inference_graph = ci
107+
@unpack ts, eq_domain, init_eq_domain, var_domain, inferred, inference_graph = ci
100108
@unpack var_to_diff, graph = ts.structure
101109
fullvars = get_fullvars(ts)
102110
isempty(inferred) && return ci
@@ -124,13 +132,18 @@ function infer_clocks!(ci::ClockInference)
124132
# mapping from `i` in `InferredDiscrete(i)` to the vertices in that inferred partition
125133
relative_hyperedges = Dict{Int, Set{ClockVertex.Type}}()
126134

127-
for (ieq, eq) in enumerate(equations(ts))
135+
function infer_equation(ieq, eq, is_initialization_equation)
128136
empty!(varsbuf)
129137
empty!(hyperedge)
130138
# get variables in equation
131139
vars!(varsbuf, eq; op = Symbolics.Operator)
132140
# add the equation to the hyperedge
133-
push!(hyperedge, ClockVertex.Equation(ieq))
141+
eq_node = if is_initialization_equation
142+
ClockVertex.InitEquation(ieq)
143+
else
144+
ClockVertex.Equation(ieq)
145+
end
146+
push!(hyperedge, eq_node)
134147
for var in varsbuf
135148
idx = get(var_to_idx, var, nothing)
136149
# if this is just a single variable, add it to the hyperedge
@@ -215,6 +228,12 @@ function infer_clocks!(ci::ClockInference)
215228

216229
add_edge!(inference_graph, hyperedge)
217230
end
231+
for (ieq, eq) in enumerate(equations(ts))
232+
infer_equation(ieq, eq, false)
233+
end
234+
for (ieq, eq) in enumerate(initialization_equations(ts.sys))
235+
infer_equation(ieq, eq, true)
236+
end
218237

219238
clock_partitions = connectionsets(inference_graph)
220239
for partition in clock_partitions
@@ -243,6 +262,7 @@ function infer_clocks!(ci::ClockInference)
243262
Moshi.Match.@match vert begin
244263
ClockVertex.Variable(i) => (var_domain[i] = clock)
245264
ClockVertex.Equation(i) => (eq_domain[i] = clock)
265+
ClockVertex.InitEquation(i) => (init_eq_domain[i] = clock)
246266
ClockVertex.Clock(_) => nothing
247267
end
248268
end
@@ -278,14 +298,16 @@ end
278298
For multi-clock systems, create a separate system for each clock in the system, along with associated equations. Return the updated tearing state, and the sets of clocked variables associated with each time domain.
279299
"""
280300
function split_system(ci::ClockInference{S}) where {S}
281-
@unpack ts, eq_domain, var_domain, inferred = ci
301+
@unpack ts, eq_domain, init_eq_domain, var_domain, inferred = ci
282302
fullvars = get_fullvars(ts)
283303
@unpack graph = ts.structure
284304
continuous_id = Ref(0)
285305
clock_to_id = Dict{TimeDomain, Int}()
286306
id_to_clock = TimeDomain[]
287307
eq_to_cid = Vector{Int}(undef, nsrcs(graph))
288308
cid_to_eq = Vector{Int}[]
309+
init_eq_to_cid = Vector{Int}(undef, length(initialization_equations(ts.sys)))
310+
cid_to_init_eq = Vector{Int}[]
289311
var_to_cid = Vector{Int}(undef, ndsts(graph))
290312
cid_to_var = Vector{Int}[]
291313
# cid_counter = number of clocks
@@ -311,6 +333,15 @@ function split_system(ci::ClockInference{S}) where {S}
311333
eq_to_cid[i] = cid
312334
resize_or_push!(cid_to_eq, i, cid)
313335
end
336+
# NOTE: This assumes that there is at least one equation for each clock
337+
for _ in 1:length(cid_to_eq)
338+
push!(cid_to_init_eq, Int[])
339+
end
340+
for (i, d) in enumerate(init_eq_domain)
341+
cid = clock_to_id[d]
342+
init_eq_to_cid[i] = cid
343+
push!(cid_to_init_eq[cid], i)
344+
end
314345
continuous_id = continuous_id[]
315346
# for each clock partition what are the input (indexes/vars)
316347
input_idxs = map(_ -> Int[], 1:cid_counter[])
@@ -334,8 +365,8 @@ function split_system(ci::ClockInference{S}) where {S}
334365

335366
# breaks the system up into a continous and 0 or more discrete systems
336367
tss = similar(cid_to_eq, S)
337-
for (id, (ieqs, ivars)) in enumerate(zip(cid_to_eq, cid_to_var))
338-
ts_i = system_subset(ts, ieqs, ivars)
368+
for (id, (ieqs, iieqs, ivars)) in enumerate(zip(cid_to_eq, cid_to_init_eq, cid_to_var))
369+
ts_i = system_subset(ts, ieqs, iieqs, ivars)
339370
if id != continuous_id
340371
ts_i = shift_discrete_system(ts_i)
341372
@set! ts_i.structure.only_discrete = true

src/systems/systemstructure.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,11 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
219219
end
220220

221221
TransformationState(sys::AbstractSystem) = TearingState(sys)
222-
function system_subset(ts::TearingState, ieqs::Vector{Int}, ivars::Vector{Int})
222+
function system_subset(ts::TearingState, ieqs::Vector{Int}, iieqs::Vector{Int}, ivars::Vector{Int})
223223
eqs = equations(ts)
224+
initeqs = initialization_equations(ts.sys)
224225
@set! ts.sys.eqs = eqs[ieqs]
226+
@set! ts.sys.initialization_eqs = initeqs[iieqs]
225227
@set! ts.original_eqs = ts.original_eqs[ieqs]
226228
@set! ts.structure = system_subset(ts.structure, ieqs, ivars)
227229
if all(eq -> eq.rhs isa StateMachineOperator, get_eqs(ts.sys))

test/clock.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,34 @@ eqs = [yd ~ Sample(dt)(y)
118118
@named sys = System(eqs, t)
119119
@test_throws ModelingToolkit.HybridSystemNotSupportedException ss=mtkcompile(sys);
120120

121+
@testset "Clock inference uses and splits initialization equations" begin
122+
@variables x(t) y(t) z(t)
123+
k = ShiftIndex()
124+
clk = Clock(0.1)
125+
eqs = [D(x) ~ x, y ~ Sample(clk)(x), z ~ z(k-1) + 1]
126+
initialization_eqs = [y ~ z, x ~ 1]
127+
@named sys = System(eqs, t; initialization_eqs)
128+
ts = TearingState(sys)
129+
ci = ModelingToolkit.ClockInference(ts)
130+
@test length(ci.init_eq_domain) == 2
131+
ModelingToolkit.infer_clocks!(ci)
132+
canonical_eqs = map(eqs) do eq
133+
if iscall(eq.lhs) && operation(eq.lhs) isa Differential
134+
return eq
135+
else
136+
return 0 ~ eq.rhs - eq.lhs
137+
end
138+
end
139+
eqs_idxs = findfirst.(isequal.(canonical_eqs), (equations(ci.ts),))
140+
@test ci.eq_domain[eqs_idxs[1]] == ContinuousClock()
141+
@test ci.eq_domain[eqs_idxs[2]] == clk
142+
@test ci.eq_domain[eqs_idxs[3]] == clk
143+
varmap = Dict(ci.ts.fullvars .=> ci.var_domain)
144+
@test varmap[x] == ContinuousClock()
145+
@test varmap[y] == clk
146+
@test varmap[z] == clk
147+
end
148+
121149
@test_skip begin
122150
Tf = 1.0
123151
prob = ODEProblem(

0 commit comments

Comments
 (0)