Skip to content

Commit 47324ef

Browse files
feat: preemptively tear some trivial equation in mtkcompile
1 parent d1e9f58 commit 47324ef

File tree

1 file changed

+99
-12
lines changed

1 file changed

+99
-12
lines changed

src/systems/systemstructure.jl

Lines changed: 99 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
208208
structure::SystemStructure
209209
extra_eqs::Vector
210210
param_derivative_map::Dict{BasicSymbolic, Any}
211+
original_eqs::Vector{Equation}
211212
"""
212213
Additional user-provided observed equations. The variables calculated here
213214
are not used in the rest of the system.
@@ -219,6 +220,7 @@ TransformationState(sys::AbstractSystem) = TearingState(sys)
219220
function system_subset(ts::TearingState, ieqs::Vector{Int})
220221
eqs = equations(ts)
221222
@set! ts.sys.eqs = eqs[ieqs]
223+
@set! ts.original_eqs = ts.original_eqs[ieqs]
222224
@set! ts.structure = system_subset(ts.structure, ieqs)
223225
ts
224226
end
@@ -281,10 +283,8 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
281283
iv = length(ivs) == 1 ? ivs[1] : nothing
282284
# flatten array equations
283285
eqs = flatten_equations(equations(sys))
286+
original_eqs = copy(eqs)
284287
neqs = length(eqs)
285-
obseqs = observed(sys)
286-
obsvars = Set([eq.lhs for eq in obseqs])
287-
@set! sys.observed = Equation[]
288288
param_derivative_map = Dict{BasicSymbolic, Any}()
289289
# * Scalarize unknowns
290290
dvs = Set{BasicSymbolic}()
@@ -356,14 +356,6 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
356356
push!(incidence, v)
357357
end
358358

359-
if v in obsvars
360-
error("""
361-
Observed equations in unsimplified systems cannot compute quantities that \
362-
are involved in the equations of the system. Encountered observed \
363-
variable `$v` in equation `$_eq`.
364-
""")
365-
end
366-
367359
# TODO: Can we handle this without `isparameter`?
368360
if symbolic_contains(v, ps) ||
369361
getmetadata(v, SymScope, LocalScope()) isa GlobalScope && isparameter(v)
@@ -432,6 +424,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
432424
end
433425
end
434426
eqs = eqs[eqs_to_retain]
427+
original_eqs = original_eqs[eqs_to_retain]
435428
neqs = length(eqs)
436429
symbolic_incidence = symbolic_incidence[eqs_to_retain]
437430

@@ -440,6 +433,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
440433
# depending on order due to NP-completeness of tearing.
441434
sortidxs = Base.sortperm(eqs, by = string)
442435
eqs = eqs[sortidxs]
436+
original_eqs = original_eqs[sortidxs]
443437
symbolic_incidence = symbolic_incidence[sortidxs]
444438
end
445439

@@ -533,8 +527,100 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
533527
ts = TearingState(sys, fullvars,
534528
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
535529
complete(graph), nothing, var_types, false),
536-
Any[], param_derivative_map, obseqs)
530+
Any[], param_derivative_map, original_eqs, Equation[])
531+
532+
return ts
533+
end
534+
535+
"""
536+
$(TYPEDSIGNATURES)
537+
538+
Preemptively identify observed equations in the system and tear them. This happens before
539+
any simplification. The equations torn by this process are ones that are already given in
540+
an explicit form in the system and where the LHS is not present in any other equation of
541+
the system except for other such preempitvely torn equations.
542+
"""
543+
function trivial_tearing!(ts::TearingState)
544+
@assert length(ts.original_eqs) == length(equations(ts))
545+
# equations that can be trivially torn an observed equations
546+
trivial_idxs = BitSet()
547+
# equations to never check
548+
blacklist = BitSet()
549+
torn_eqs = Equation[]
550+
# variables that have been matched to trivially torn equations
551+
matched_vars = BitSet()
552+
# variable to index in fullvars
553+
var_to_idx = Dict{Any, Int}(ts.fullvars .=> eachindex(ts.fullvars))
554+
555+
complete!(ts.structure)
556+
var_to_diff = ts.structure.var_to_diff
557+
graph = ts.structure.graph
558+
while true
559+
# track whether we added an equation to the trivial list this iteration
560+
added_equation = false
561+
for (i, eq) in enumerate(ts.original_eqs)
562+
# don't check already torn equations
563+
i in trivial_idxs && continue
564+
i in blacklist && continue
565+
# ensure it is an observed equation matched to a variable in fullvars
566+
vari = get(var_to_idx, eq.lhs, 0)
567+
iszero(vari) && continue
568+
# if a variable was the LHS of two trivial observed equations, we wouldn't have
569+
# included it in the list. Error if somehow it made it through.
570+
@assert !(vari in matched_vars)
571+
# don't tear differential/shift equations (or differentiated/shifted variables)
572+
var_to_diff[vari] == 0 || continue
573+
invview(var_to_diff)[vari] == 0 || continue
574+
# get the equations that the candidate matched variable is present in, except
575+
# those equations which have already been torn as observed
576+
eqidxs = setdiff(𝑑neighbors(graph, vari), trivial_idxs)
577+
# it should only be present in this equation
578+
length(eqidxs) == 1 || continue
579+
eqi = only(eqidxs)
580+
@assert eqi == i
581+
582+
# for every variable present in this equation, make sure it isn't _only_
583+
# present in this equation
584+
isvalid = true
585+
for v in 𝑠neighbors(graph, eqi)
586+
v == vari && continue
587+
v in matched_vars && continue
588+
isvalid &= length(𝑑neighbors(graph, v)) > 1
589+
isvalid || break
590+
end
591+
isvalid || continue
592+
# skip if the LHS is present in the RHS, since then this isn't explicit
593+
if occursin(eq.lhs, eq.rhs)
594+
push!(blacklist, i)
595+
continue
596+
end
597+
598+
added_equation = true
599+
push!(trivial_idxs, eqi)
600+
push!(torn_eqs, eq)
601+
push!(matched_vars, vari)
602+
end
537603

604+
# if we didn't add an equation this iteration, we won't add one next iteration
605+
added_equation || break
606+
end
607+
608+
deleteat!(var_to_diff.primal_to_diff, matched_vars)
609+
deleteat!(var_to_diff.diff_to_primal, matched_vars)
610+
deleteat!(ts.structure.eq_to_diff.primal_to_diff, trivial_idxs)
611+
deleteat!(ts.structure.eq_to_diff.diff_to_primal, trivial_idxs)
612+
delete_srcs!(ts.structure.graph, trivial_idxs)
613+
delete_dsts!(ts.structure.graph, matched_vars)
614+
if ts.structure.solvable_graph !== nothing
615+
delete_srcs!(ts.structure.solvable_graph, trivial_idxs)
616+
delete_dsts!(ts.structure.solvable_graph, matched_vars)
617+
end
618+
if ts.structure.var_types !== nothing
619+
deleteat!(ts.structure.var_types, matched_vars)
620+
end
621+
deleteat!(ts.fullvars, matched_vars)
622+
deleteat!(ts.original_eqs, trivial_idxs)
623+
ts.additional_observed = torn_eqs
538624
return ts
539625
end
540626

@@ -770,6 +856,7 @@ function _mtkcompile!(state::TearingState; simplify = false,
770856
ModelingToolkit.markio!(state, orig_inputs, inputs, outputs, disturbance_inputs)
771857
state = ModelingToolkit.inputs_to_parameters!(state, [inputs; disturbance_inputs])
772858
end
859+
trivial_tearing!(state)
773860
sys, mm = ModelingToolkit.alias_elimination!(state; kwargs...)
774861
if check_consistency
775862
fully_determined = ModelingToolkit.check_consistency(

0 commit comments

Comments
 (0)