Skip to content

Commit 0eee182

Browse files
authored
Merge pull request #1738 from SciML/myb/alias_irr
Make sure that irreducible variables don't get eliminated by accident
2 parents a09b914 + 473d57d commit 0eee182

File tree

3 files changed

+30
-4
lines changed

3 files changed

+30
-4
lines changed

src/systems/alias_elimination.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,9 @@ function simple_aliases!(ag, graph, var_to_diff, mm_orig, irreducibles = ())
570570
irreducibles)
571571

572572
# Step 2: Simplify the system using the Bareiss factorization
573-
for v in setdiff(solvable_variables, @view pivots[1:rank1])
573+
rk1vars = BitSet(@view pivots[1:rank1])
574+
for v in solvable_variables
575+
v in rk1vars && continue
574576
ag[v] = 0
575577
end
576578

@@ -699,6 +701,16 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
699701
end
700702
for (i, v) in enumerate(level_to_var)
701703
_alias = get(ag, v, nothing)
704+
v_eqs = 𝑑neighbors(graph, v)
705+
# if an irreducible appears in only one equation, we need to make
706+
# sure that the other variables don't get eliminated
707+
if length(v_eqs) == 1
708+
eq = v_eqs[1]
709+
for av in 𝑠neighbors(graph, eq)
710+
push!(irreducibles, av)
711+
end
712+
ag[v] = nothing
713+
end
702714
push!(irreducibles, v)
703715
if _alias !== nothing && iszero(_alias[1]) && i < length(level_to_var)
704716
# we have `x = 0`
@@ -709,7 +721,8 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
709721
end
710722
if nlevels < (new_nlevels = length(level_to_var))
711723
for i in (nlevels + 1):new_nlevels
712-
var_to_diff[level_to_var[i - 1]] = level_to_var[i]
724+
li = level_to_var[i]
725+
var_to_diff[level_to_var[i - 1]] = li
713726
push!(updated_diff_vars, level_to_var[i - 1])
714727
end
715728
end

test/odesystem.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -839,8 +839,10 @@ let
839839

840840
sys_alias = alias_elimination(sys_con)
841841
D = Differential(t)
842-
true_eqs = [0 ~ sys.v - D(sys.x)
843-
0 ~ ctrl.kv * D(sys.x) + ctrl.kx * sys.x - D(sys.v)]
842+
true_eqs = [0 ~ D(sys.v) - sys.u
843+
0 ~ sys.x - ctrl.x
844+
0 ~ sys.v - D(sys.x)
845+
0 ~ ctrl.kv * D(sys.x) + ctrl.kx * ctrl.x - D(sys.v)]
844846
@test isequal(full_equations(sys_alias), true_eqs)
845847

846848
sys_simp = structural_simplify(sys_con)

test/reduction.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,3 +244,14 @@ eqs = [D(x) ~ σ * (y - x)
244244
lorenz1 = ODESystem(eqs, t, name = :lorenz1)
245245
lorenz1_reduced = structural_simplify(lorenz1)
246246
@test z in Set(parameters(lorenz1_reduced))
247+
248+
# MWE for #1722
249+
@variables t
250+
vars = @variables a(t) w(t) phi(t)
251+
eqs = [a ~ D(w)
252+
w ~ D(phi)
253+
w ~ sin(t)]
254+
@named sys = ODESystem(eqs, t, vars, [])
255+
ss = alias_elimination(sys)
256+
@test equations(ss) == [0 ~ D(D(phi)) - a, 0 ~ sin(t) - D(phi)]
257+
@test observed(ss) == [w ~ D(phi)]

0 commit comments

Comments
 (0)