203203mutable struct TearingState{T <: AbstractSystem } <: AbstractTearingState{T}
204204 """ The system of equations."""
205205 sys:: T
206+ original_eqs:: Vector{Equation}
206207 """ The set of variables of the system."""
207208 fullvars:: Vector{BasicSymbolic}
208209 structure:: SystemStructure
213214TransformationState (sys:: AbstractSystem ) = TearingState (sys)
214215function system_subset (ts:: TearingState , ieqs:: Vector{Int} )
215216 eqs = equations (ts)
217+ @set! ts. original_eqs = ts. original_eqs[ieqs]
216218 @set! ts. sys. eqs = eqs[ieqs]
217219 @set! ts. structure = system_subset (ts. structure, ieqs)
218220 ts
@@ -274,8 +276,9 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
274276 sys = process_parameter_equations (sys)
275277 ivs = independent_variables (sys)
276278 iv = length (ivs) == 1 ? ivs[1 ] : nothing
277- # flatten array equations
278- eqs = flatten_equations (equations (sys))
279+ # scalarize array equations, without scalarizing arguments to registered functions
280+ original_eqs = flatten_equations (copy (equations (sys)))
281+ eqs = copy (original_eqs)
279282 neqs = length (eqs)
280283 param_derivative_map = Dict {BasicSymbolic, Any} ()
281284 # * Scalarize unknowns
@@ -513,7 +516,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
513516
514517 eq_to_diff = DiffGraph (nsrcs (graph))
515518
516- ts = TearingState (sys, fullvars,
519+ ts = TearingState (sys, original_eqs, fullvars,
517520 SystemStructure (complete (var_to_diff), complete (eq_to_diff),
518521 complete (graph), nothing , var_types, false ),
519522 Any[], param_derivative_map)
@@ -696,6 +699,22 @@ function Base.show(io::IO, mime::MIME"text/plain", ms::MatchedSystemStructure)
696699 printstyled (io, " SelectedState" )
697700end
698701
702+ function make_eqs_zero_equals! (ts:: TearingState )
703+ neweqs = map (enumerate (get_eqs (ts. sys))) do kvp
704+ i, eq = kvp
705+ isalgeq = true
706+ for j in 𝑠neighbors (ts. structure. graph, i)
707+ isalgeq &= invview (ts. structure. var_to_diff)[j] === nothing
708+ end
709+ if isalgeq
710+ return 0 ~ eq. rhs - eq. lhs
711+ else
712+ return eq
713+ end
714+ end
715+ copyto! (get_eqs (ts. sys), neweqs)
716+ end
717+
699718function mtkcompile! (state:: TearingState ; simplify = false ,
700719 check_consistency = true , fully_determined = true , warn_initialize_determined = true ,
701720 inputs = Any[], outputs = Any[],
@@ -722,6 +741,7 @@ function mtkcompile!(state::TearingState; simplify = false,
722741 """ ))
723742 end
724743 if length (tss) > 1
744+ make_eqs_zero_equals! (tss[continuous_id])
725745 # simplify as normal
726746 sys = _mtkcompile! (tss[continuous_id]; simplify,
727747 inputs = [inputs; clocked_inputs[continuous_id]], outputs, disturbance_inputs,
0 commit comments