Skip to content

Commit 7cf774d

Browse files
feat: propagate state machines in structural simplification
1 parent 7f8b8f2 commit 7cf774d

File tree

2 files changed

+76
-9
lines changed

2 files changed

+76
-9
lines changed

src/systems/systems.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,10 @@ function __mtkcompile(sys::AbstractSystem; simplify = false,
7575
return simplify_optimization_system(sys; kwargs..., sort_eqs, simplify)
7676
end
7777

78+
sys, statemachines = extract_top_level_statemachines(sys)
7879
sys = expand_connections(sys)
79-
state = TearingState(sys; sort_eqs)
80+
state = TearingState(sys)
81+
append!(state.statemachines, statemachines)
8082

8183
@unpack structure, fullvars = state
8284
@unpack graph, var_to_diff, var_types = structure

src/systems/systemstructure.jl

Lines changed: 73 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,6 @@ end
203203
mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
204204
"""The system of equations."""
205205
sys::T
206-
original_eqs::Vector{Equation}
207206
"""The set of variables of the system."""
208207
fullvars::Vector{BasicSymbolic}
209208
structure::SystemStructure
@@ -215,6 +214,7 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
215214
are not used in the rest of the system.
216215
"""
217216
additional_observed::Vector{Equation}
217+
statemachines::Vector{T}
218218
end
219219

220220
TransformationState(sys::AbstractSystem) = TearingState(sys)
@@ -224,6 +224,22 @@ function system_subset(ts::TearingState, ieqs::Vector{Int})
224224
@set! ts.sys.eqs = eqs[ieqs]
225225
@set! ts.original_eqs = ts.original_eqs[ieqs]
226226
@set! ts.structure = system_subset(ts.structure, ieqs)
227+
if all(eq -> eq.rhs isa StateMachineOperator, get_eqs(ts.sys))
228+
names = Symbol[]
229+
for eq in get_eqs(ts.sys)
230+
if eq.lhs isa Transition
231+
push!(names, first(namespace_hierarchy(nameof(eq.rhs.from))))
232+
push!(names, first(namespace_hierarchy(nameof(eq.rhs.to))))
233+
elseif eq.lhs isa InitialState
234+
push!(names, first(namespace_hierarchy(nameof(eq.rhs.s))))
235+
else
236+
error("Unhandled state machine operator")
237+
end
238+
end
239+
@set! ts.statemachines = filter(x -> nameof(x) in names, ts.statemachines)
240+
else
241+
@set! ts.statemachines = eltype(ts.statemachines)[]
242+
end
227243
ts
228244
end
229245

@@ -277,6 +293,49 @@ function symbolic_contains(var, set)
277293
all(x -> x in set, Symbolics.scalarize(var))
278294
end
279295

296+
"""
297+
$(TYPEDSIGNATURES)
298+
299+
Descend through the system hierarchy and look for statemachines. Remove equations from
300+
the inner statemachine systems. Return the new `sys` and an array of top-level
301+
statemachines.
302+
"""
303+
function extract_top_level_statemachines(sys::AbstractSystem)
304+
eqs = get_eqs(sys)
305+
306+
if !isempty(eqs) && all(eq -> eq.lhs isa StateMachineOperator, eqs)
307+
# top-level statemachine
308+
with_removed = @set sys.systems = map(remove_child_equations, get_systems(sys))
309+
return with_removed, [sys]
310+
elseif !isempty(eqs) && any(eq -> eq.lhs isa StateMachineOperator, eqs)
311+
# error: can't mix
312+
error("Mixing statemachine equations and standard equations in a top-level statemachine is not allowed.")
313+
else
314+
# descend
315+
subsystems = get_systems(sys)
316+
newsubsystems = eltype(subsystems)[]
317+
statemachines = eltype(subsystems)[]
318+
for subsys in subsystems
319+
newsubsys, sub_statemachines = extract_top_level_statemachines(subsys)
320+
push!(newsubsystems, newsubsys)
321+
append!(statemachines, sub_statemachines)
322+
end
323+
@set! sys.systems = newsubsystems
324+
return sys, statemachines
325+
end
326+
end
327+
328+
"""
329+
$(TYPEDSIGNATURES)
330+
331+
Return `sys` with all equations (including those in subsystems) removed.
332+
"""
333+
function remove_child_equations(sys::AbstractSystem)
334+
@set! sys.eqs = eltype(get_eqs(sys))[]
335+
@set! sys.systems = map(remove_child_equations, get_systems(sys))
336+
return sys
337+
end
338+
280339
function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
281340
# flatten system
282341
sys = flatten(sys)
@@ -342,9 +401,16 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
342401
# change the equation if the RHS is `missing` so the rest of this loop works
343402
eq = 0.0 ~ coalesce(eq.rhs, 0.0)
344403
end
345-
rhs = quick_cancel ? quick_cancel_expr(eq.rhs) : eq.rhs
346-
if !_iszero(eq.lhs)
404+
is_statemachine_equation = false
405+
if eq.lhs isa StateMachineOperator
406+
is_statemachine_equation = true
407+
eq = eq
408+
rhs = eq.rhs
409+
elseif _iszero(eq.lhs)
410+
rhs = quick_cancel ? quick_cancel_expr(eq.rhs) : eq.rhs
411+
else
347412
lhs = quick_cancel ? quick_cancel_expr(eq.lhs) : eq.lhs
413+
rhs = quick_cancel ? quick_cancel_expr(eq.rhs) : eq.rhs
348414
eq = 0 ~ rhs - lhs
349415
end
350416
empty!(varsbuf)
@@ -408,8 +474,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
408474
addvar!(v, VARIABLE)
409475
end
410476
end
411-
412-
if isalgeq
477+
if isalgeq || is_statemachine_equation
413478
eqs[i] = eq
414479
else
415480
eqs[i] = eqs[i].lhs ~ rhs
@@ -526,11 +591,10 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
526591

527592
eq_to_diff = DiffGraph(nsrcs(graph))
528593

529-
ts = TearingState(sys, original_eqs, fullvars,
594+
ts = TearingState(sys, fullvars,
530595
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
531596
complete(graph), nothing, var_types, false),
532-
Any[], param_derivative_map, original_eqs, Equation[])
533-
597+
Any[], param_derivative_map, original_eqs, Equation[], typeof(sys)[])
534598
return ts
535599
end
536600

@@ -860,6 +924,7 @@ function mtkcompile!(state::TearingState; simplify = false,
860924
inputs = [inputs; clocked_inputs[continuous_id]], outputs, disturbance_inputs,
861925
check_consistency, fully_determined,
862926
kwargs...)
927+
additional_passes = get(kwargs, :additional_passes, nothing)
863928
if !isnothing(additional_passes) && any(discrete_compile_pass, additional_passes)
864929
discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes)
865930
discrete_compile = additional_passes[discrete_pass_idx]

0 commit comments

Comments
 (0)