Skip to content

Commit 1f71583

Browse files
feat: preemptively tear some trivial equations in mtkcompile
1 parent 58972cd commit 1f71583

File tree

2 files changed

+106
-2
lines changed

2 files changed

+106
-2
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -960,7 +960,7 @@ function update_simplified_system!(
960960
obs_sub[eq.lhs] = eq.rhs
961961
end
962962
# TODO: compute the dependency correctly so that we don't have to do this
963-
obs = [fast_substitute(observed(sys), obs_sub); solved_eqs]
963+
obs = [fast_substitute(observed(sys), obs_sub); solved_eqs; state.additional_observed]
964964

965965
unknown_idxs = filter(
966966
i -> diff_to_var[i] === nothing && ispresent(i) && !(fullvars[i] in solved_vars), eachindex(state.fullvars))

src/systems/systemstructure.jl

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,12 +208,19 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
208208
structure::SystemStructure
209209
extra_eqs::Vector
210210
param_derivative_map::Dict{BasicSymbolic, Any}
211+
original_eqs::Vector{Equation}
212+
"""
213+
Additional user-provided observed equations. The variables calculated here
214+
are not used in the rest of the system.
215+
"""
216+
additional_observed::Vector{Equation}
211217
end
212218

213219
TransformationState(sys::AbstractSystem) = TearingState(sys)
214220
function system_subset(ts::TearingState, ieqs::Vector{Int})
215221
eqs = equations(ts)
216222
@set! ts.sys.eqs = eqs[ieqs]
223+
@set! ts.original_eqs = ts.original_eqs[ieqs]
217224
@set! ts.structure = system_subset(ts.structure, ieqs)
218225
ts
219226
end
@@ -276,6 +283,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
276283
iv = length(ivs) == 1 ? ivs[1] : nothing
277284
# flatten array equations
278285
eqs = flatten_equations(equations(sys))
286+
original_eqs = copy(eqs)
279287
neqs = length(eqs)
280288
param_derivative_map = Dict{BasicSymbolic, Any}()
281289
# * Scalarize unknowns
@@ -320,6 +328,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
320328
varsbuf = Set()
321329
eqs_to_retain = trues(length(eqs))
322330
for (i, eq) in enumerate(eqs)
331+
_eq = eq
323332
if iscall(eq.lhs) && (op = operation(eq.lhs)) isa Differential &&
324333
isequal(op.x, iv) && is_time_dependent_parameter(only(arguments(eq.lhs)), ps, iv)
325334
# parameter derivatives are opted out by specifying `D(p) ~ missing`, but
@@ -415,6 +424,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
415424
end
416425
end
417426
eqs = eqs[eqs_to_retain]
427+
original_eqs = original_eqs[eqs_to_retain]
418428
neqs = length(eqs)
419429
symbolic_incidence = symbolic_incidence[eqs_to_retain]
420430

@@ -423,6 +433,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
423433
# depending on order due to NP-completeness of tearing.
424434
sortidxs = Base.sortperm(eqs, by = string)
425435
eqs = eqs[sortidxs]
436+
original_eqs = original_eqs[sortidxs]
426437
symbolic_incidence = symbolic_incidence[sortidxs]
427438
end
428439

@@ -516,11 +527,103 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
516527
ts = TearingState(sys, fullvars,
517528
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
518529
complete(graph), nothing, var_types, false),
519-
Any[], param_derivative_map)
530+
Any[], param_derivative_map, original_eqs, Equation[])
520531

521532
return ts
522533
end
523534

535+
"""
536+
$(TYPEDSIGNATURES)
537+
538+
Preemptively identify observed equations in the system and tear them. This happens before
539+
any simplification. The equations torn by this process are ones that are already given in
540+
an explicit form in the system and where the LHS is not present in any other equation of
541+
the system except for other such preempitvely torn equations.
542+
"""
543+
function trivial_tearing!(ts::TearingState)
544+
@assert length(ts.original_eqs) == length(equations(ts))
545+
# equations that can be trivially torn an observed equations
546+
trivial_idxs = BitSet()
547+
# equations to never check
548+
blacklist = BitSet()
549+
torn_eqs = Equation[]
550+
# variables that have been matched to trivially torn equations
551+
matched_vars = BitSet()
552+
# variable to index in fullvars
553+
var_to_idx = Dict{Any, Int}(ts.fullvars .=> eachindex(ts.fullvars))
554+
555+
complete!(ts.structure)
556+
var_to_diff = ts.structure.var_to_diff
557+
graph = ts.structure.graph
558+
while true
559+
# track whether we added an equation to the trivial list this iteration
560+
added_equation = false
561+
for (i, eq) in enumerate(ts.original_eqs)
562+
# don't check already torn equations
563+
i in trivial_idxs && continue
564+
i in blacklist && continue
565+
# ensure it is an observed equation matched to a variable in fullvars
566+
vari = get(var_to_idx, eq.lhs, 0)
567+
iszero(vari) && continue
568+
# if a variable was the LHS of two trivial observed equations, we wouldn't have
569+
# included it in the list. Error if somehow it made it through.
570+
@assert !(vari in matched_vars)
571+
# don't tear differential/shift equations (or differentiated/shifted variables)
572+
var_to_diff[vari] == 0 || continue
573+
invview(var_to_diff)[vari] == 0 || continue
574+
# get the equations that the candidate matched variable is present in, except
575+
# those equations which have already been torn as observed
576+
eqidxs = setdiff(𝑑neighbors(graph, vari), trivial_idxs)
577+
# it should only be present in this equation
578+
length(eqidxs) == 1 || continue
579+
eqi = only(eqidxs)
580+
@assert eqi == i
581+
582+
# for every variable present in this equation, make sure it isn't _only_
583+
# present in this equation
584+
isvalid = true
585+
for v in 𝑠neighbors(graph, eqi)
586+
v == vari && continue
587+
v in matched_vars && continue
588+
isvalid &= length(𝑑neighbors(graph, v)) > 1
589+
isvalid || break
590+
end
591+
isvalid || continue
592+
# skip if the LHS is present in the RHS, since then this isn't explicit
593+
if occursin(eq.lhs, eq.rhs)
594+
push!(blacklist, i)
595+
continue
596+
end
597+
598+
added_equation = true
599+
push!(trivial_idxs, eqi)
600+
push!(torn_eqs, eq)
601+
push!(matched_vars, vari)
602+
end
603+
604+
# if we didn't add an equation this iteration, we won't add one next iteration
605+
added_equation || break
606+
end
607+
608+
deleteat!(var_to_diff.primal_to_diff, matched_vars)
609+
deleteat!(var_to_diff.diff_to_primal, matched_vars)
610+
deleteat!(ts.structure.eq_to_diff.primal_to_diff, trivial_idxs)
611+
deleteat!(ts.structure.eq_to_diff.diff_to_primal, trivial_idxs)
612+
delete_srcs!(ts.structure.graph, trivial_idxs)
613+
delete_dsts!(ts.structure.graph, matched_vars)
614+
if ts.structure.solvable_graph !== nothing
615+
delete_srcs!(ts.structure.solvable_graph, trivial_idxs)
616+
delete_dsts!(ts.structure.solvable_graph, matched_vars)
617+
end
618+
if ts.structure.var_types !== nothing
619+
deleteat!(ts.structure.var_types, matched_vars)
620+
end
621+
deleteat!(ts.fullvars, matched_vars)
622+
deleteat!(ts.original_eqs, trivial_idxs)
623+
ts.additional_observed = torn_eqs
624+
return ts
625+
end
626+
524627
function lower_order_var(dervar, t)
525628
if isdifferential(dervar)
526629
diffvar = arguments(dervar)[1]
@@ -753,6 +856,7 @@ function _mtkcompile!(state::TearingState; simplify = false,
753856
ModelingToolkit.markio!(state, orig_inputs, inputs, outputs, disturbance_inputs)
754857
state = ModelingToolkit.inputs_to_parameters!(state, [inputs; disturbance_inputs])
755858
end
859+
trivial_tearing!(state)
756860
sys, mm = ModelingToolkit.alias_elimination!(state; kwargs...)
757861
if check_consistency
758862
fully_determined = ModelingToolkit.check_consistency(

0 commit comments

Comments
 (0)