1+ @data  ClockVertex begin 
2+     Variable (Int)
3+     Equation (Int)
4+     Clock (SciMLBase. AbstractClock)
5+ end 
6+ 
17struct  ClockInference{S}
28    """ Tearing state.""" 
39    ts:: S 
410    """ The time domain (discrete clock, continuous) of each equation.""" 
511    eq_domain:: Vector{TimeDomain} 
612    """ The output time domain (discrete clock, continuous) of each variable.""" 
713    var_domain:: Vector{TimeDomain} 
14+     inference_graph:: HyperGraph{ClockVertex.Type} 
815    """ The set of variables with concrete domains.""" 
916    inferred:: BitSet 
1017end 
@@ -22,7 +29,21 @@ function ClockInference(ts::TransformationState)
2229            var_domain[i] =  d
2330        end 
2431    end 
25-     ClockInference (ts, eq_domain, var_domain, inferred)
32+     inference_graph =  HyperGraph {ClockVertex.Type} ()
33+     for  i in  1 : nsrcs (graph)
34+         add_vertex! (inference_graph, ClockVertex. Equation (i))
35+     end 
36+     for  i in  1 : ndsts (graph)
37+         varvert =  ClockVertex. Variable (i)
38+         add_vertex! (inference_graph, varvert)
39+         v =  ts. fullvars[i]
40+         d =  get_time_domain (v)
41+         is_concrete_time_domain (d) ||  continue 
42+         dvert =  ClockVertex. Clock (d)
43+         add_vertex! (inference_graph, dvert)
44+         add_edge! (inference_graph, (varvert, dvert))
45+     end 
46+     ClockInference (ts, eq_domain, var_domain, inference_graph, inferred)
2647end 
2748
2849struct  NotInferredTimeDomain end 
7596Update the equation-to-time domain mapping by inferring the time domain from the variables. 
7697""" 
7798function  infer_clocks! (ci:: ClockInference )
78-     @unpack  ts, eq_domain, var_domain, inferred =  ci
99+     @unpack  ts, eq_domain, var_domain, inferred, inference_graph  =  ci
79100    @unpack  var_to_diff, graph =  ts. structure
80101    fullvars =  get_fullvars (ts)
81102    isempty (inferred) &&  return  ci
82-     #  TODO : add a graph type to do this lazily
83-     var_graph =  SimpleGraph (ndsts (graph))
84-     for  eq in  𝑠vertices (graph)
85-         vvs =  𝑠neighbors (graph, eq)
86-         if  ! isempty (vvs)
87-             fv, vs =  Iterators. peel (vvs)
88-             for  v in  vs
89-                 add_edge! (var_graph, fv, v)
90-             end 
91-         end 
103+ 
104+     var_to_idx =  Dict (fullvars .=>  eachindex (fullvars))
105+ 
106+     #  all shifted variables have the same clock as the unshifted variant
107+     for  (i, v) in  enumerate (fullvars)
108+         iscall (v) ||  continue 
109+         operation (v) isa  Shift ||  continue 
110+         unshifted =  only (arguments (v))
111+         add_edge! (inference_graph, (ClockVertex. Variable (i), ClockVertex. Variable (var_to_idx[unshifted])))
92112    end 
93-     for  v in  vertices (var_to_diff)
94-         if  (v′ =  var_to_diff[v]) != =  nothing 
95-             add_edge! (var_graph, v, v′)
113+ 
114+     #  preallocated buffers:
115+     #  variables in each equation
116+     varsbuf =  Set ()
117+     #  variables in each argument to an operator
118+     arg_varsbuf =  Set ()
119+     #  hyperedge for each equation
120+     hyperedge =  Set {ClockVertex.Type} ()
121+     #  hyperedge for each argument to an operator
122+     arg_hyperedge =  Set {ClockVertex.Type} ()
123+     #  mapping from `i` in `InferredDiscrete(i)` to the vertices in that inferred partition
124+     relative_hyperedges =  Dict {Int, Set{ClockVertex.Type}} ()
125+ 
126+     for  (ieq, eq) in  enumerate (equations (ts))
127+         empty! (varsbuf)
128+         empty! (hyperedge)
129+         #  get variables in equation
130+         vars! (varsbuf, eq; op =  Symbolics. Operator)
131+         #  add the equation to the hyperedge
132+         push! (hyperedge, ClockVertex. Equation (ieq))
133+         for  var in  varsbuf
134+             idx =  get (var_to_idx, var, nothing )
135+             #  if this is just a single variable, add it to the hyperedge
136+             if  idx isa  Int
137+                 push! (hyperedge, ClockVertex. Variable (idx))
138+                 #  we don't immediately `continue` here because this variable might be a
139+                 #  `Sample` or similar and we want the clock information from it if it is.
140+             end 
141+             #  now we only care about synchronous operators
142+             iscall (var) ||  continue 
143+             op =  operation (var)
144+             is_synchronous_operator (op) ||  continue 
145+ 
146+             #  arguments and corresponding time domains
147+             args =  arguments (var)
148+             tdomains =  input_timedomain (op)
149+             nargs =  length (args)
150+             ndoms =  length (tdomains)
151+             if  nargs !=  ndoms
152+                 throw (ArgumentError (""" 
153+                 Operator $op  applied to $nargs  arguments $args  but only returns $ndoms  \ 
154+                 domains $tdomains  from `input_timedomain`. 
155+                 """  ))
156+             end 
157+ 
158+             #  each relative clock mapping is only valid per operator application
159+             empty! (relative_hyperedges)
160+             for  (arg, domain) in  zip (args, tdomains)
161+                 empty! (arg_varsbuf)
162+                 empty! (arg_hyperedge)
163+                 #  get variables in argument
164+                 vars! (arg_varsbuf, arg; op =  Union{Differential, Shift})
165+                 #  get hyperedge for involved variables
166+                 for  v in  arg_varsbuf
167+                     vidx =  get (var_to_idx, v, nothing )
168+                     vidx ===  nothing  &&  continue 
169+                     push! (arg_hyperedge, ClockVertex. Variable (vidx))
170+                 end 
171+ 
172+                 Moshi. Match. @match  domain begin 
173+                     #  If the time domain for this argument is a clock, then all variables in this edge have that clock.
174+                     x:: SciMLBase.AbstractClock  =>  begin 
175+                         #  add the clock to the edge
176+                         push! (arg_hyperedge, ClockVertex. Clock (x))
177+                         #  add the edge to the graph
178+                         add_edge! (inference_graph, arg_hyperedge)
179+                     end 
180+                     #  We only know that this time domain is inferred. Treat it as a unique domain, all we know is that the
181+                     #  involved variables have the same clock.
182+                     InferredClock. Inferred () =>  add_edge! (inference_graph, arg_hyperedge)
183+                     #  All `InferredDiscrete` with the same `i` have the same clock (including output domain) so we don't
184+                     #  add the edge, and instead add this to the `relative_hyperedges` mapping.
185+                     InferredClock. InferredDiscrete (i) =>  begin 
186+                         relative_edge =  get! (() ->  Set {ClockVertex.Type} (), relative_hyperedges, i)
187+                         union! (relative_edge, arg_hyperedge)
188+                     end 
189+                 end 
190+             end 
191+ 
192+             outdomain =  output_timedomain (op)
193+             Moshi. Match. @match  outdomain begin 
194+                 x:: SciMLBase.AbstractClock  =>  begin 
195+                     push! (hyperedge, ClockVertex. Clock (x))
196+                 end 
197+                 InferredClock. Inferred () =>  nothing 
198+                 InferredClock. InferredDiscrete (i) =>  begin 
199+                     buffer =  get (relative_hyperedges, i, nothing )
200+                     if  buffer != =  nothing 
201+                         union! (hyperedge, buffer)
202+                         delete! (relative_hyperedges, i)
203+                     end 
204+                 end 
205+             end 
206+ 
207+             for  (_, relative_edge) in  relative_hyperedges
208+                 add_edge! (inference_graph, relative_edge)
209+             end 
96210        end 
211+ 
212+         add_edge! (inference_graph, hyperedge)
97213    end 
98-     cc =  connected_components (var_graph)
99-     for  c′ in  cc
100-         c =  BitSet (c′)
101-         idxs =  intersect (c, inferred)
102-         isempty (idxs) &&  continue 
103-         if  ! allequal (iscontinuous (var_domain[i]) for  i in  idxs)
104-             display (fullvars[c′])
105-             throw (ClockInferenceException (" Clocks are not consistent in connected component $(fullvars[c′]) " 
214+ 
215+     clock_partitions =  connectionsets (inference_graph)
216+     for  partition in  clock_partitions
217+         clockidxs =  findall (vert ->  Moshi. Data. isa_variant (vert, ClockVertex. Clock), partition)
218+         if  isempty (clockidxs)
219+             vidxs =  Int[vert.:1  for  vert in  partition if  Moshi. Data. isa_variant (vert, ClockVertex. Variable)]
220+             throw (ArgumentError (""" 
221+             Found clock partion with no associated clock. Involved variables: $(fullvars[vidxs]) . 
222+             """  ))
106223        end 
107-         vd =  var_domain[first (idxs)]
108-         for  v in  c′
109-             var_domain[v] =  vd
224+         if  length (clockidxs) >  1 
225+             vidxs =  Int[vert.:1  for  vert in  partition if  Moshi. Data. isa_variant (vert, ClockVertex. Variable)]
226+             clks =  [vert.:1  for  vert in  view (partition, clockidxs)]
227+             throw (ArgumentError (""" 
228+             Found clock partition with multiple associated clocks. Involved variables: \ 
229+             $(fullvars[vidxs]) . Involved clocks: $(clks) . 
230+             """  ))
110231        end 
111-     end 
112232
113-     for  v in  𝑑vertices (graph)
114-         vd =  var_domain[v]
115-         eqs =  𝑑neighbors (graph, v)
116-         isempty (eqs) &&  continue 
117-         for  eq in  eqs
118-             eq_domain[eq] =  vd
233+         clock =  partition[only (clockidxs)]. :1 
234+         for  vert in  partition
235+             Moshi. Match. @match  vert begin 
236+                 ClockVertex. Variable (i) =>  (var_domain[i] =  clock)
237+                 ClockVertex. Equation (i) =>  (eq_domain[i] =  clock)
238+                 ClockVertex. Clock (_) =>  nothing 
239+             end 
119240        end 
120241    end 
121242
0 commit comments