11@data  ClockVertex begin 
22    Variable (Int)
33    Equation (Int)
4+     InitEquation (Int)
45    Clock (SciMLBase. AbstractClock)
56end 
67
@@ -9,6 +10,8 @@ struct ClockInference{S}
910    ts:: S 
1011    """ The time domain (discrete clock, continuous) of each equation.""" 
1112    eq_domain:: Vector{TimeDomain} 
13+     """ The time domain (discrete clock, continuous) of each initialization equation.""" 
14+     init_eq_domain:: Vector{TimeDomain} 
1215    """ The output time domain (discrete clock, continuous) of each variable.""" 
1316    var_domain:: Vector{TimeDomain} 
1417    inference_graph:: HyperGraph{ClockVertex.Type} 
@@ -20,6 +23,8 @@ function ClockInference(ts::TransformationState)
2023    @unpack  structure =  ts
2124    @unpack  graph =  structure
2225    eq_domain =  TimeDomain[ContinuousClock () for  _ in  1 : nsrcs (graph)]
26+     init_eq_domain =  TimeDomain[ContinuousClock ()
27+                                 for  _ in  1 : length (initialization_equations (ts. sys))]
2328    var_domain =  TimeDomain[ContinuousClock () for  _ in  1 : ndsts (graph)]
2429    inferred =  BitSet ()
2530    for  (i, v) in  enumerate (get_fullvars (ts))
@@ -33,6 +38,9 @@ function ClockInference(ts::TransformationState)
3338    for  i in  1 : nsrcs (graph)
3439        add_vertex! (inference_graph, ClockVertex. Equation (i))
3540    end 
41+     for  i in  eachindex (initialization_equations (ts. sys))
42+         add_vertex! (inference_graph, ClockVertex. InitEquation (i))
43+     end 
3644    for  i in  1 : ndsts (graph)
3745        varvert =  ClockVertex. Variable (i)
3846        add_vertex! (inference_graph, varvert)
@@ -43,7 +51,7 @@ function ClockInference(ts::TransformationState)
4351        add_vertex! (inference_graph, dvert)
4452        add_edge! (inference_graph, (varvert, dvert))
4553    end 
46-     ClockInference (ts, eq_domain, var_domain, inference_graph, inferred)
54+     ClockInference (ts, eq_domain, init_eq_domain,  var_domain, inference_graph, inferred)
4755end 
4856
4957struct  NotInferredTimeDomain end 
96104Update the equation-to-time domain mapping by inferring the time domain from the variables. 
97105""" 
98106function  infer_clocks! (ci:: ClockInference )
99-     @unpack  ts, eq_domain, var_domain, inferred, inference_graph =  ci
107+     @unpack  ts, eq_domain, init_eq_domain,  var_domain, inferred, inference_graph =  ci
100108    @unpack  var_to_diff, graph =  ts. structure
101109    fullvars =  get_fullvars (ts)
102110    isempty (inferred) &&  return  ci
@@ -124,13 +132,18 @@ function infer_clocks!(ci::ClockInference)
124132    #  mapping from `i` in `InferredDiscrete(i)` to the vertices in that inferred partition
125133    relative_hyperedges =  Dict {Int, Set{ClockVertex.Type}} ()
126134
127-     for   (ieq, eq)  in   enumerate ( equations (ts) )
135+     function   infer_equation (ieq, eq, is_initialization_equation )
128136        empty! (varsbuf)
129137        empty! (hyperedge)
130138        #  get variables in equation
131139        vars! (varsbuf, eq; op =  Symbolics. Operator)
132140        #  add the equation to the hyperedge
133-         push! (hyperedge, ClockVertex. Equation (ieq))
141+         eq_node =  if  is_initialization_equation
142+             ClockVertex. InitEquation (ieq)
143+         else 
144+             ClockVertex. Equation (ieq)
145+         end 
146+         push! (hyperedge, eq_node)
134147        for  var in  varsbuf
135148            idx =  get (var_to_idx, var, nothing )
136149            #  if this is just a single variable, add it to the hyperedge
@@ -215,6 +228,12 @@ function infer_clocks!(ci::ClockInference)
215228
216229        add_edge! (inference_graph, hyperedge)
217230    end 
231+     for  (ieq, eq) in  enumerate (equations (ts))
232+         infer_equation (ieq, eq, false )
233+     end 
234+     for  (ieq, eq) in  enumerate (initialization_equations (ts. sys))
235+         infer_equation (ieq, eq, true )
236+     end 
218237
219238    clock_partitions =  connectionsets (inference_graph)
220239    for  partition in  clock_partitions
@@ -243,6 +262,7 @@ function infer_clocks!(ci::ClockInference)
243262            Moshi. Match. @match  vert begin 
244263                ClockVertex. Variable (i) =>  (var_domain[i] =  clock)
245264                ClockVertex. Equation (i) =>  (eq_domain[i] =  clock)
265+                 ClockVertex. InitEquation (i) =>  (init_eq_domain[i] =  clock)
246266                ClockVertex. Clock (_) =>  nothing 
247267            end 
248268        end 
@@ -278,14 +298,16 @@ end
278298For multi-clock systems, create a separate system for each clock in the system, along with associated equations. Return the updated tearing state, and the sets of clocked variables associated with each time domain. 
279299""" 
280300function  split_system (ci:: ClockInference{S} ) where  {S}
281-     @unpack  ts, eq_domain, var_domain, inferred =  ci
301+     @unpack  ts, eq_domain, init_eq_domain,  var_domain, inferred =  ci
282302    fullvars =  get_fullvars (ts)
283303    @unpack  graph =  ts. structure
284304    continuous_id =  Ref (0 )
285305    clock_to_id =  Dict {TimeDomain, Int} ()
286306    id_to_clock =  TimeDomain[]
287307    eq_to_cid =  Vector {Int} (undef, nsrcs (graph))
288308    cid_to_eq =  Vector{Int}[]
309+     init_eq_to_cid =  Vector {Int} (undef, length (initialization_equations (ts. sys)))
310+     cid_to_init_eq =  Vector{Int}[]
289311    var_to_cid =  Vector {Int} (undef, ndsts (graph))
290312    cid_to_var =  Vector{Int}[]
291313    #  cid_counter = number of clocks
@@ -311,6 +333,15 @@ function split_system(ci::ClockInference{S}) where {S}
311333        eq_to_cid[i] =  cid
312334        resize_or_push! (cid_to_eq, i, cid)
313335    end 
336+     #  NOTE: This assumes that there is at least one equation for each clock
337+     for  _ in  1 : length (cid_to_eq)
338+         push! (cid_to_init_eq, Int[])
339+     end 
340+     for  (i, d) in  enumerate (init_eq_domain)
341+         cid =  clock_to_id[d]
342+         init_eq_to_cid[i] =  cid
343+         push! (cid_to_init_eq[cid], i)
344+     end 
314345    continuous_id =  continuous_id[]
315346    #  for each clock partition what are the input (indexes/vars)
316347    input_idxs =  map (_ ->  Int[], 1 : cid_counter[])
@@ -334,8 +365,8 @@ function split_system(ci::ClockInference{S}) where {S}
334365
335366    #  breaks the system up into a continous and 0 or more discrete systems
336367    tss =  similar (cid_to_eq, S)
337-     for  (id, (ieqs, ivars)) in  enumerate (zip (cid_to_eq, cid_to_var))
338-         ts_i =  system_subset (ts, ieqs, ivars)
368+     for  (id, (ieqs, iieqs,  ivars)) in  enumerate (zip (cid_to_eq, cid_to_init_eq , cid_to_var))
369+         ts_i =  system_subset (ts, ieqs, iieqs,  ivars)
339370        if  id !=  continuous_id
340371            ts_i =  shift_discrete_system (ts_i)
341372            @set!  ts_i. structure. only_discrete =  true 
0 commit comments