Skip to content

Commit ed0612b

Browse files
feat: retain original equations of the system in TearingState
1 parent 4988ee1 commit ed0612b

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

src/systems/systems.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ function __mtkcompile(sys::AbstractSystem; simplify = false,
8080

8181
@unpack structure, fullvars = state
8282
@unpack graph, var_to_diff, var_types = structure
83-
eqs = equations(state)
8483
brown_vars = Int[]
8584
new_idxs = zeros(Int, length(var_types))
8685
idx = 0
@@ -98,7 +97,8 @@ function __mtkcompile(sys::AbstractSystem; simplify = false,
9897
Is = Int[]
9998
Js = Int[]
10099
vals = Num[]
101-
new_eqs = copy(eqs)
100+
make_eqs_zero_equals!(state)
101+
new_eqs = copy(equations(state))
102102
dvar2eq = Dict{Any, Int}()
103103
for (v, dv) in enumerate(var_to_diff)
104104
dv === nothing && continue

src/systems/systemstructure.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ end
203203
mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
204204
"""The system of equations."""
205205
sys::T
206+
original_eqs::Vector{Equation}
206207
"""The set of variables of the system."""
207208
fullvars::Vector{BasicSymbolic}
208209
structure::SystemStructure
@@ -219,6 +220,7 @@ end
219220
TransformationState(sys::AbstractSystem) = TearingState(sys)
220221
function system_subset(ts::TearingState, ieqs::Vector{Int})
221222
eqs = equations(ts)
223+
@set! ts.original_eqs = ts.original_eqs[ieqs]
222224
@set! ts.sys.eqs = eqs[ieqs]
223225
@set! ts.original_eqs = ts.original_eqs[ieqs]
224226
@set! ts.structure = system_subset(ts.structure, ieqs)
@@ -524,7 +526,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
524526

525527
eq_to_diff = DiffGraph(nsrcs(graph))
526528

527-
ts = TearingState(sys, fullvars,
529+
ts = TearingState(sys, original_eqs, fullvars,
528530
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
529531
complete(graph), nothing, var_types, false),
530532
Any[], param_derivative_map, original_eqs, Equation[])
@@ -810,6 +812,22 @@ function Base.show(io::IO, mime::MIME"text/plain", ms::MatchedSystemStructure)
810812
printstyled(io, " SelectedState")
811813
end
812814

815+
function make_eqs_zero_equals!(ts::TearingState)
816+
neweqs = map(enumerate(get_eqs(ts.sys))) do kvp
817+
i, eq = kvp
818+
isalgeq = true
819+
for j in 𝑠neighbors(ts.structure.graph, i)
820+
isalgeq &= invview(ts.structure.var_to_diff)[j] === nothing
821+
end
822+
if isalgeq
823+
return 0 ~ eq.rhs - eq.lhs
824+
else
825+
return eq
826+
end
827+
end
828+
copyto!(get_eqs(ts.sys), neweqs)
829+
end
830+
813831
function mtkcompile!(state::TearingState; simplify = false,
814832
check_consistency = true, fully_determined = true, warn_initialize_determined = true,
815833
inputs = Any[], outputs = Any[],
@@ -836,6 +854,7 @@ function mtkcompile!(state::TearingState; simplify = false,
836854
"""))
837855
end
838856
if length(tss) > 1
857+
make_eqs_zero_equals!(tss[continuous_id])
839858
# simplify as normal
840859
sys = _mtkcompile!(tss[continuous_id]; simplify,
841860
inputs = [inputs; clocked_inputs[continuous_id]], outputs, disturbance_inputs,

0 commit comments

Comments
 (0)