203
203
mutable struct TearingState{T <: AbstractSystem } <: AbstractTearingState{T}
204
204
""" The system of equations."""
205
205
sys:: T
206
- original_eqs:: Vector{Equation}
207
206
""" The set of variables of the system."""
208
207
fullvars:: Vector{BasicSymbolic}
209
208
structure:: SystemStructure
@@ -215,6 +214,7 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
215
214
are not used in the rest of the system.
216
215
"""
217
216
additional_observed:: Vector{Equation}
217
+ statemachines:: Vector{T}
218
218
end
219
219
220
220
TransformationState (sys:: AbstractSystem ) = TearingState (sys)
@@ -224,6 +224,22 @@ function system_subset(ts::TearingState, ieqs::Vector{Int})
224
224
@set! ts. sys. eqs = eqs[ieqs]
225
225
@set! ts. original_eqs = ts. original_eqs[ieqs]
226
226
@set! ts. structure = system_subset (ts. structure, ieqs)
227
+ if all (eq -> eq. rhs isa StateMachineOperator, get_eqs (ts. sys))
228
+ names = Symbol[]
229
+ for eq in get_eqs (ts. sys)
230
+ if eq. lhs isa Transition
231
+ push! (names, first (namespace_hierarchy (nameof (eq. rhs. from))))
232
+ push! (names, first (namespace_hierarchy (nameof (eq. rhs. to))))
233
+ elseif eq. lhs isa InitialState
234
+ push! (names, first (namespace_hierarchy (nameof (eq. rhs. s))))
235
+ else
236
+ error (" Unhandled state machine operator" )
237
+ end
238
+ end
239
+ @set! ts. statemachines = filter (x -> nameof (x) in names, ts. statemachines)
240
+ else
241
+ @set! ts. statemachines = eltype (ts. statemachines)[]
242
+ end
227
243
ts
228
244
end
229
245
@@ -277,6 +293,49 @@ function symbolic_contains(var, set)
277
293
all (x -> x in set, Symbolics. scalarize (var))
278
294
end
279
295
296
+ """
297
+ $(TYPEDSIGNATURES)
298
+
299
+ Descend through the system hierarchy and look for statemachines. Remove equations from
300
+ the inner statemachine systems. Return the new `sys` and an array of top-level
301
+ statemachines.
302
+ """
303
+ function extract_top_level_statemachines (sys:: AbstractSystem )
304
+ eqs = get_eqs (sys)
305
+
306
+ if ! isempty (eqs) && all (eq -> eq. lhs isa StateMachineOperator, eqs)
307
+ # top-level statemachine
308
+ with_removed = @set sys. systems = map (remove_child_equations, get_systems (sys))
309
+ return with_removed, [sys]
310
+ elseif ! isempty (eqs) && any (eq -> eq. lhs isa StateMachineOperator, eqs)
311
+ # error: can't mix
312
+ error (" Mixing statemachine equations and standard equations in a top-level statemachine is not allowed." )
313
+ else
314
+ # descend
315
+ subsystems = get_systems (sys)
316
+ newsubsystems = eltype (subsystems)[]
317
+ statemachines = eltype (subsystems)[]
318
+ for subsys in subsystems
319
+ newsubsys, sub_statemachines = extract_top_level_statemachines (subsys)
320
+ push! (newsubsystems, newsubsys)
321
+ append! (statemachines, sub_statemachines)
322
+ end
323
+ @set! sys. systems = newsubsystems
324
+ return sys, statemachines
325
+ end
326
+ end
327
+
328
+ """
329
+ $(TYPEDSIGNATURES)
330
+
331
+ Return `sys` with all equations (including those in subsystems) removed.
332
+ """
333
+ function remove_child_equations (sys:: AbstractSystem )
334
+ @set! sys. eqs = eltype (get_eqs (sys))[]
335
+ @set! sys. systems = map (remove_child_equations, get_systems (sys))
336
+ return sys
337
+ end
338
+
280
339
function TearingState (sys; quick_cancel = false , check = true , sort_eqs = true )
281
340
# flatten system
282
341
sys = flatten (sys)
@@ -342,9 +401,16 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
342
401
# change the equation if the RHS is `missing` so the rest of this loop works
343
402
eq = 0.0 ~ coalesce (eq. rhs, 0.0 )
344
403
end
345
- rhs = quick_cancel ? quick_cancel_expr (eq. rhs) : eq. rhs
346
- if ! _iszero (eq. lhs)
404
+ is_statemachine_equation = false
405
+ if eq. lhs isa StateMachineOperator
406
+ is_statemachine_equation = true
407
+ eq = eq
408
+ rhs = eq. rhs
409
+ elseif _iszero (eq. lhs)
410
+ rhs = quick_cancel ? quick_cancel_expr (eq. rhs) : eq. rhs
411
+ else
347
412
lhs = quick_cancel ? quick_cancel_expr (eq. lhs) : eq. lhs
413
+ rhs = quick_cancel ? quick_cancel_expr (eq. rhs) : eq. rhs
348
414
eq = 0 ~ rhs - lhs
349
415
end
350
416
empty! (varsbuf)
@@ -408,8 +474,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
408
474
addvar! (v, VARIABLE)
409
475
end
410
476
end
411
-
412
- if isalgeq
477
+ if isalgeq || is_statemachine_equation
413
478
eqs[i] = eq
414
479
else
415
480
eqs[i] = eqs[i]. lhs ~ rhs
@@ -526,11 +591,10 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
526
591
527
592
eq_to_diff = DiffGraph (nsrcs (graph))
528
593
529
- ts = TearingState (sys, original_eqs, fullvars,
594
+ ts = TearingState (sys, fullvars,
530
595
SystemStructure (complete (var_to_diff), complete (eq_to_diff),
531
596
complete (graph), nothing , var_types, false ),
532
- Any[], param_derivative_map, original_eqs, Equation[])
533
-
597
+ Any[], param_derivative_map, original_eqs, Equation[], typeof (sys)[])
534
598
return ts
535
599
end
536
600
@@ -860,6 +924,7 @@ function mtkcompile!(state::TearingState; simplify = false,
860
924
inputs = [inputs; clocked_inputs[continuous_id]], outputs, disturbance_inputs,
861
925
check_consistency, fully_determined,
862
926
kwargs... )
927
+ additional_passes = get (kwargs, :additional_passes , nothing )
863
928
if ! isnothing (additional_passes) && any (discrete_compile_pass, additional_passes)
864
929
discrete_pass_idx = findfirst (discrete_compile_pass, additional_passes)
865
930
discrete_compile = additional_passes[discrete_pass_idx]
0 commit comments