Skip to content

Commit 4f7810c

Browse files
committed
Early work on the new discrete backend for MTK
1 parent a2db412 commit 4f7810c

File tree

3 files changed

+44
-24
lines changed

3 files changed

+44
-24
lines changed

src/systems/clock_inference.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ function infer_clocks!(ci::ClockInference)
9393
c = BitSet(c′)
9494
idxs = intersect(c, inferred)
9595
isempty(idxs) && continue
96-
if !allequal(var_domain[i] for i in idxs)
96+
if !allequal(iscontinuous(var_domain[i]) for i in idxs)
9797
display(fullvars[c′])
9898
throw(ClockInferenceException("Clocks are not consistent in connected component $(fullvars[c′])"))
9999
end
@@ -144,6 +144,9 @@ function split_system(ci::ClockInference{S}) where {S}
144144
var_to_cid = Vector{Int}(undef, ndsts(graph))
145145
cid_to_var = Vector{Int}[]
146146
cid_counter = Ref(0)
147+
148+
# populates clock_to_id and id_to_clock
149+
# checks if there is a continuous_id (for some reason? clock to id does this too)
147150
for (i, d) in enumerate(eq_domain)
148151
cid = let cid_counter = cid_counter, id_to_clock = id_to_clock,
149152
continuous_id = continuous_id
@@ -161,9 +164,13 @@ function split_system(ci::ClockInference{S}) where {S}
161164
resize_or_push!(cid_to_eq, i, cid)
162165
end
163166
continuous_id = continuous_id[]
167+
# for each clock partition what are the input (indexes/vars)
164168
input_idxs = map(_ -> Int[], 1:cid_counter[])
165169
inputs = map(_ -> Any[], 1:cid_counter[])
170+
# var_domain corresponds to fullvars/all variables in the system
166171
nvv = length(var_domain)
172+
# put variables into the right clock partition
173+
# keep track of inputs to each partition
167174
for i in 1:nvv
168175
d = var_domain[i]
169176
cid = get(clock_to_id, d, 0)
@@ -177,6 +184,7 @@ function split_system(ci::ClockInference{S}) where {S}
177184
resize_or_push!(cid_to_var, i, cid)
178185
end
179186

187+
# breaks the system up into a continous and 0 or more discrete systems
180188
tss = similar(cid_to_eq, S)
181189
for (id, ieqs) in enumerate(cid_to_eq)
182190
ts_i = system_subset(ts, ieqs)
@@ -186,6 +194,7 @@ function split_system(ci::ClockInference{S}) where {S}
186194
end
187195
tss[id] = ts_i
188196
end
197+
# put the continous system at the back
189198
if continuous_id != 0
190199
tss[continuous_id], tss[end] = tss[end], tss[continuous_id]
191200
inputs[continuous_id], inputs[end] = inputs[end], inputs[continuous_id]

src/systems/systems.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function structural_simplify(
3131
kwargs...)
3232
isscheduled(sys) && throw(RepeatedStructuralSimplificationError())
3333
newsys′ = __structural_simplify(sys, io; simplify,
34-
allow_symbolic, allow_parameter, conservative, fully_determined,
34+
allow_symbolic, allow_parameter, conservative, fully_determined, additional_passes,
3535
kwargs...)
3636
if newsys′ isa Tuple
3737
@assert length(newsys′) == 2
@@ -169,3 +169,9 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal
169169
guesses = guesses(sys), initialization_eqs = initialization_equations(sys))
170170
end
171171
end
172+
173+
"""
174+
Mark whether an extra pass `p` can support compiling discrete systems.
175+
"""
176+
discrete_compile_pass(p) = false
177+

src/systems/systemstructure.jl

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -626,40 +626,45 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals
626626
check_consistency = true, fully_determined = true, warn_initialize_determined = true,
627627
kwargs...)
628628
if state.sys isa ODESystem
629+
# split_system returns one or two systems and the inputs for each
630+
# mod clock inference to be binary
631+
# if it's continous keep going, if not then error unless given trait impl in additional passes
629632
ci = ModelingToolkit.ClockInference(state)
630633
ci = ModelingToolkit.infer_clocks!(ci)
631634
time_domains = merge(Dict(state.fullvars .=> ci.var_domain),
632635
Dict(default_toterm.(state.fullvars) .=> ci.var_domain))
633636
tss, inputs, continuous_id, id_to_clock = ModelingToolkit.split_system(ci)
637+
if continuous_id == 0
638+
# do a trait check here - handle fully discrete system
639+
additional_passes = get(kwargs, :additional_passes, nothing)
640+
if !isnothing(additional_passes) && any(discrete_compile_pass, additional_passes)
641+
# take the first discrete compilation pass given for now
642+
discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes)
643+
discrete_compile = additional_passes[discrete_pass_idx]
644+
deleteat!(additional_passes, discrete_pass_idx)
645+
return discrete_compile(tss, inputs)
646+
else
647+
# error goes here! this is a purely discrete system
648+
throw(HybridSystemNotSupportedException("Discrete systems without JuliaSimCompiler are currently not supported in ODESystem."))
649+
end
650+
end
651+
# puts the ios passed in to the call into the continous system
634652
cont_io = merge_io(io, inputs[continuous_id])
653+
# simplify as normal
635654
sys, input_idxs = _structural_simplify!(tss[continuous_id], cont_io; simplify,
636655
check_consistency, fully_determined,
637656
kwargs...)
638657
if length(tss) > 1
639-
if continuous_id > 0
658+
if !isnothing(additional_passes) && any(discrete_compile_pass, additional_passes)
659+
discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes)
660+
discrete_compile = additional_passes[discrete_pass_idx]
661+
deleteat!(additional_passes, discrete_pass_idx)
662+
# in the case of a hybrid system, the discrete_compile pass should take the currents of sys.discrete_subsystems
663+
# and modifies discrete_subsystems to bea tuple of the io and anything else, while adding or manipulating the rest of sys as needed
664+
sys = discrete_compile(sys, tss[2:end], inputs)
665+
else
640666
throw(HybridSystemNotSupportedException("Hybrid continuous-discrete systems are currently not supported with the standard MTK compiler. This system requires JuliaSimCompiler.jl, see https://help.juliahub.com/juliasimcompiler/stable/"))
641667
end
642-
# TODO: rename it to something else
643-
discrete_subsystems = Vector{ODESystem}(undef, length(tss))
644-
# Note that the appended_parameters must agree with
645-
# `generate_discrete_affect`!
646-
appended_parameters = parameters(sys)
647-
for (i, state) in enumerate(tss)
648-
if i == continuous_id
649-
discrete_subsystems[i] = sys
650-
continue
651-
end
652-
dist_io = merge_io(io, inputs[i])
653-
ss, = _structural_simplify!(state, dist_io; simplify, check_consistency,
654-
fully_determined, kwargs...)
655-
append!(appended_parameters, inputs[i], unknowns(ss))
656-
discrete_subsystems[i] = ss
657-
end
658-
@set! sys.discrete_subsystems = discrete_subsystems, inputs, continuous_id,
659-
id_to_clock
660-
@set! sys.ps = appended_parameters
661-
@set! sys.defaults = merge(ModelingToolkit.defaults(sys),
662-
Dict(v => 0.0 for v in Iterators.flatten(inputs)))
663668
end
664669
ps = [sym isa CallWithMetadata ? sym :
665670
setmetadata(sym, VariableTimeDomain, get(time_domains, sym, Continuous()))

0 commit comments

Comments
 (0)