204204mutable struct TearingState{T <: AbstractSystem } <: AbstractTearingState{T}
205205 """ The system of equations."""
206206 sys:: T
207- original_eqs:: Vector{Equation}
208207 """ The set of variables of the system."""
209208 fullvars:: Vector{BasicSymbolic}
210209 structure:: SystemStructure
@@ -216,6 +215,7 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
216215 are not used in the rest of the system.
217216 """
218217 additional_observed:: Vector{Equation}
218+ statemachines:: Vector{T}
219219end
220220
221221TransformationState (sys:: AbstractSystem ) = TearingState (sys)
@@ -224,6 +224,22 @@ function system_subset(ts::TearingState, ieqs::Vector{Int})
224224 @set! ts. sys. eqs = eqs[ieqs]
225225 @set! ts. original_eqs = ts. original_eqs[ieqs]
226226 @set! ts. structure = system_subset (ts. structure, ieqs)
227+ if all (eq -> eq. rhs isa StateMachineOperator, get_eqs (ts. sys))
228+ names = Symbol[]
229+ for eq in get_eqs (ts. sys)
230+ if eq. lhs isa Transition
231+ push! (names, first (namespace_hierarchy (nameof (eq. rhs. from))))
232+ push! (names, first (namespace_hierarchy (nameof (eq. rhs. to))))
233+ elseif eq. lhs isa InitialState
234+ push! (names, first (namespace_hierarchy (nameof (eq. rhs. s))))
235+ else
236+ error (" Unhandled state machine operator" )
237+ end
238+ end
239+ @set! ts. statemachines = filter (x -> nameof (x) in names, ts. statemachines)
240+ else
241+ @set! ts. statemachines = eltype (ts. statemachines)[]
242+ end
227243 ts
228244end
229245
@@ -277,6 +293,49 @@ function symbolic_contains(var, set)
277293 all (x -> x in set, Symbolics. scalarize (var))
278294end
279295
296+ """
297+ $(TYPEDSIGNATURES)
298+
299+ Descend through the system hierarchy and look for statemachines. Remove equations from
300+ the inner statemachine systems. Return the new `sys` and an array of top-level
301+ statemachines.
302+ """
303+ function extract_top_level_statemachines (sys:: AbstractSystem )
304+ eqs = get_eqs (sys)
305+
306+ if ! isempty (eqs) && all (eq -> eq. lhs isa StateMachineOperator, eqs)
307+ # top-level statemachine
308+ with_removed = @set sys. systems = map (remove_child_equations, get_systems (sys))
309+ return with_removed, [sys]
310+ elseif ! isempty (eqs) && any (eq -> eq. lhs isa StateMachineOperator, eqs)
311+ # error: can't mix
312+ error (" Mixing statemachine equations and standard equations in a top-level statemachine is not allowed." )
313+ else
314+ # descend
315+ subsystems = get_systems (sys)
316+ newsubsystems = eltype (subsystems)[]
317+ statemachines = eltype (subsystems)[]
318+ for subsys in subsystems
319+ newsubsys, sub_statemachines = extract_top_level_statemachines (subsys)
320+ push! (newsubsystems, newsubsys)
321+ append! (statemachines, sub_statemachines)
322+ end
323+ @set! sys. systems = newsubsystems
324+ return sys, statemachines
325+ end
326+ end
327+
328+ """
329+ $(TYPEDSIGNATURES)
330+
331+ Return `sys` with all equations (including those in subsystems) removed.
332+ """
333+ function remove_child_equations (sys:: AbstractSystem )
334+ @set! sys. eqs = eltype (get_eqs (sys))[]
335+ @set! sys. systems = map (remove_child_equations, get_systems (sys))
336+ return sys
337+ end
338+
280339function TearingState (sys; quick_cancel = false , check = true , sort_eqs = true )
281340 # flatten system
282341 sys = flatten (sys)
@@ -342,9 +401,16 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
342401 # change the equation if the RHS is `missing` so the rest of this loop works
343402 eq = 0.0 ~ coalesce (eq. rhs, 0.0 )
344403 end
345- rhs = quick_cancel ? quick_cancel_expr (eq. rhs) : eq. rhs
346- if ! _iszero (eq. lhs)
404+ is_statemachine_equation = false
405+ if eq. lhs isa StateMachineOperator
406+ is_statemachine_equation = true
407+ eq = eq
408+ rhs = eq. rhs
409+ elseif _iszero (eq. lhs)
410+ rhs = quick_cancel ? quick_cancel_expr (eq. rhs) : eq. rhs
411+ else
347412 lhs = quick_cancel ? quick_cancel_expr (eq. lhs) : eq. lhs
413+ rhs = quick_cancel ? quick_cancel_expr (eq. rhs) : eq. rhs
348414 eq = 0 ~ rhs - lhs
349415 end
350416 empty! (varsbuf)
@@ -409,8 +475,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
409475 addvar! (v, VARIABLE)
410476 end
411477 end
412-
413- if isalgeq
478+ if isalgeq || is_statemachine_equation
414479 eqs[i] = eq
415480 else
416481 eqs[i] = eqs[i]. lhs ~ rhs
@@ -528,11 +593,10 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
528593
529594 eq_to_diff = DiffGraph (nsrcs (graph))
530595
531- ts = TearingState (sys, original_eqs, fullvars,
596+ ts = TearingState (sys, fullvars,
532597 SystemStructure (complete (var_to_diff), complete (eq_to_diff),
533598 complete (graph), nothing , var_types, false ),
534- Any[], param_derivative_map, original_eqs, Equation[])
535-
599+ Any[], param_derivative_map, original_eqs, Equation[], typeof (sys)[])
536600 return ts
537601end
538602
@@ -862,6 +926,7 @@ function mtkcompile!(state::TearingState; simplify = false,
862926 inputs = [inputs; clocked_inputs[continuous_id]], outputs, disturbance_inputs,
863927 check_consistency, fully_determined,
864928 kwargs... )
929+ additional_passes = get (kwargs, :additional_passes , nothing )
865930 if ! isnothing (additional_passes) && any (discrete_compile_pass, additional_passes)
866931 discrete_pass_idx = findfirst (discrete_compile_pass, additional_passes)
867932 discrete_compile = additional_passes[discrete_pass_idx]
0 commit comments