Skip to content

Commit 4988ee1

Browse files
BenChungAayushSabharwal
authored andcommitted
Early work on the new discrete backend for MTK
1 parent 7b5c5f0 commit 4988ee1

File tree

3 files changed

+48
-13
lines changed

3 files changed

+48
-13
lines changed

src/systems/clock_inference.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ function infer_clocks!(ci::ClockInference)
100100
c = BitSet(c′)
101101
idxs = intersect(c, inferred)
102102
isempty(idxs) && continue
103-
if !allequal(var_domain[i] for i in idxs)
103+
if !allequal(iscontinuous(var_domain[i]) for i in idxs)
104104
display(fullvars[c′])
105105
throw(ClockInferenceException("Clocks are not consistent in connected component $(fullvars[c′])"))
106106
end
@@ -155,6 +155,9 @@ function split_system(ci::ClockInference{S}) where {S}
155155
cid_to_var = Vector{Int}[]
156156
# cid_counter = number of clocks
157157
cid_counter = Ref(0)
158+
159+
# populates clock_to_id and id_to_clock
160+
# checks if there is a continuous_id (for some reason? clock to id does this too)
158161
for (i, d) in enumerate(eq_domain)
159162
cid = let cid_counter = cid_counter, id_to_clock = id_to_clock,
160163
continuous_id = continuous_id
@@ -174,9 +177,13 @@ function split_system(ci::ClockInference{S}) where {S}
174177
resize_or_push!(cid_to_eq, i, cid)
175178
end
176179
continuous_id = continuous_id[]
180+
# for each clock partition what are the input (indexes/vars)
177181
input_idxs = map(_ -> Int[], 1:cid_counter[])
178182
inputs = map(_ -> Any[], 1:cid_counter[])
183+
# var_domain corresponds to fullvars/all variables in the system
179184
nvv = length(var_domain)
185+
# put variables into the right clock partition
186+
# keep track of inputs to each partition
180187
for i in 1:nvv
181188
d = var_domain[i]
182189
cid = get(clock_to_id, d, 0)
@@ -190,6 +197,7 @@ function split_system(ci::ClockInference{S}) where {S}
190197
resize_or_push!(cid_to_var, i, cid)
191198
end
192199

200+
# breaks the system up into a continous and 0 or more discrete systems
193201
tss = similar(cid_to_eq, S)
194202
for (id, ieqs) in enumerate(cid_to_eq)
195203
ts_i = system_subset(ts, ieqs)
@@ -199,6 +207,7 @@ function split_system(ci::ClockInference{S}) where {S}
199207
end
200208
tss[id] = ts_i
201209
end
210+
# put the continous system at the back
202211
if continuous_id != 0
203212
tss[continuous_id], tss[end] = tss[end], tss[continuous_id]
204213
inputs[continuous_id], inputs[end] = inputs[end], inputs[continuous_id]

src/systems/systems.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ function mtkcompile(
3636
isscheduled(sys) && throw(RepeatedStructuralSimplificationError())
3737
newsys′ = __mtkcompile(sys; simplify,
3838
allow_symbolic, allow_parameter, conservative, fully_determined,
39-
inputs, outputs, disturbance_inputs,
39+
inputs, outputs, disturbance_inputs, additional_passes,
4040
kwargs...)
4141
if newsys′ isa Tuple
4242
@assert length(newsys′) == 2
@@ -292,3 +292,8 @@ function map_variables_to_equations(sys::AbstractSystem; rename_dummy_derivative
292292

293293
return mapping
294294
end
295+
296+
"""
297+
Mark whether an extra pass `p` can support compiling discrete systems.
298+
"""
299+
discrete_compile_pass(p) = false

src/systems/systemstructure.jl

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -820,19 +820,40 @@ function mtkcompile!(state::TearingState; simplify = false,
820820
time_domains = merge(Dict(state.fullvars .=> ci.var_domain),
821821
Dict(default_toterm.(state.fullvars) .=> ci.var_domain))
822822
tss, clocked_inputs, continuous_id, id_to_clock = ModelingToolkit.split_system(ci)
823+
if continuous_id == 0
824+
# do a trait check here - handle fully discrete system
825+
additional_passes = get(kwargs, :additional_passes, nothing)
826+
if !isnothing(additional_passes) && any(discrete_compile_pass, additional_passes)
827+
# take the first discrete compilation pass given for now
828+
discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes)
829+
discrete_compile = additional_passes[discrete_pass_idx]
830+
deleteat!(additional_passes, discrete_pass_idx)
831+
return discrete_compile(tss, clocked_inputs)
832+
end
833+
throw(HybridSystemNotSupportedException("""
834+
Discrete systems with multiple clocks are not supported with the standard \
835+
MTK compiler.
836+
"""))
837+
end
823838
if length(tss) > 1
824-
if continuous_id == 0
825-
throw(HybridSystemNotSupportedException("""
826-
Discrete systems with multiple clocks are not supported with the standard \
827-
MTK compiler.
828-
"""))
829-
else
830-
throw(HybridSystemNotSupportedException("""
831-
Hybrid continuous-discrete systems are currently not supported with \
832-
the standard MTK compiler. This system requires JuliaSimCompiler.jl, \
833-
see https://help.juliahub.com/juliasimcompiler/stable/
834-
"""))
839+
# simplify as normal
840+
sys = _mtkcompile!(tss[continuous_id]; simplify,
841+
inputs = [inputs; clocked_inputs[continuous_id]], outputs, disturbance_inputs,
842+
check_consistency, fully_determined,
843+
kwargs...)
844+
if !isnothing(additional_passes) && any(discrete_compile_pass, additional_passes)
845+
discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes)
846+
discrete_compile = additional_passes[discrete_pass_idx]
847+
deleteat!(additional_passes, discrete_pass_idx)
848+
# in the case of a hybrid system, the discrete_compile pass should take the currents of sys.discrete_subsystems
849+
# and modifies discrete_subsystems to bea tuple of the io and anything else, while adding or manipulating the rest of sys as needed
850+
return discrete_compile(sys, tss[2:end], inputs)
835851
end
852+
throw(HybridSystemNotSupportedException("""
853+
Hybrid continuous-discrete systems are currently not supported with \
854+
the standard MTK compiler. This system requires JuliaSimCompiler.jl, \
855+
see https://help.juliahub.com/juliasimcompiler/stable/
856+
"""))
836857
end
837858
if get_is_discrete(state.sys) ||
838859
continuous_id == 1 && any(Base.Fix2(isoperator, Shift), state.fullvars)

0 commit comments

Comments
 (0)