11@data ClockVertex begin
22 Variable (Int)
33 Equation (Int)
4+ InitEquation (Int)
45 Clock (SciMLBase. AbstractClock)
56end
67
@@ -9,6 +10,8 @@ struct ClockInference{S}
910 ts:: S
1011 """ The time domain (discrete clock, continuous) of each equation."""
1112 eq_domain:: Vector{TimeDomain}
13+ """ The time domain (discrete clock, continuous) of each initialization equation."""
14+ init_eq_domain:: Vector{TimeDomain}
1215 """ The output time domain (discrete clock, continuous) of each variable."""
1316 var_domain:: Vector{TimeDomain}
1417 inference_graph:: HyperGraph{ClockVertex.Type}
@@ -20,6 +23,8 @@ function ClockInference(ts::TransformationState)
2023 @unpack structure = ts
2124 @unpack graph = structure
2225 eq_domain = TimeDomain[ContinuousClock () for _ in 1 : nsrcs (graph)]
26+ init_eq_domain = TimeDomain[ContinuousClock ()
27+ for _ in 1 : length (initialization_equations (ts. sys))]
2328 var_domain = TimeDomain[ContinuousClock () for _ in 1 : ndsts (graph)]
2429 inferred = BitSet ()
2530 for (i, v) in enumerate (get_fullvars (ts))
@@ -33,6 +38,9 @@ function ClockInference(ts::TransformationState)
3338 for i in 1 : nsrcs (graph)
3439 add_vertex! (inference_graph, ClockVertex. Equation (i))
3540 end
41+ for i in eachindex (initialization_equations (ts. sys))
42+ add_vertex! (inference_graph, ClockVertex. InitEquation (i))
43+ end
3644 for i in 1 : ndsts (graph)
3745 varvert = ClockVertex. Variable (i)
3846 add_vertex! (inference_graph, varvert)
@@ -43,7 +51,7 @@ function ClockInference(ts::TransformationState)
4351 add_vertex! (inference_graph, dvert)
4452 add_edge! (inference_graph, (varvert, dvert))
4553 end
46- ClockInference (ts, eq_domain, var_domain, inference_graph, inferred)
54+ ClockInference (ts, eq_domain, init_eq_domain, var_domain, inference_graph, inferred)
4755end
4856
4957struct NotInferredTimeDomain end
96104Update the equation-to-time domain mapping by inferring the time domain from the variables.
97105"""
98106function infer_clocks! (ci:: ClockInference )
99- @unpack ts, eq_domain, var_domain, inferred, inference_graph = ci
107+ @unpack ts, eq_domain, init_eq_domain, var_domain, inferred, inference_graph = ci
100108 @unpack var_to_diff, graph = ts. structure
101109 fullvars = get_fullvars (ts)
102110 isempty (inferred) && return ci
@@ -124,13 +132,18 @@ function infer_clocks!(ci::ClockInference)
124132 # mapping from `i` in `InferredDiscrete(i)` to the vertices in that inferred partition
125133 relative_hyperedges = Dict {Int, Set{ClockVertex.Type}} ()
126134
127- for (ieq, eq) in enumerate ( equations (ts) )
135+ function infer_equation (ieq, eq, is_initialization_equation )
128136 empty! (varsbuf)
129137 empty! (hyperedge)
130138 # get variables in equation
131139 vars! (varsbuf, eq; op = Symbolics. Operator)
132140 # add the equation to the hyperedge
133- push! (hyperedge, ClockVertex. Equation (ieq))
141+ eq_node = if is_initialization_equation
142+ ClockVertex. InitEquation (ieq)
143+ else
144+ ClockVertex. Equation (ieq)
145+ end
146+ push! (hyperedge, eq_node)
134147 for var in varsbuf
135148 idx = get (var_to_idx, var, nothing )
136149 # if this is just a single variable, add it to the hyperedge
@@ -215,6 +228,12 @@ function infer_clocks!(ci::ClockInference)
215228
216229 add_edge! (inference_graph, hyperedge)
217230 end
231+ for (ieq, eq) in enumerate (equations (ts))
232+ infer_equation (ieq, eq, false )
233+ end
234+ for (ieq, eq) in enumerate (initialization_equations (ts. sys))
235+ infer_equation (ieq, eq, true )
236+ end
218237
219238 clock_partitions = connectionsets (inference_graph)
220239 for partition in clock_partitions
@@ -243,6 +262,7 @@ function infer_clocks!(ci::ClockInference)
243262 Moshi. Match. @match vert begin
244263 ClockVertex. Variable (i) => (var_domain[i] = clock)
245264 ClockVertex. Equation (i) => (eq_domain[i] = clock)
265+ ClockVertex. InitEquation (i) => (init_eq_domain[i] = clock)
246266 ClockVertex. Clock (_) => nothing
247267 end
248268 end
@@ -278,14 +298,16 @@ end
278298For multi-clock systems, create a separate system for each clock in the system, along with associated equations. Return the updated tearing state, and the sets of clocked variables associated with each time domain.
279299"""
280300function split_system (ci:: ClockInference{S} ) where {S}
281- @unpack ts, eq_domain, var_domain, inferred = ci
301+ @unpack ts, eq_domain, init_eq_domain, var_domain, inferred = ci
282302 fullvars = get_fullvars (ts)
283303 @unpack graph = ts. structure
284304 continuous_id = Ref (0 )
285305 clock_to_id = Dict {TimeDomain, Int} ()
286306 id_to_clock = TimeDomain[]
287307 eq_to_cid = Vector {Int} (undef, nsrcs (graph))
288308 cid_to_eq = Vector{Int}[]
309+ init_eq_to_cid = Vector {Int} (undef, length (initialization_equations (ts. sys)))
310+ cid_to_init_eq = Vector{Int}[]
289311 var_to_cid = Vector {Int} (undef, ndsts (graph))
290312 cid_to_var = Vector{Int}[]
291313 # cid_counter = number of clocks
@@ -311,6 +333,15 @@ function split_system(ci::ClockInference{S}) where {S}
311333 eq_to_cid[i] = cid
312334 resize_or_push! (cid_to_eq, i, cid)
313335 end
336+ # NOTE: This assumes that there is at least one equation for each clock
337+ for _ in 1 : length (cid_to_eq)
338+ push! (cid_to_init_eq, Int[])
339+ end
340+ for (i, d) in enumerate (init_eq_domain)
341+ cid = clock_to_id[d]
342+ init_eq_to_cid[i] = cid
343+ push! (cid_to_init_eq[cid], i)
344+ end
314345 continuous_id = continuous_id[]
315346 # for each clock partition what are the input (indexes/vars)
316347 input_idxs = map (_ -> Int[], 1 : cid_counter[])
@@ -334,8 +365,8 @@ function split_system(ci::ClockInference{S}) where {S}
334365
335366 # breaks the system up into a continous and 0 or more discrete systems
336367 tss = similar (cid_to_eq, S)
337- for (id, (ieqs, ivars)) in enumerate (zip (cid_to_eq, cid_to_var))
338- ts_i = system_subset (ts, ieqs, ivars)
368+ for (id, (ieqs, iieqs, ivars)) in enumerate (zip (cid_to_eq, cid_to_init_eq , cid_to_var))
369+ ts_i = system_subset (ts, ieqs, iieqs, ivars)
339370 if id != continuous_id
340371 ts_i = shift_discrete_system (ts_i)
341372 @set! ts_i. structure. only_discrete = true
0 commit comments