1
1
@data ClockVertex begin
2
2
Variable (Int)
3
3
Equation (Int)
4
+ InitEquation (Int)
4
5
Clock (SciMLBase. AbstractClock)
5
6
end
6
7
@@ -9,6 +10,8 @@ struct ClockInference{S}
9
10
ts:: S
10
11
""" The time domain (discrete clock, continuous) of each equation."""
11
12
eq_domain:: Vector{TimeDomain}
13
+ """ The time domain (discrete clock, continuous) of each initialization equation."""
14
+ init_eq_domain:: Vector{TimeDomain}
12
15
""" The output time domain (discrete clock, continuous) of each variable."""
13
16
var_domain:: Vector{TimeDomain}
14
17
inference_graph:: HyperGraph{ClockVertex.Type}
@@ -20,6 +23,8 @@ function ClockInference(ts::TransformationState)
20
23
@unpack structure = ts
21
24
@unpack graph = structure
22
25
eq_domain = TimeDomain[ContinuousClock () for _ in 1 : nsrcs (graph)]
26
+ init_eq_domain = TimeDomain[ContinuousClock ()
27
+ for _ in 1 : length (initialization_equations (ts. sys))]
23
28
var_domain = TimeDomain[ContinuousClock () for _ in 1 : ndsts (graph)]
24
29
inferred = BitSet ()
25
30
for (i, v) in enumerate (get_fullvars (ts))
@@ -33,6 +38,9 @@ function ClockInference(ts::TransformationState)
33
38
for i in 1 : nsrcs (graph)
34
39
add_vertex! (inference_graph, ClockVertex. Equation (i))
35
40
end
41
+ for i in eachindex (initialization_equations (ts. sys))
42
+ add_vertex! (inference_graph, ClockVertex. InitEquation (i))
43
+ end
36
44
for i in 1 : ndsts (graph)
37
45
varvert = ClockVertex. Variable (i)
38
46
add_vertex! (inference_graph, varvert)
@@ -43,7 +51,7 @@ function ClockInference(ts::TransformationState)
43
51
add_vertex! (inference_graph, dvert)
44
52
add_edge! (inference_graph, (varvert, dvert))
45
53
end
46
- ClockInference (ts, eq_domain, var_domain, inference_graph, inferred)
54
+ ClockInference (ts, eq_domain, init_eq_domain, var_domain, inference_graph, inferred)
47
55
end
48
56
49
57
struct NotInferredTimeDomain end
96
104
Update the equation-to-time domain mapping by inferring the time domain from the variables.
97
105
"""
98
106
function infer_clocks! (ci:: ClockInference )
99
- @unpack ts, eq_domain, var_domain, inferred, inference_graph = ci
107
+ @unpack ts, eq_domain, init_eq_domain, var_domain, inferred, inference_graph = ci
100
108
@unpack var_to_diff, graph = ts. structure
101
109
fullvars = get_fullvars (ts)
102
110
isempty (inferred) && return ci
@@ -124,13 +132,18 @@ function infer_clocks!(ci::ClockInference)
124
132
# mapping from `i` in `InferredDiscrete(i)` to the vertices in that inferred partition
125
133
relative_hyperedges = Dict {Int, Set{ClockVertex.Type}} ()
126
134
127
- for (ieq, eq) in enumerate ( equations (ts) )
135
+ function infer_equation (ieq, eq, is_initialization_equation )
128
136
empty! (varsbuf)
129
137
empty! (hyperedge)
130
138
# get variables in equation
131
139
vars! (varsbuf, eq; op = Symbolics. Operator)
132
140
# add the equation to the hyperedge
133
- push! (hyperedge, ClockVertex. Equation (ieq))
141
+ eq_node = if is_initialization_equation
142
+ ClockVertex. InitEquation (ieq)
143
+ else
144
+ ClockVertex. Equation (ieq)
145
+ end
146
+ push! (hyperedge, eq_node)
134
147
for var in varsbuf
135
148
idx = get (var_to_idx, var, nothing )
136
149
# if this is just a single variable, add it to the hyperedge
@@ -215,6 +228,12 @@ function infer_clocks!(ci::ClockInference)
215
228
216
229
add_edge! (inference_graph, hyperedge)
217
230
end
231
+ for (ieq, eq) in enumerate (equations (ts))
232
+ infer_equation (ieq, eq, false )
233
+ end
234
+ for (ieq, eq) in enumerate (initialization_equations (ts. sys))
235
+ infer_equation (ieq, eq, true )
236
+ end
218
237
219
238
clock_partitions = connectionsets (inference_graph)
220
239
for partition in clock_partitions
@@ -243,6 +262,7 @@ function infer_clocks!(ci::ClockInference)
243
262
Moshi. Match. @match vert begin
244
263
ClockVertex. Variable (i) => (var_domain[i] = clock)
245
264
ClockVertex. Equation (i) => (eq_domain[i] = clock)
265
+ ClockVertex. InitEquation (i) => (init_eq_domain[i] = clock)
246
266
ClockVertex. Clock (_) => nothing
247
267
end
248
268
end
@@ -278,14 +298,16 @@ end
278
298
For multi-clock systems, create a separate system for each clock in the system, along with associated equations. Return the updated tearing state, and the sets of clocked variables associated with each time domain.
279
299
"""
280
300
function split_system (ci:: ClockInference{S} ) where {S}
281
- @unpack ts, eq_domain, var_domain, inferred = ci
301
+ @unpack ts, eq_domain, init_eq_domain, var_domain, inferred = ci
282
302
fullvars = get_fullvars (ts)
283
303
@unpack graph = ts. structure
284
304
continuous_id = Ref (0 )
285
305
clock_to_id = Dict {TimeDomain, Int} ()
286
306
id_to_clock = TimeDomain[]
287
307
eq_to_cid = Vector {Int} (undef, nsrcs (graph))
288
308
cid_to_eq = Vector{Int}[]
309
+ init_eq_to_cid = Vector {Int} (undef, length (initialization_equations (ts. sys)))
310
+ cid_to_init_eq = Vector{Int}[]
289
311
var_to_cid = Vector {Int} (undef, ndsts (graph))
290
312
cid_to_var = Vector{Int}[]
291
313
# cid_counter = number of clocks
@@ -311,6 +333,15 @@ function split_system(ci::ClockInference{S}) where {S}
311
333
eq_to_cid[i] = cid
312
334
resize_or_push! (cid_to_eq, i, cid)
313
335
end
336
+ # NOTE: This assumes that there is at least one equation for each clock
337
+ for _ in 1 : length (cid_to_eq)
338
+ push! (cid_to_init_eq, Int[])
339
+ end
340
+ for (i, d) in enumerate (init_eq_domain)
341
+ cid = clock_to_id[d]
342
+ init_eq_to_cid[i] = cid
343
+ push! (cid_to_init_eq[cid], i)
344
+ end
314
345
continuous_id = continuous_id[]
315
346
# for each clock partition what are the input (indexes/vars)
316
347
input_idxs = map (_ -> Int[], 1 : cid_counter[])
@@ -334,8 +365,8 @@ function split_system(ci::ClockInference{S}) where {S}
334
365
335
366
# breaks the system up into a continous and 0 or more discrete systems
336
367
tss = similar (cid_to_eq, S)
337
- for (id, (ieqs, ivars)) in enumerate (zip (cid_to_eq, cid_to_var))
338
- ts_i = system_subset (ts, ieqs, ivars)
368
+ for (id, (ieqs, iieqs, ivars)) in enumerate (zip (cid_to_eq, cid_to_init_eq , cid_to_var))
369
+ ts_i = system_subset (ts, ieqs, iieqs, ivars)
339
370
if id != continuous_id
340
371
ts_i = shift_discrete_system (ts_i)
341
372
@set! ts_i. structure. only_discrete = true
0 commit comments