5858
5959struct NotInferredTimeDomain end
6060
61+ struct InferEquationClosure
62+ varsbuf:: Set{SymbolicT}
63+ # variables in each argument to an operator
64+ arg_varsbuf:: Set{SymbolicT}
65+ # hyperedge for each equation
66+ hyperedge:: Set{ClockVertex.Type}
67+ # hyperedge for each argument to an operator
68+ arg_hyperedge:: Set{ClockVertex.Type}
69+ # mapping from `i` in `InferredDiscrete(i)` to the vertices in that inferred partition
70+ relative_hyperedges:: Dict{Int, Set{ClockVertex.Type}}
71+ var_to_idx:: Dict{SymbolicT, Int}
72+ inference_graph:: HyperGraph{ClockVertex.Type}
73+ end
74+
75+ function InferEquationClosure(var_to_idx, inference_graph)
76+ InferEquationClosure(Set{SymbolicT}(), Set{SymbolicT}(), Set{ClockVertex. Type}(),
77+ Set{ClockVertex. Type}(), Dict{Int, Set{ClockVertex. Type}}(),
78+ var_to_idx, inference_graph)
79+ end
80+
81+ function (iec:: InferEquationClosure )(ieq:: Int , eq:: Equation , is_initialization_equation:: Bool )
82+ (; varsbuf, arg_varsbuf, hyperedge, arg_hyperedge, relative_hyperedges) = iec
83+ (; var_to_idx, inference_graph) = iec
84+ empty!(varsbuf)
85+ empty!(hyperedge)
86+ # get variables in equation
87+ SU. search_variables!(varsbuf, eq; is_atomic = MTKBase. OperatorIsAtomic{SU. Operator}())
88+ # add the equation to the hyperedge
89+ eq_node = if is_initialization_equation
90+ ClockVertex. InitEquation(ieq)
91+ else
92+ ClockVertex. Equation(ieq)
93+ end
94+ push!(hyperedge, eq_node)
95+ for var in varsbuf
96+ idx = get(var_to_idx, var, nothing )
97+ # if this is just a single variable, add it to the hyperedge
98+ if idx isa Int
99+ push!(hyperedge, ClockVertex. Variable(idx))
100+ # we don't immediately `continue` here because this variable might be a
101+ # `Sample` or similar and we want the clock information from it if it is.
102+ end
103+ # now we only care about synchronous operators
104+ op, args = @match var begin
105+ BSImpl. Term(; f, args) && if is_timevarying_operator(f):: Bool end => (f, args)
106+ _ => continue
107+ end
108+
109+ # arguments and corresponding time domains
110+ tdomains = input_timedomain(op):: Vector{InputTimeDomainElT}
111+ nargs = length(args)
112+ ndoms = length(tdomains)
113+ if nargs != ndoms
114+ throw(ArgumentError("""
115+ Operator $op applied to $nargs arguments $args but only returns $ndoms \
116+ domains $tdomains from `input_timedomain`.
117+ """ ))
118+ end
119+
120+ # each relative clock mapping is only valid per operator application
121+ empty!(relative_hyperedges)
122+ for (arg, domain) in zip(args, tdomains)
123+ empty!(arg_varsbuf)
124+ empty!(arg_hyperedge)
125+ # get variables in argument
126+ SU. search_variables!(arg_varsbuf, arg; is_atomic = MTKBase. OperatorIsAtomic{Union{Differential, MTKBase. Shift}}())
127+ # get hyperedge for involved variables
128+ for v in arg_varsbuf
129+ vidx = get(var_to_idx, v, nothing )
130+ vidx === nothing && continue
131+ push!(arg_hyperedge, ClockVertex. Variable(vidx))
132+ end
133+
134+ @match domain begin
135+ # If the time domain for this argument is a clock, then all variables in this edge have that clock.
136+ x:: SciMLBase.AbstractClock => begin
137+ # add the clock to the edge
138+ push!(arg_hyperedge, ClockVertex. Clock(x))
139+ # add the edge to the graph
140+ add_edge!(inference_graph, arg_hyperedge)
141+ end
142+ # We only know that this time domain is inferred. Treat it as a unique domain, all we know is that the
143+ # involved variables have the same clock.
144+ InferredClock. Inferred() => add_edge!(inference_graph, arg_hyperedge)
145+ # All `InferredDiscrete` with the same `i` have the same clock (including output domain) so we don't
146+ # add the edge, and instead add this to the `relative_hyperedges` mapping.
147+ InferredClock. InferredDiscrete(i) => begin
148+ relative_edge = get!(Set{ClockVertex. Type}, relative_hyperedges, i)
149+ union!(relative_edge, arg_hyperedge)
150+ end
151+ end
152+ end
153+
154+ outdomain = output_timedomain(op)
155+ @match outdomain begin
156+ x:: SciMLBase.AbstractClock => begin
157+ push!(hyperedge, ClockVertex. Clock(x))
158+ end
159+ InferredClock. Inferred() => nothing
160+ InferredClock. InferredDiscrete(i) => begin
161+ buffer = get(relative_hyperedges, i, nothing )
162+ if buffer != = nothing
163+ union!(hyperedge, buffer)
164+ delete!(relative_hyperedges, i)
165+ end
166+ end
167+ end
168+
169+ for (_, relative_edge) in relative_hyperedges
170+ add_edge!(inference_graph, relative_edge)
171+ end
172+ end
173+
174+ add_edge!(inference_graph, hyperedge)
175+ end
176+
61177"""
62178Update the equation-to-time domain mapping by inferring the time domain from the variables.
63179"""
64180function infer_clocks!(ci:: ClockInference )
65181 (; ts, eq_domain, init_eq_domain, var_domain, inferred, inference_graph) = ci
66- (; var_to_diff, graph) = ts. structure
67182 sys = get_sys(ts)
68183 fullvars = StateSelection. get_fullvars(ts)
69184 isempty(inferred) && return ci
70185
71- var_to_idx = Dict{SymbolicT, Int}(fullvars .=> eachindex(fullvars))
186+ var_to_idx = Dict{SymbolicT, Int}()
187+ sizehint!(var_to_idx, length(fullvars))
188+ for (i, v) in enumerate(fullvars)
189+ var_to_idx[v] = i
190+ end
72191
73192 # all shifted variables have the same clock as the unshifted variant
74193 for (i, v) in enumerate(fullvars)
@@ -81,112 +200,8 @@ function infer_clocks!(ci::ClockInference)
81200 _ => nothing
82201 end
83202 end
203+ infer_equation = InferEquationClosure(var_to_idx, inference_graph)
84204
85- # preallocated buffers:
86- # variables in each equation
87- varsbuf = Set{SymbolicT}()
88- # variables in each argument to an operator
89- arg_varsbuf = Set{SymbolicT}()
90- # hyperedge for each equation
91- hyperedge = Set{ClockVertex. Type}()
92- # hyperedge for each argument to an operator
93- arg_hyperedge = Set{ClockVertex. Type}()
94- # mapping from `i` in `InferredDiscrete(i)` to the vertices in that inferred partition
95- relative_hyperedges = Dict{Int, Set{ClockVertex. Type}}()
96-
97- function infer_equation(ieq, eq, is_initialization_equation)
98- empty!(varsbuf)
99- empty!(hyperedge)
100- # get variables in equation
101- SU. search_variables!(varsbuf, eq; is_atomic = MTKBase. OperatorIsAtomic{SU. Operator}())
102- # add the equation to the hyperedge
103- eq_node = if is_initialization_equation
104- ClockVertex. InitEquation(ieq)
105- else
106- ClockVertex. Equation(ieq)
107- end
108- push!(hyperedge, eq_node)
109- for var in varsbuf
110- idx = get(var_to_idx, var, nothing )
111- # if this is just a single variable, add it to the hyperedge
112- if idx isa Int
113- push!(hyperedge, ClockVertex. Variable(idx))
114- # we don't immediately `continue` here because this variable might be a
115- # `Sample` or similar and we want the clock information from it if it is.
116- end
117- # now we only care about synchronous operators
118- op, args = @match var begin
119- BSImpl. Term(; f, args) && if is_timevarying_operator(f):: Bool end => (f, args)
120- _ => continue
121- end
122-
123- # arguments and corresponding time domains
124- tdomains = input_timedomain(op):: Vector{InputTimeDomainElT}
125- nargs = length(args)
126- ndoms = length(tdomains)
127- if nargs != ndoms
128- throw(ArgumentError("""
129- Operator $op applied to $nargs arguments $args but only returns $ndoms \
130- domains $tdomains from `input_timedomain`.
131- """ ))
132- end
133-
134- # each relative clock mapping is only valid per operator application
135- empty!(relative_hyperedges)
136- for (arg, domain) in zip(args, tdomains)
137- empty!(arg_varsbuf)
138- empty!(arg_hyperedge)
139- # get variables in argument
140- SU. search_variables!(arg_varsbuf, arg; is_atomic = MTKBase. OperatorIsAtomic{Union{Differential, MTKBase. Shift}}())
141- # get hyperedge for involved variables
142- for v in arg_varsbuf
143- vidx = get(var_to_idx, v, nothing )
144- vidx === nothing && continue
145- push!(arg_hyperedge, ClockVertex. Variable(vidx))
146- end
147-
148- @match domain begin
149- # If the time domain for this argument is a clock, then all variables in this edge have that clock.
150- x:: SciMLBase.AbstractClock => begin
151- # add the clock to the edge
152- push!(arg_hyperedge, ClockVertex. Clock(x))
153- # add the edge to the graph
154- add_edge!(inference_graph, arg_hyperedge)
155- end
156- # We only know that this time domain is inferred. Treat it as a unique domain, all we know is that the
157- # involved variables have the same clock.
158- InferredClock. Inferred() => add_edge!(inference_graph, arg_hyperedge)
159- # All `InferredDiscrete` with the same `i` have the same clock (including output domain) so we don't
160- # add the edge, and instead add this to the `relative_hyperedges` mapping.
161- InferredClock. InferredDiscrete(i) => begin
162- relative_edge = get!(Set{ClockVertex. Type}, relative_hyperedges, i)
163- union!(relative_edge, arg_hyperedge)
164- end
165- end
166- end
167-
168- outdomain = output_timedomain(op)
169- @match outdomain begin
170- x:: SciMLBase.AbstractClock => begin
171- push!(hyperedge, ClockVertex. Clock(x))
172- end
173- InferredClock. Inferred() => nothing
174- InferredClock. InferredDiscrete(i) => begin
175- buffer = get(relative_hyperedges, i, nothing )
176- if buffer != = nothing
177- union!(hyperedge, buffer)
178- delete!(relative_hyperedges, i)
179- end
180- end
181- end
182-
183- for (_, relative_edge) in relative_hyperedges
184- add_edge!(inference_graph, relative_edge)
185- end
186- end
187-
188- add_edge!(inference_graph, hyperedge)
189- end
190205 for (ieq, eq) in enumerate(MTKBase. equations(sys))
191206 infer_equation(ieq, eq, false )
192207 end
@@ -212,7 +227,9 @@ function infer_clocks!(ci::ClockInference)
212227 """ ))
213228 end
214229
215- clock = partition[only(clockidxs)]. :1
230+ clock = Moshi. Match. @match partition[only(clockidxs)] begin
231+ ClockVertex. Clock(clk) => clk
232+ end
216233 for vert in partition
217234 Moshi. Match. @match vert begin
218235 ClockVertex. Variable(i) => (var_domain[i] = clock)
@@ -275,19 +292,15 @@ function split_system(ci::ClockInference{S}) where {S}
275292 # populates clock_to_id and id_to_clock
276293 # checks if there is a continuous_id (for some reason? clock to id does this too)
277294 for (i, d) in enumerate(eq_domain)
278- cid = let cid_counter = cid_counter, id_to_clock = id_to_clock,
279- continuous_id = continuous_id
280-
281- # Fill the clock_to_id dict as you go,
282- # ContinuousClock() => 1, ...
283- get!(clock_to_id, d) do
284- cid = (cid_counter[] += 1 )
285- push!(id_to_clock, d)
286- if d == SciMLBase. ContinuousClock()
287- continuous_id[] = cid
288- end
289- cid
295+ # We don't use `get!` here because that desperately wants to box things
296+ cid = get(clock_to_id, d, 0 )
297+ if iszero(cid)
298+ cid = (cid_counter[] += 1 )
299+ push!(id_to_clock, d)
300+ if d === SciMLBase. ContinuousClock()
301+ continuous_id[] = cid
290302 end
303+ clock_to_id[d] = cid
291304 end
292305 eq_to_cid[i] = cid
293306 resize_or_push!(cid_to_eq, i, cid)
0 commit comments