Skip to content

Commit 1286318

Browse files
feat: retain original equations of the system in TearingState
1 parent a581b96 commit 1286318

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

src/systems/systems.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ function __mtkcompile(sys::AbstractSystem; simplify = false,
8080

8181
@unpack structure, fullvars = state
8282
@unpack graph, var_to_diff, var_types = structure
83-
eqs = equations(state)
8483
brown_vars = Int[]
8584
new_idxs = zeros(Int, length(var_types))
8685
idx = 0
@@ -98,7 +97,8 @@ function __mtkcompile(sys::AbstractSystem; simplify = false,
9897
Is = Int[]
9998
Js = Int[]
10099
vals = Num[]
101-
new_eqs = copy(eqs)
100+
make_eqs_zero_equals!(state)
101+
new_eqs = copy(equations(state))
102102
dvar2eq = Dict{Any, Int}()
103103
for (v, dv) in enumerate(var_to_diff)
104104
dv === nothing && continue

src/systems/systemstructure.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ end
204204
mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
205205
"""The system of equations."""
206206
sys::T
207+
original_eqs::Vector{Equation}
207208
"""The set of variables of the system."""
208209
fullvars::Vector{BasicSymbolic}
209210
structure::SystemStructure
@@ -527,7 +528,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
527528

528529
eq_to_diff = DiffGraph(nsrcs(graph))
529530

530-
ts = TearingState(sys, fullvars,
531+
ts = TearingState(sys, original_eqs, fullvars,
531532
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
532533
complete(graph), nothing, var_types, false),
533534
Any[], param_derivative_map, original_eqs, Equation[])
@@ -813,6 +814,22 @@ function Base.show(io::IO, mime::MIME"text/plain", ms::MatchedSystemStructure)
813814
printstyled(io, " SelectedState")
814815
end
815816

817+
function make_eqs_zero_equals!(ts::TearingState)
818+
neweqs = map(enumerate(get_eqs(ts.sys))) do kvp
819+
i, eq = kvp
820+
isalgeq = true
821+
for j in 𝑠neighbors(ts.structure.graph, i)
822+
isalgeq &= invview(ts.structure.var_to_diff)[j] === nothing
823+
end
824+
if isalgeq
825+
return 0 ~ eq.rhs - eq.lhs
826+
else
827+
return eq
828+
end
829+
end
830+
copyto!(get_eqs(ts.sys), neweqs)
831+
end
832+
816833
function mtkcompile!(state::TearingState; simplify = false,
817834
check_consistency = true, fully_determined = true, warn_initialize_determined = true,
818835
inputs = Any[], outputs = Any[],
@@ -839,6 +856,7 @@ function mtkcompile!(state::TearingState; simplify = false,
839856
"""))
840857
end
841858
if length(tss) > 1
859+
make_eqs_zero_equals!(tss[continuous_id])
842860
# simplify as normal
843861
sys = _mtkcompile!(tss[continuous_id]; simplify,
844862
inputs = [inputs; clocked_inputs[continuous_id]], outputs, disturbance_inputs,

0 commit comments

Comments
 (0)