@@ -209,6 +209,7 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
209209 structure:: SystemStructure
210210 extra_eqs:: Vector
211211 param_derivative_map:: Dict{BasicSymbolic, Any}
212+ statemachines:: Vector{T}
212213end
213214
214215TransformationState (sys:: AbstractSystem ) = TearingState (sys)
@@ -217,6 +218,22 @@ function system_subset(ts::TearingState, ieqs::Vector{Int})
217218 @set! ts. original_eqs = ts. original_eqs[ieqs]
218219 @set! ts. sys. eqs = eqs[ieqs]
219220 @set! ts. structure = system_subset (ts. structure, ieqs)
221+ if all (eq -> eq. rhs isa StateMachineOperator, get_eqs (ts. sys))
222+ names = Symbol[]
223+ for eq in get_eqs (ts. sys)
224+ if eq. lhs isa Transition
225+ push! (names, first (namespace_hierarchy (nameof (eq. rhs. from))))
226+ push! (names, first (namespace_hierarchy (nameof (eq. rhs. to))))
227+ elseif eq. lhs isa InitialState
228+ push! (names, first (namespace_hierarchy (nameof (eq. rhs. s))))
229+ else
230+ error (" Unhandled state machine operator" )
231+ end
232+ end
233+ @set! ts. statemachines = filter (x -> nameof (x) in names, ts. statemachines)
234+ else
235+ @set! ts. statemachines = eltype (ts. statemachines)[]
236+ end
220237 ts
221238end
222239
@@ -270,6 +287,49 @@ function symbolic_contains(var, set)
270287 all (x -> x in set, Symbolics. scalarize (var))
271288end
272289
290+ """
291+ $(TYPEDSIGNATURES)
292+
293+ Descend through the system hierarchy and look for statemachines. Remove equations from
294+ the inner statemachine systems. Return the new `sys` and an array of top-level
295+ statemachines.
296+ """
297+ function extract_top_level_statemachines (sys:: AbstractSystem )
298+ eqs = get_eqs (sys)
299+
300+ if ! isempty (eqs) && all (eq -> eq. lhs isa StateMachineOperator, eqs)
301+ # top-level statemachine
302+ with_removed = @set sys. systems = map (remove_child_equations, get_systems (sys))
303+ return with_removed, [sys]
304+ elseif ! isempty (eqs) && any (eq -> eq. lhs isa StateMachineOperator, eqs)
305+ # error: can't mix
306+ error (" Mixing statemachine equations and standard equations in a top-level statemachine is not allowed." )
307+ else
308+ # descend
309+ subsystems = get_systems (sys)
310+ newsubsystems = eltype (subsystems)[]
311+ statemachines = eltype (subsystems)[]
312+ for subsys in subsystems
313+ newsubsys, sub_statemachines = extract_top_level_statemachines (subsys)
314+ push! (newsubsystems, newsubsys)
315+ append! (statemachines, sub_statemachines)
316+ end
317+ @set! sys. systems = newsubsystems
318+ return sys, statemachines
319+ end
320+ end
321+
322+ """
323+ $(TYPEDSIGNATURES)
324+
325+ Return `sys` with all equations (including those in subsystems) removed.
326+ """
327+ function remove_child_equations (sys:: AbstractSystem )
328+ @set! sys. eqs = eltype (get_eqs (sys))[]
329+ @set! sys. systems = map (remove_child_equations, get_systems (sys))
330+ return sys
331+ end
332+
273333function TearingState (sys; quick_cancel = false , check = true , sort_eqs = true )
274334 # flatten system
275335 sys = flatten (sys)
@@ -334,9 +394,16 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
334394 # change the equation if the RHS is `missing` so the rest of this loop works
335395 eq = 0.0 ~ coalesce (eq. rhs, 0.0 )
336396 end
337- rhs = quick_cancel ? quick_cancel_expr (eq. rhs) : eq. rhs
338- if ! _iszero (eq. lhs)
397+ is_statemachine_equation = false
398+ if eq. lhs isa StateMachineOperator
399+ is_statemachine_equation = true
400+ eq = eq
401+ rhs = eq. rhs
402+ elseif _iszero (eq. lhs)
403+ rhs = quick_cancel ? quick_cancel_expr (eq. rhs) : eq. rhs
404+ else
339405 lhs = quick_cancel ? quick_cancel_expr (eq. lhs) : eq. lhs
406+ rhs = quick_cancel ? quick_cancel_expr (eq. rhs) : eq. rhs
340407 eq = 0 ~ rhs - lhs
341408 end
342409 empty! (varsbuf)
@@ -400,8 +467,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
400467 addvar! (v, VARIABLE)
401468 end
402469 end
403-
404- if isalgeq
470+ if isalgeq || is_statemachine_equation
405471 eqs[i] = eq
406472 else
407473 eqs[i] = eqs[i]. lhs ~ rhs
@@ -521,8 +587,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
521587 ts = TearingState (sys, original_eqs, fullvars,
522588 SystemStructure (complete (var_to_diff), complete (eq_to_diff),
523589 complete (graph), nothing , var_types, false ),
524- Any[], param_derivative_map)
525-
590+ Any[], param_derivative_map, typeof (sys)[])
526591 return ts
527592end
528593
@@ -749,6 +814,7 @@ function mtkcompile!(state::TearingState; simplify = false,
749814 inputs = [inputs; clocked_inputs[continuous_id]], outputs, disturbance_inputs,
750815 check_consistency, fully_determined,
751816 kwargs... )
817+ additional_passes = get (kwargs, :additional_passes , nothing )
752818 if ! isnothing (additional_passes) && any (discrete_compile_pass, additional_passes)
753819 discrete_pass_idx = findfirst (discrete_compile_pass, additional_passes)
754820 discrete_compile = additional_passes[discrete_pass_idx]
0 commit comments