Skip to content

Commit 32c2ce6

Browse files
Merge pull request #11 from JuliaComputing/as/precompilation
refactor: improve precompile-friendliness
2 parents ca73937 + 448e0f3 commit 32c2ce6

File tree

13 files changed

+273
-182
lines changed

13 files changed

+273
-182
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1212
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
1313
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
1414
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
15-
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
1615

1716
[weakdeps]
1817
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
@@ -29,7 +28,6 @@ LinearAlgebra = "1.11.0"
2928
OrderedCollections = "1"
3029
Setfield = "1.1.1"
3130
SparseArrays = "1.11.0"
32-
UnPack = "1.0.2"
3331
julia = "1.9"
3432

3533
[extras]

lib/ModelingToolkitTearing/src/clock_inference/clock_inference.jl

Lines changed: 133 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,136 @@ end
5858

5959
struct NotInferredTimeDomain end
6060

61+
struct InferEquationClosure
62+
varsbuf::Set{SymbolicT}
63+
# variables in each argument to an operator
64+
arg_varsbuf::Set{SymbolicT}
65+
# hyperedge for each equation
66+
hyperedge::Set{ClockVertex.Type}
67+
# hyperedge for each argument to an operator
68+
arg_hyperedge::Set{ClockVertex.Type}
69+
# mapping from `i` in `InferredDiscrete(i)` to the vertices in that inferred partition
70+
relative_hyperedges::Dict{Int, Set{ClockVertex.Type}}
71+
var_to_idx::Dict{SymbolicT, Int}
72+
inference_graph::HyperGraph{ClockVertex.Type}
73+
end
74+
75+
function InferEquationClosure(var_to_idx, inference_graph)
76+
InferEquationClosure(Set{SymbolicT}(), Set{SymbolicT}(), Set{ClockVertex.Type}(),
77+
Set{ClockVertex.Type}(), Dict{Int, Set{ClockVertex.Type}}(),
78+
var_to_idx, inference_graph)
79+
end
80+
81+
function (iec::InferEquationClosure)(ieq::Int, eq::Equation, is_initialization_equation::Bool)
82+
(; varsbuf, arg_varsbuf, hyperedge, arg_hyperedge, relative_hyperedges) = iec
83+
(; var_to_idx, inference_graph) = iec
84+
empty!(varsbuf)
85+
empty!(hyperedge)
86+
# get variables in equation
87+
SU.search_variables!(varsbuf, eq; is_atomic = MTKBase.OperatorIsAtomic{SU.Operator}())
88+
# add the equation to the hyperedge
89+
eq_node = if is_initialization_equation
90+
ClockVertex.InitEquation(ieq)
91+
else
92+
ClockVertex.Equation(ieq)
93+
end
94+
push!(hyperedge, eq_node)
95+
for var in varsbuf
96+
idx = get(var_to_idx, var, nothing)
97+
# if this is just a single variable, add it to the hyperedge
98+
if idx isa Int
99+
push!(hyperedge, ClockVertex.Variable(idx))
100+
# we don't immediately `continue` here because this variable might be a
101+
# `Sample` or similar and we want the clock information from it if it is.
102+
end
103+
# now we only care about synchronous operators
104+
op, args = @match var begin
105+
BSImpl.Term(; f, args) && if is_timevarying_operator(f)::Bool end => (f, args)
106+
_ => continue
107+
end
108+
109+
# arguments and corresponding time domains
110+
tdomains = input_timedomain(op)::Vector{InputTimeDomainElT}
111+
nargs = length(args)
112+
ndoms = length(tdomains)
113+
if nargs != ndoms
114+
throw(ArgumentError("""
115+
Operator $op applied to $nargs arguments $args but only returns $ndoms \
116+
domains $tdomains from `input_timedomain`.
117+
"""))
118+
end
119+
120+
# each relative clock mapping is only valid per operator application
121+
empty!(relative_hyperedges)
122+
for (arg, domain) in zip(args, tdomains)
123+
empty!(arg_varsbuf)
124+
empty!(arg_hyperedge)
125+
# get variables in argument
126+
SU.search_variables!(arg_varsbuf, arg; is_atomic = MTKBase.OperatorIsAtomic{Union{Differential, MTKBase.Shift}}())
127+
# get hyperedge for involved variables
128+
for v in arg_varsbuf
129+
vidx = get(var_to_idx, v, nothing)
130+
vidx === nothing && continue
131+
push!(arg_hyperedge, ClockVertex.Variable(vidx))
132+
end
133+
134+
@match domain begin
135+
# If the time domain for this argument is a clock, then all variables in this edge have that clock.
136+
x::SciMLBase.AbstractClock => begin
137+
# add the clock to the edge
138+
push!(arg_hyperedge, ClockVertex.Clock(x))
139+
# add the edge to the graph
140+
add_edge!(inference_graph, arg_hyperedge)
141+
end
142+
# We only know that this time domain is inferred. Treat it as a unique domain, all we know is that the
143+
# involved variables have the same clock.
144+
InferredClock.Inferred() => add_edge!(inference_graph, arg_hyperedge)
145+
# All `InferredDiscrete` with the same `i` have the same clock (including output domain) so we don't
146+
# add the edge, and instead add this to the `relative_hyperedges` mapping.
147+
InferredClock.InferredDiscrete(i) => begin
148+
relative_edge = get!(Set{ClockVertex.Type}, relative_hyperedges, i)
149+
union!(relative_edge, arg_hyperedge)
150+
end
151+
end
152+
end
153+
154+
outdomain = output_timedomain(op)
155+
@match outdomain begin
156+
x::SciMLBase.AbstractClock => begin
157+
push!(hyperedge, ClockVertex.Clock(x))
158+
end
159+
InferredClock.Inferred() => nothing
160+
InferredClock.InferredDiscrete(i) => begin
161+
buffer = get(relative_hyperedges, i, nothing)
162+
if buffer !== nothing
163+
union!(hyperedge, buffer)
164+
delete!(relative_hyperedges, i)
165+
end
166+
end
167+
end
168+
169+
for (_, relative_edge) in relative_hyperedges
170+
add_edge!(inference_graph, relative_edge)
171+
end
172+
end
173+
174+
add_edge!(inference_graph, hyperedge)
175+
end
176+
61177
"""
62178
Update the equation-to-time domain mapping by inferring the time domain from the variables.
63179
"""
64180
function infer_clocks!(ci::ClockInference)
65181
(; ts, eq_domain, init_eq_domain, var_domain, inferred, inference_graph) = ci
66-
(; var_to_diff, graph) = ts.structure
67182
sys = get_sys(ts)
68183
fullvars = StateSelection.get_fullvars(ts)
69184
isempty(inferred) && return ci
70185

71-
var_to_idx = Dict{SymbolicT, Int}(fullvars .=> eachindex(fullvars))
186+
var_to_idx = Dict{SymbolicT, Int}()
187+
sizehint!(var_to_idx, length(fullvars))
188+
for (i, v) in enumerate(fullvars)
189+
var_to_idx[v] = i
190+
end
72191

73192
# all shifted variables have the same clock as the unshifted variant
74193
for (i, v) in enumerate(fullvars)
@@ -81,112 +200,8 @@ function infer_clocks!(ci::ClockInference)
81200
_ => nothing
82201
end
83202
end
203+
infer_equation = InferEquationClosure(var_to_idx, inference_graph)
84204

85-
# preallocated buffers:
86-
# variables in each equation
87-
varsbuf = Set{SymbolicT}()
88-
# variables in each argument to an operator
89-
arg_varsbuf = Set{SymbolicT}()
90-
# hyperedge for each equation
91-
hyperedge = Set{ClockVertex.Type}()
92-
# hyperedge for each argument to an operator
93-
arg_hyperedge = Set{ClockVertex.Type}()
94-
# mapping from `i` in `InferredDiscrete(i)` to the vertices in that inferred partition
95-
relative_hyperedges = Dict{Int, Set{ClockVertex.Type}}()
96-
97-
function infer_equation(ieq, eq, is_initialization_equation)
98-
empty!(varsbuf)
99-
empty!(hyperedge)
100-
# get variables in equation
101-
SU.search_variables!(varsbuf, eq; is_atomic = MTKBase.OperatorIsAtomic{SU.Operator}())
102-
# add the equation to the hyperedge
103-
eq_node = if is_initialization_equation
104-
ClockVertex.InitEquation(ieq)
105-
else
106-
ClockVertex.Equation(ieq)
107-
end
108-
push!(hyperedge, eq_node)
109-
for var in varsbuf
110-
idx = get(var_to_idx, var, nothing)
111-
# if this is just a single variable, add it to the hyperedge
112-
if idx isa Int
113-
push!(hyperedge, ClockVertex.Variable(idx))
114-
# we don't immediately `continue` here because this variable might be a
115-
# `Sample` or similar and we want the clock information from it if it is.
116-
end
117-
# now we only care about synchronous operators
118-
op, args = @match var begin
119-
BSImpl.Term(; f, args) && if is_timevarying_operator(f)::Bool end => (f, args)
120-
_ => continue
121-
end
122-
123-
# arguments and corresponding time domains
124-
tdomains = input_timedomain(op)::Vector{InputTimeDomainElT}
125-
nargs = length(args)
126-
ndoms = length(tdomains)
127-
if nargs != ndoms
128-
throw(ArgumentError("""
129-
Operator $op applied to $nargs arguments $args but only returns $ndoms \
130-
domains $tdomains from `input_timedomain`.
131-
"""))
132-
end
133-
134-
# each relative clock mapping is only valid per operator application
135-
empty!(relative_hyperedges)
136-
for (arg, domain) in zip(args, tdomains)
137-
empty!(arg_varsbuf)
138-
empty!(arg_hyperedge)
139-
# get variables in argument
140-
SU.search_variables!(arg_varsbuf, arg; is_atomic = MTKBase.OperatorIsAtomic{Union{Differential, MTKBase.Shift}}())
141-
# get hyperedge for involved variables
142-
for v in arg_varsbuf
143-
vidx = get(var_to_idx, v, nothing)
144-
vidx === nothing && continue
145-
push!(arg_hyperedge, ClockVertex.Variable(vidx))
146-
end
147-
148-
@match domain begin
149-
# If the time domain for this argument is a clock, then all variables in this edge have that clock.
150-
x::SciMLBase.AbstractClock => begin
151-
# add the clock to the edge
152-
push!(arg_hyperedge, ClockVertex.Clock(x))
153-
# add the edge to the graph
154-
add_edge!(inference_graph, arg_hyperedge)
155-
end
156-
# We only know that this time domain is inferred. Treat it as a unique domain, all we know is that the
157-
# involved variables have the same clock.
158-
InferredClock.Inferred() => add_edge!(inference_graph, arg_hyperedge)
159-
# All `InferredDiscrete` with the same `i` have the same clock (including output domain) so we don't
160-
# add the edge, and instead add this to the `relative_hyperedges` mapping.
161-
InferredClock.InferredDiscrete(i) => begin
162-
relative_edge = get!(Set{ClockVertex.Type}, relative_hyperedges, i)
163-
union!(relative_edge, arg_hyperedge)
164-
end
165-
end
166-
end
167-
168-
outdomain = output_timedomain(op)
169-
@match outdomain begin
170-
x::SciMLBase.AbstractClock => begin
171-
push!(hyperedge, ClockVertex.Clock(x))
172-
end
173-
InferredClock.Inferred() => nothing
174-
InferredClock.InferredDiscrete(i) => begin
175-
buffer = get(relative_hyperedges, i, nothing)
176-
if buffer !== nothing
177-
union!(hyperedge, buffer)
178-
delete!(relative_hyperedges, i)
179-
end
180-
end
181-
end
182-
183-
for (_, relative_edge) in relative_hyperedges
184-
add_edge!(inference_graph, relative_edge)
185-
end
186-
end
187-
188-
add_edge!(inference_graph, hyperedge)
189-
end
190205
for (ieq, eq) in enumerate(MTKBase.equations(sys))
191206
infer_equation(ieq, eq, false)
192207
end
@@ -212,7 +227,9 @@ function infer_clocks!(ci::ClockInference)
212227
"""))
213228
end
214229

215-
clock = partition[only(clockidxs)].:1
230+
clock = Moshi.Match.@match partition[only(clockidxs)] begin
231+
ClockVertex.Clock(clk) => clk
232+
end
216233
for vert in partition
217234
Moshi.Match.@match vert begin
218235
ClockVertex.Variable(i) => (var_domain[i] = clock)
@@ -275,19 +292,15 @@ function split_system(ci::ClockInference{S}) where {S}
275292
# populates clock_to_id and id_to_clock
276293
# checks if there is a continuous_id (for some reason? clock to id does this too)
277294
for (i, d) in enumerate(eq_domain)
278-
cid = let cid_counter = cid_counter, id_to_clock = id_to_clock,
279-
continuous_id = continuous_id
280-
281-
# Fill the clock_to_id dict as you go,
282-
# ContinuousClock() => 1, ...
283-
get!(clock_to_id, d) do
284-
cid = (cid_counter[] += 1)
285-
push!(id_to_clock, d)
286-
if d == SciMLBase.ContinuousClock()
287-
continuous_id[] = cid
288-
end
289-
cid
295+
# We don't use `get!` here because that desperately wants to box things
296+
cid = get(clock_to_id, d, 0)
297+
if iszero(cid)
298+
cid = (cid_counter[] += 1)
299+
push!(id_to_clock, d)
300+
if d === SciMLBase.ContinuousClock()
301+
continuous_id[] = cid
290302
end
303+
clock_to_id[d] = cid
291304
end
292305
eq_to_cid[i] = cid
293306
resize_or_push!(cid_to_eq, i, cid)

lib/ModelingToolkitTearing/src/reassemble.jl

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,7 +1076,7 @@ function (alg::DefaultReassembleAlgorithm)(state::TearingState,
10761076
dummy_sub = Dict{SymbolicT, SymbolicT}()
10771077

10781078
if MTKBase.has_iv(state.sys) && MTKBase.get_iv(state.sys) !== nothing
1079-
iv = MTKBase.get_iv(state.sys)
1079+
iv = MTKBase.get_iv(state.sys)::SymbolicT
10801080
if !StateSelection.is_only_discrete(state.structure)
10811081
D = Differential(iv)
10821082
else
@@ -1089,29 +1089,50 @@ function (alg::DefaultReassembleAlgorithm)(state::TearingState,
10891089
extra_unknowns = state.fullvars[extra_eqs_vars[2]]
10901090
if StateSelection.is_only_discrete(state.structure)
10911091
var_sccs = add_additional_history!(
1092-
state, var_eq_matching, full_var_eq_matching, var_sccs, iv)
1092+
state, var_eq_matching, full_var_eq_matching, var_sccs, iv::SymbolicT)
10931093
end
10941094

10951095
# Structural simplification
1096-
substitute_derivatives_algevars!(state, neweqs, var_eq_matching, dummy_sub, iv, D)
1097-
1098-
var_sccs = generate_derivative_variables!(
1099-
state, neweqs, var_eq_matching, full_var_eq_matching, var_sccs, mm, iv)
1100-
neweqs, solved_eqs,
1101-
eq_ordering,
1102-
var_ordering,
1103-
nelim_eq,
1104-
nelim_var = generate_system_equations!(
1105-
state, neweqs, var_eq_matching, full_var_eq_matching,
1106-
var_sccs, extra_eqs_vars, iv, D; simplify, inline_linear_sccs,
1107-
analytical_linear_scc_limit)
1108-
1109-
state = reorder_vars!(
1110-
state, var_eq_matching, var_sccs, eq_ordering, var_ordering, nelim_eq, nelim_var)
1111-
# var_eq_matching and full_var_eq_matching are now invalidated
1112-
1113-
sys = update_simplified_system!(state, neweqs, solved_eqs, dummy_sub, var_sccs,
1114-
extra_unknowns, iv, D; array_hack)
1096+
if iv isa SymbolicT # Without iv we don't have derivatives
1097+
D = D::Union{Differential, Shift}
1098+
substitute_derivatives_algevars!(state, neweqs, var_eq_matching, dummy_sub, iv, D)
1099+
1100+
var_sccs = generate_derivative_variables!(
1101+
state, neweqs, var_eq_matching, full_var_eq_matching, var_sccs, mm, iv)
1102+
end
1103+
if iv isa SymbolicT
1104+
D = D::Union{Differential, Shift}
1105+
neweqs, solved_eqs,
1106+
eq_ordering,
1107+
var_ordering,
1108+
nelim_eq,
1109+
nelim_var = generate_system_equations!(
1110+
state, neweqs, var_eq_matching, full_var_eq_matching,
1111+
var_sccs, extra_eqs_vars, iv, D; simplify, inline_linear_sccs,
1112+
analytical_linear_scc_limit)
1113+
state = reorder_vars!(
1114+
state, var_eq_matching, var_sccs, eq_ordering, var_ordering, nelim_eq, nelim_var)
1115+
# var_eq_matching and full_var_eq_matching are now invalidated
1116+
1117+
sys = update_simplified_system!(state, neweqs, solved_eqs, dummy_sub, var_sccs,
1118+
extra_unknowns, iv, D; array_hack)
1119+
else
1120+
D = D::Nothing
1121+
neweqs, solved_eqs,
1122+
eq_ordering,
1123+
var_ordering,
1124+
nelim_eq,
1125+
nelim_var = generate_system_equations!(
1126+
state, neweqs, var_eq_matching, full_var_eq_matching,
1127+
var_sccs, extra_eqs_vars, iv, D; simplify, inline_linear_sccs,
1128+
analytical_linear_scc_limit)
1129+
state = reorder_vars!(
1130+
state, var_eq_matching, var_sccs, eq_ordering, var_ordering, nelim_eq, nelim_var)
1131+
# var_eq_matching and full_var_eq_matching are now invalidated
1132+
1133+
sys = update_simplified_system!(state, neweqs, solved_eqs, dummy_sub, var_sccs,
1134+
extra_unknowns, iv, D; array_hack)
1135+
end
11151136

11161137
@set! state.sys = sys
11171138
@set! sys.tearing_state = state

0 commit comments

Comments
 (0)