Skip to content

Commit dea08c6

Browse files
feat: retain original equations of the system in TearingState
1 parent 404d7dc commit dea08c6

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

src/systems/systems.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal
8787

8888
@unpack structure, fullvars = state
8989
@unpack graph, var_to_diff, var_types = structure
90-
eqs = equations(state)
9190
brown_vars = Int[]
9291
new_idxs = zeros(Int, length(var_types))
9392
idx = 0
@@ -104,7 +103,8 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal
104103
Is = Int[]
105104
Js = Int[]
106105
vals = Num[]
107-
new_eqs = copy(eqs)
106+
make_eqs_zero_equals!(state)
107+
new_eqs = copy(equations(state))
108108
dvar2eq = Dict{Any, Int}()
109109
for (v, dv) in enumerate(var_to_diff)
110110
dv === nothing && continue

src/systems/systemstructure.jl

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ end
198198

199199
mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
200200
sys::T
201+
original_eqs::Vector{Equation}
201202
fullvars::Vector
202203
structure::SystemStructure
203204
extra_eqs::Vector
@@ -206,6 +207,7 @@ end
206207
TransformationState(sys::AbstractSystem) = TearingState(sys)
207208
function system_subset(ts::TearingState, ieqs::Vector{Int})
208209
eqs = equations(ts)
210+
@set! ts.original_eqs = ts.original_eqs[ieqs]
209211
@set! ts.sys.eqs = eqs[ieqs]
210212
@set! ts.structure = system_subset(ts.structure, ieqs)
211213
ts
@@ -252,7 +254,8 @@ function TearingState(sys; quick_cancel = false, check = true)
252254
ivs = independent_variables(sys)
253255
iv = length(ivs) == 1 ? ivs[1] : nothing
254256
# scalarize array equations, without scalarizing arguments to registered functions
255-
eqs = flatten_equations(copy(equations(sys)))
257+
original_eqs = flatten_equations(copy(equations(sys)))
258+
eqs = copy(original_eqs)
256259
neqs = length(eqs)
257260
dervaridxs = OrderedSet{Int}()
258261
var2idx = Dict{Any, Int}()
@@ -428,7 +431,7 @@ function TearingState(sys; quick_cancel = false, check = true)
428431

429432
eq_to_diff = DiffGraph(nsrcs(graph))
430433

431-
ts = TearingState(sys, fullvars,
434+
ts = TearingState(sys, original_eqs, fullvars,
432435
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
433436
complete(graph), nothing, var_types, sys isa DiscreteSystem),
434437
Any[])
@@ -622,6 +625,22 @@ function merge_io(io, inputs)
622625
return io
623626
end
624627

628+
function make_eqs_zero_equals!(ts::TearingState)
629+
neweqs = map(enumerate(get_eqs(ts.sys))) do kvp
630+
i, eq = kvp
631+
isalgeq = true
632+
for j in 𝑠neighbors(ts.structure.graph, i)
633+
isalgeq &= invview(ts.structure.var_to_diff)[j] === nothing
634+
end
635+
if isalgeq
636+
return 0 ~ eq.rhs - eq.lhs
637+
else
638+
return eq
639+
end
640+
end
641+
copyto!(get_eqs(ts.sys), neweqs)
642+
end
643+
625644
function structural_simplify!(state::TearingState, io = nothing; simplify = false,
626645
check_consistency = true, fully_determined = true, warn_initialize_determined = true,
627646
kwargs...)
@@ -649,6 +668,7 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals
649668
throw(HybridSystemNotSupportedException("Discrete systems without JuliaSimCompiler are currently not supported in ODESystem."))
650669
end
651670
end
671+
make_eqs_zero_equals!(tss[continuous_id])
652672
# puts the ios passed in to the call into the continous system
653673
cont_io = merge_io(io, inputs[continuous_id])
654674
# simplify as normal

0 commit comments

Comments
 (0)