1+ @data ClockVertex begin
2+ Variable (Int)
3+ Equation (Int)
4+ Clock (SciMLBase. AbstractClock)
5+ end
6+
17struct ClockInference{S}
28 """ Tearing state."""
39 ts:: S
410 """ The time domain (discrete clock, continuous) of each equation."""
511 eq_domain:: Vector{TimeDomain}
612 """ The output time domain (discrete clock, continuous) of each variable."""
713 var_domain:: Vector{TimeDomain}
14+ inference_graph:: HyperGraph{ClockVertex.Type}
815 """ The set of variables with concrete domains."""
916 inferred:: BitSet
1017end
@@ -22,7 +29,21 @@ function ClockInference(ts::TransformationState)
2229 var_domain[i] = d
2330 end
2431 end
25- ClockInference (ts, eq_domain, var_domain, inferred)
32+ inference_graph = HyperGraph {ClockVertex.Type} ()
33+ for i in 1 : nsrcs (graph)
34+ add_vertex! (inference_graph, ClockVertex. Equation (i))
35+ end
36+ for i in 1 : ndsts (graph)
37+ varvert = ClockVertex. Variable (i)
38+ add_vertex! (inference_graph, varvert)
39+ v = ts. fullvars[i]
40+ d = get_time_domain (v)
41+ is_concrete_time_domain (d) || continue
42+ dvert = ClockVertex. Clock (d)
43+ add_vertex! (inference_graph, dvert)
44+ add_edge! (inference_graph, (varvert, dvert))
45+ end
46+ ClockInference (ts, eq_domain, var_domain, inference_graph, inferred)
2647end
2748
2849struct NotInferredTimeDomain end
7596Update the equation-to-time domain mapping by inferring the time domain from the variables.
7697"""
7798function infer_clocks! (ci:: ClockInference )
78- @unpack ts, eq_domain, var_domain, inferred = ci
99+ @unpack ts, eq_domain, var_domain, inferred, inference_graph = ci
79100 @unpack var_to_diff, graph = ts. structure
80101 fullvars = get_fullvars (ts)
81102 isempty (inferred) && return ci
82- # TODO : add a graph type to do this lazily
83- var_graph = SimpleGraph ( ndsts (graph ))
84- for eq in 𝑠vertices (graph)
85- vvs = 𝑠neighbors (graph, eq)
86- if ! isempty (vvs )
87- fv, vs = Iterators . peel (vvs)
88- for v in vs
89- add_edge! (var_graph, fv, v )
90- end
91- end
103+
104+ var_to_idx = Dict (fullvars .=> eachindex (fullvars ))
105+
106+ # all shifted variables have the same clock as the unshifted variant
107+ for (i, v) in enumerate (fullvars )
108+ iscall (v) || continue
109+ operation (v) isa Shift || continue
110+ unshifted = only ( arguments (v) )
111+ add_edge! (inference_graph, (
112+ ClockVertex . Variable (i), ClockVertex . Variable (var_to_idx[unshifted])))
92113 end
93- for v in vertices (var_to_diff)
94- if (v′ = var_to_diff[v]) != = nothing
95- add_edge! (var_graph, v, v′)
114+
115+ # preallocated buffers:
116+ # variables in each equation
117+ varsbuf = Set ()
118+ # variables in each argument to an operator
119+ arg_varsbuf = Set ()
120+ # hyperedge for each equation
121+ hyperedge = Set {ClockVertex.Type} ()
122+ # hyperedge for each argument to an operator
123+ arg_hyperedge = Set {ClockVertex.Type} ()
124+ # mapping from `i` in `InferredDiscrete(i)` to the vertices in that inferred partition
125+ relative_hyperedges = Dict {Int, Set{ClockVertex.Type}} ()
126+
127+ for (ieq, eq) in enumerate (equations (ts))
128+ empty! (varsbuf)
129+ empty! (hyperedge)
130+ # get variables in equation
131+ vars! (varsbuf, eq; op = Symbolics. Operator)
132+ # add the equation to the hyperedge
133+ push! (hyperedge, ClockVertex. Equation (ieq))
134+ for var in varsbuf
135+ idx = get (var_to_idx, var, nothing )
136+ # if this is just a single variable, add it to the hyperedge
137+ if idx isa Int
138+ push! (hyperedge, ClockVertex. Variable (idx))
139+ # we don't immediately `continue` here because this variable might be a
140+ # `Sample` or similar and we want the clock information from it if it is.
141+ end
142+ # now we only care about synchronous operators
143+ iscall (var) || continue
144+ op = operation (var)
145+ is_synchronous_operator (op) || continue
146+
147+ # arguments and corresponding time domains
148+ args = arguments (var)
149+ tdomains = input_timedomain (op)
150+ nargs = length (args)
151+ ndoms = length (tdomains)
152+ if nargs != ndoms
153+ throw (ArgumentError ("""
154+ Operator $op applied to $nargs arguments $args but only returns $ndoms \
155+ domains $tdomains from `input_timedomain`.
156+ """ ))
157+ end
158+
159+ # each relative clock mapping is only valid per operator application
160+ empty! (relative_hyperedges)
161+ for (arg, domain) in zip (args, tdomains)
162+ empty! (arg_varsbuf)
163+ empty! (arg_hyperedge)
164+ # get variables in argument
165+ vars! (arg_varsbuf, arg; op = Union{Differential, Shift})
166+ # get hyperedge for involved variables
167+ for v in arg_varsbuf
168+ vidx = get (var_to_idx, v, nothing )
169+ vidx === nothing && continue
170+ push! (arg_hyperedge, ClockVertex. Variable (vidx))
171+ end
172+
173+ Moshi. Match. @match domain begin
174+ # If the time domain for this argument is a clock, then all variables in this edge have that clock.
175+ x:: SciMLBase.AbstractClock => begin
176+ # add the clock to the edge
177+ push! (arg_hyperedge, ClockVertex. Clock (x))
178+ # add the edge to the graph
179+ add_edge! (inference_graph, arg_hyperedge)
180+ end
181+ # We only know that this time domain is inferred. Treat it as a unique domain, all we know is that the
182+ # involved variables have the same clock.
183+ InferredClock. Inferred () => add_edge! (inference_graph, arg_hyperedge)
184+ # All `InferredDiscrete` with the same `i` have the same clock (including output domain) so we don't
185+ # add the edge, and instead add this to the `relative_hyperedges` mapping.
186+ InferredClock. InferredDiscrete (i) => begin
187+ relative_edge = get! (() -> Set {ClockVertex.Type} (), relative_hyperedges, i)
188+ union! (relative_edge, arg_hyperedge)
189+ end
190+ end
191+ end
192+
193+ outdomain = output_timedomain (op)
194+ Moshi. Match. @match outdomain begin
195+ x:: SciMLBase.AbstractClock => begin
196+ push! (hyperedge, ClockVertex. Clock (x))
197+ end
198+ InferredClock. Inferred () => nothing
199+ InferredClock. InferredDiscrete (i) => begin
200+ buffer = get (relative_hyperedges, i, nothing )
201+ if buffer != = nothing
202+ union! (hyperedge, buffer)
203+ delete! (relative_hyperedges, i)
204+ end
205+ end
206+ end
207+
208+ for (_, relative_edge) in relative_hyperedges
209+ add_edge! (inference_graph, relative_edge)
210+ end
96211 end
212+
213+ add_edge! (inference_graph, hyperedge)
97214 end
98- cc = connected_components (var_graph)
99- for c′ in cc
100- c = BitSet (c′)
101- idxs = intersect (c, inferred)
102- isempty (idxs) && continue
103- if ! allequal (iscontinuous (var_domain[i]) for i in idxs)
104- display (fullvars[c′])
105- throw (ClockInferenceException (" Clocks are not consistent in connected component $(fullvars[c′]) " ))
215+
216+ clock_partitions = connectionsets (inference_graph)
217+ for partition in clock_partitions
218+ clockidxs = findall (vert -> Moshi. Data. isa_variant (vert, ClockVertex. Clock), partition)
219+ if isempty (clockidxs)
220+ vidxs = Int[vert.:1
221+ for vert in partition
222+ if Moshi. Data. isa_variant (vert, ClockVertex. Variable)]
223+ throw (ArgumentError ("""
224+ Found clock partion with no associated clock. Involved variables: $(fullvars[vidxs]) .
225+ """ ))
106226 end
107- vd = var_domain[first (idxs)]
108- for v in c′
109- var_domain[v] = vd
227+ if length (clockidxs) > 1
228+ vidxs = Int[vert.:1
229+ for vert in partition
230+ if Moshi. Data. isa_variant (vert, ClockVertex. Variable)]
231+ clks = [vert.:1 for vert in view (partition, clockidxs)]
232+ throw (ArgumentError ("""
233+ Found clock partition with multiple associated clocks. Involved variables: \
234+ $(fullvars[vidxs]) . Involved clocks: $(clks) .
235+ """ ))
110236 end
111- end
112237
113- for v in 𝑑vertices (graph)
114- vd = var_domain[v]
115- eqs = 𝑑neighbors (graph, v)
116- isempty (eqs) && continue
117- for eq in eqs
118- eq_domain[eq] = vd
238+ clock = partition[only (clockidxs)]. :1
239+ for vert in partition
240+ Moshi. Match. @match vert begin
241+ ClockVertex. Variable (i) => (var_domain[i] = clock)
242+ ClockVertex. Equation (i) => (eq_domain[i] = clock)
243+ ClockVertex. Clock (_) => nothing
244+ end
119245 end
120246 end
121247
0 commit comments