diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index 9eccc309bd..f6c394ee1c 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -687,8 +687,8 @@ function update_simplified_system!( unknowns = [unknowns; extra_unknowns] @set! sys.unknowns = unknowns - obs, subeqs, deps = cse_and_array_hacks( - sys, obs, solved_eqs, unknowns, neweqs; cse = cse_hack, array = array_hack) + obs = cse_and_array_hacks( + sys, obs, unknowns, neweqs; cse = cse_hack, array = array_hack) @set! sys.eqs = neweqs @set! sys.observed = obs @@ -790,7 +790,7 @@ if all `p[i]` are present and the unscalarized form is used in any equation (obs not) we first count the number of times the scalarized form of each observed variable occurs in observed equations (and unknowns if it's split). """ -function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, array = true) +function cse_and_array_hacks(sys, obs, unknowns, neweqs; cse = true, array = true) # HACK 1 # mapping of rhs to temporary CSE variable # `f(...) => tmpvar` in above example @@ -818,7 +818,6 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr tempeq = tempvar ~ rhs_arr rhs_to_tempvar[rhs_arr] = tempvar push!(obs, tempeq) - push!(subeqs, tempeq) end # getindex_wrapper is used because `observed2graph` treats `x` and `x[i]` as different, @@ -827,10 +826,6 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr neweq = lhs ~ getindex_wrapper( rhs_to_tempvar[rhs_arr], Tuple(arguments(rhs)[2:end])) obs[i] = neweq - subeqi = findfirst(isequal(eq), subeqs) - if subeqi !== nothing - subeqs[subeqi] = neweq - end end # end HACK 1 @@ -860,7 +855,6 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr tempeq = tempvar ~ rhs_arr rhs_to_tempvar[rhs_arr] = tempvar push!(obs, tempeq) - push!(subeqs, tempeq) end # don't need getindex_wrapper, but do it anyway to know that this # hack took place @@ -900,15 +894,8 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr push!(obs_arr_eqs, arrvar ~ rhs) end append!(obs, obs_arr_eqs) - append!(subeqs, obs_arr_eqs) - - # need to re-sort subeqs - subeqs = ModelingToolkit.topsort_equations(subeqs, [eq.lhs for eq in subeqs]) - - deps = Vector{Int}[i == 1 ? Int[] : collect(1:(i - 1)) - for i in 1:length(subeqs)] - return obs, subeqs, deps + return obs end function is_getindexed_array(rhs) diff --git a/src/utils.jl b/src/utils.jl index 8fcf8d7a25..dd6971dbe1 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -700,7 +700,8 @@ Get a dictionary mapping variables eliminated from the system during `mtkcompile expressions used to calculate them. """ function get_substitutions(sys) - Dict([eq.lhs => eq.rhs for eq in observed(sys)]) + obs, _ = unhack_observed(observed(sys), equations(sys)) + Dict([eq.lhs => eq.rhs for eq in obs]) end @noinline function throw_missingvars_in_sys(vars) diff --git a/test/odesystem.jl b/test/odesystem.jl index be301f34f1..58315f0cff 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -1560,3 +1560,32 @@ end @mtkcompile sys = SysC() @test length(unknowns(sys)) == 3 end + +@testset "`full_equations` doesn't recurse infinitely" begin + code = """ + using ModelingToolkit + using ModelingToolkit: t_nounits as t, D_nounits as D + @variables x(t)[1:3]=[0,0,1] + @variables u1(t)=0 u2(t)=0 + y₁, y₂, y₃ = x + k₁, k₂, k₃ = 1,1,1 + eqs = [ + D(y₁) ~ -k₁*y₁ + k₃*y₂*y₃ + u1 + D(y₂) ~ k₁*y₁ - k₃*y₂*y₃ - k₂*y₂^2 + u2 + y₁ + y₂ + y₃ ~ 1 + ] + + @named sys = System(eqs, t) + + inputs = [u1, u2] + outputs = [y₁, y₂, y₃] + ss = mtkcompile(sys; inputs) + full_equations(ss) + """ + + cmd = `$(Base.julia_cmd()) --project=$(@__DIR__) -e $code` + proc = run(cmd, stdin, stdout, stderr; wait = false) + sleep(120) + @test !process_running(proc) + kill(proc, Base.SIGKILL) +end