1
+ @data ClockVertex begin
2
+ Variable (Int)
3
+ Equation (Int)
4
+ Clock (SciMLBase. AbstractClock)
5
+ end
6
+
1
7
struct ClockInference{S}
2
8
""" Tearing state."""
3
9
ts:: S
4
10
""" The time domain (discrete clock, continuous) of each equation."""
5
11
eq_domain:: Vector{TimeDomain}
6
12
""" The output time domain (discrete clock, continuous) of each variable."""
7
13
var_domain:: Vector{TimeDomain}
14
+ inference_graph:: HyperGraph{ClockVertex.Type}
8
15
""" The set of variables with concrete domains."""
9
16
inferred:: BitSet
10
17
end
@@ -22,7 +29,21 @@ function ClockInference(ts::TransformationState)
22
29
var_domain[i] = d
23
30
end
24
31
end
25
- ClockInference (ts, eq_domain, var_domain, inferred)
32
+ inference_graph = HyperGraph {ClockVertex.Type} ()
33
+ for i in 1 : nsrcs (graph)
34
+ add_vertex! (inference_graph, ClockVertex. Equation (i))
35
+ end
36
+ for i in 1 : ndsts (graph)
37
+ varvert = ClockVertex. Variable (i)
38
+ add_vertex! (inference_graph, varvert)
39
+ v = ts. fullvars[i]
40
+ d = get_time_domain (v)
41
+ is_concrete_time_domain (d) || continue
42
+ dvert = ClockVertex. Clock (d)
43
+ add_vertex! (inference_graph, dvert)
44
+ add_edge! (inference_graph, (varvert, dvert))
45
+ end
46
+ ClockInference (ts, eq_domain, var_domain, inference_graph, inferred)
26
47
end
27
48
28
49
struct NotInferredTimeDomain end
75
96
Update the equation-to-time domain mapping by inferring the time domain from the variables.
76
97
"""
77
98
function infer_clocks! (ci:: ClockInference )
78
- @unpack ts, eq_domain, var_domain, inferred = ci
99
+ @unpack ts, eq_domain, var_domain, inferred, inference_graph = ci
79
100
@unpack var_to_diff, graph = ts. structure
80
101
fullvars = get_fullvars (ts)
81
102
isempty (inferred) && return ci
82
- # TODO : add a graph type to do this lazily
83
- var_graph = SimpleGraph ( ndsts (graph ))
84
- for eq in 𝑠vertices (graph)
85
- vvs = 𝑠neighbors (graph, eq)
86
- if ! isempty (vvs )
87
- fv, vs = Iterators . peel (vvs)
88
- for v in vs
89
- add_edge! (var_graph, fv, v )
90
- end
91
- end
103
+
104
+ var_to_idx = Dict (fullvars .=> eachindex (fullvars ))
105
+
106
+ # all shifted variables have the same clock as the unshifted variant
107
+ for (i, v) in enumerate (fullvars )
108
+ iscall (v) || continue
109
+ operation (v) isa Shift || continue
110
+ unshifted = only ( arguments (v) )
111
+ add_edge! (inference_graph, (
112
+ ClockVertex . Variable (i), ClockVertex . Variable (var_to_idx[unshifted])))
92
113
end
93
- for v in vertices (var_to_diff)
94
- if (v′ = var_to_diff[v]) != = nothing
95
- add_edge! (var_graph, v, v′)
114
+
115
+ # preallocated buffers:
116
+ # variables in each equation
117
+ varsbuf = Set ()
118
+ # variables in each argument to an operator
119
+ arg_varsbuf = Set ()
120
+ # hyperedge for each equation
121
+ hyperedge = Set {ClockVertex.Type} ()
122
+ # hyperedge for each argument to an operator
123
+ arg_hyperedge = Set {ClockVertex.Type} ()
124
+ # mapping from `i` in `InferredDiscrete(i)` to the vertices in that inferred partition
125
+ relative_hyperedges = Dict {Int, Set{ClockVertex.Type}} ()
126
+
127
+ for (ieq, eq) in enumerate (equations (ts))
128
+ empty! (varsbuf)
129
+ empty! (hyperedge)
130
+ # get variables in equation
131
+ vars! (varsbuf, eq; op = Symbolics. Operator)
132
+ # add the equation to the hyperedge
133
+ push! (hyperedge, ClockVertex. Equation (ieq))
134
+ for var in varsbuf
135
+ idx = get (var_to_idx, var, nothing )
136
+ # if this is just a single variable, add it to the hyperedge
137
+ if idx isa Int
138
+ push! (hyperedge, ClockVertex. Variable (idx))
139
+ # we don't immediately `continue` here because this variable might be a
140
+ # `Sample` or similar and we want the clock information from it if it is.
141
+ end
142
+ # now we only care about synchronous operators
143
+ iscall (var) || continue
144
+ op = operation (var)
145
+ is_synchronous_operator (op) || continue
146
+
147
+ # arguments and corresponding time domains
148
+ args = arguments (var)
149
+ tdomains = input_timedomain (op)
150
+ nargs = length (args)
151
+ ndoms = length (tdomains)
152
+ if nargs != ndoms
153
+ throw (ArgumentError ("""
154
+ Operator $op applied to $nargs arguments $args but only returns $ndoms \
155
+ domains $tdomains from `input_timedomain`.
156
+ """ ))
157
+ end
158
+
159
+ # each relative clock mapping is only valid per operator application
160
+ empty! (relative_hyperedges)
161
+ for (arg, domain) in zip (args, tdomains)
162
+ empty! (arg_varsbuf)
163
+ empty! (arg_hyperedge)
164
+ # get variables in argument
165
+ vars! (arg_varsbuf, arg; op = Union{Differential, Shift})
166
+ # get hyperedge for involved variables
167
+ for v in arg_varsbuf
168
+ vidx = get (var_to_idx, v, nothing )
169
+ vidx === nothing && continue
170
+ push! (arg_hyperedge, ClockVertex. Variable (vidx))
171
+ end
172
+
173
+ Moshi. Match. @match domain begin
174
+ # If the time domain for this argument is a clock, then all variables in this edge have that clock.
175
+ x:: SciMLBase.AbstractClock => begin
176
+ # add the clock to the edge
177
+ push! (arg_hyperedge, ClockVertex. Clock (x))
178
+ # add the edge to the graph
179
+ add_edge! (inference_graph, arg_hyperedge)
180
+ end
181
+ # We only know that this time domain is inferred. Treat it as a unique domain, all we know is that the
182
+ # involved variables have the same clock.
183
+ InferredClock. Inferred () => add_edge! (inference_graph, arg_hyperedge)
184
+ # All `InferredDiscrete` with the same `i` have the same clock (including output domain) so we don't
185
+ # add the edge, and instead add this to the `relative_hyperedges` mapping.
186
+ InferredClock. InferredDiscrete (i) => begin
187
+ relative_edge = get! (() -> Set {ClockVertex.Type} (), relative_hyperedges, i)
188
+ union! (relative_edge, arg_hyperedge)
189
+ end
190
+ end
191
+ end
192
+
193
+ outdomain = output_timedomain (op)
194
+ Moshi. Match. @match outdomain begin
195
+ x:: SciMLBase.AbstractClock => begin
196
+ push! (hyperedge, ClockVertex. Clock (x))
197
+ end
198
+ InferredClock. Inferred () => nothing
199
+ InferredClock. InferredDiscrete (i) => begin
200
+ buffer = get (relative_hyperedges, i, nothing )
201
+ if buffer != = nothing
202
+ union! (hyperedge, buffer)
203
+ delete! (relative_hyperedges, i)
204
+ end
205
+ end
206
+ end
207
+
208
+ for (_, relative_edge) in relative_hyperedges
209
+ add_edge! (inference_graph, relative_edge)
210
+ end
96
211
end
212
+
213
+ add_edge! (inference_graph, hyperedge)
97
214
end
98
- cc = connected_components (var_graph)
99
- for c′ in cc
100
- c = BitSet (c′)
101
- idxs = intersect (c, inferred)
102
- isempty (idxs) && continue
103
- if ! allequal (iscontinuous (var_domain[i]) for i in idxs)
104
- display (fullvars[c′])
105
- throw (ClockInferenceException (" Clocks are not consistent in connected component $(fullvars[c′]) " ))
215
+
216
+ clock_partitions = connectionsets (inference_graph)
217
+ for partition in clock_partitions
218
+ clockidxs = findall (vert -> Moshi. Data. isa_variant (vert, ClockVertex. Clock), partition)
219
+ if isempty (clockidxs)
220
+ vidxs = Int[vert.:1
221
+ for vert in partition
222
+ if Moshi. Data. isa_variant (vert, ClockVertex. Variable)]
223
+ throw (ArgumentError ("""
224
+ Found clock partion with no associated clock. Involved variables: $(fullvars[vidxs]) .
225
+ """ ))
106
226
end
107
- vd = var_domain[first (idxs)]
108
- for v in c′
109
- var_domain[v] = vd
227
+ if length (clockidxs) > 1
228
+ vidxs = Int[vert.:1
229
+ for vert in partition
230
+ if Moshi. Data. isa_variant (vert, ClockVertex. Variable)]
231
+ clks = [vert.:1 for vert in view (partition, clockidxs)]
232
+ throw (ArgumentError ("""
233
+ Found clock partition with multiple associated clocks. Involved variables: \
234
+ $(fullvars[vidxs]) . Involved clocks: $(clks) .
235
+ """ ))
110
236
end
111
- end
112
237
113
- for v in 𝑑vertices (graph)
114
- vd = var_domain[v]
115
- eqs = 𝑑neighbors (graph, v)
116
- isempty (eqs) && continue
117
- for eq in eqs
118
- eq_domain[eq] = vd
238
+ clock = partition[only (clockidxs)]. :1
239
+ for vert in partition
240
+ Moshi. Match. @match vert begin
241
+ ClockVertex. Variable (i) => (var_domain[i] = clock)
242
+ ClockVertex. Equation (i) => (eq_domain[i] = clock)
243
+ ClockVertex. Clock (_) => nothing
244
+ end
119
245
end
120
246
end
121
247
0 commit comments