Skip to content

Commit 182b928

Browse files
authored
Merge pull request #1759 from SciML/myb/alias_fix
Remember to mark the predecessors as processed as well to avoid cycles in alias elimination
2 parents d6195ee + eb5cd3e commit 182b928

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

src/systems/alias_elimination.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,8 @@ end
475475
it === nothing && return nothing
476476
e, ns = it
477477
# c * a = b <=> a = c * b when -1 <= c <= 1
478-
return (ag[e][1], RootedAliasTree(iag, e)), (stage, iterate(it, ns))
478+
return (ag[e][1], RootedAliasTree(iag, e)), (stage,
479+
iterate(neighbors(invag, root), ns))
479480
end
480481

481482
count_nonzeros(a::AbstractArray) = count(!iszero, a)
@@ -605,6 +606,15 @@ function simple_aliases!(ag, graph, var_to_diff, mm_orig, irreducibles = ())
605606
return mm
606607
end
607608

609+
function mark_processed!(processed, var_to_diff, v)
610+
diff_to_var = invview(var_to_diff)
611+
processed[v] = true
612+
while (v = diff_to_var[v]) !== nothing
613+
processed[v] = true
614+
end
615+
return nothing
616+
end
617+
608618
function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
609619
# Step 1: Perform bareiss factorization on the adjacency matrix of the linear
610620
# subsystem of the system we're interested in.
@@ -674,7 +684,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
674684
@assert length(level_to_var) == level
675685
push!(level_to_var, v)
676686
end
677-
processed[v] = true
687+
mark_processed!(processed, var_to_diff, v)
678688
current_coeff_level[] = (coeff, level + 1)
679689
end
680690
end
@@ -684,7 +694,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
684694
max_lv = max(max_lv, lv)
685695
v = nodevalue(t)
686696
iszero(v) && continue
687-
processed[v] = true
697+
mark_processed!(processed, var_to_diff, v)
688698
v == r && continue
689699
if lv < length(level_to_var)
690700
if level_to_var[lv + 1] == v
@@ -730,6 +740,9 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
730740

731741
if !isempty(irreducibles)
732742
ag = newag
743+
for k in keys(ag)
744+
push!(irreducibles, k)
745+
end
733746
mm_orig2 = isempty(ag) ? mm_orig : reduce!(copy(mm_orig), ag)
734747
mm = simple_aliases!(ag, graph, var_to_diff, mm_orig2, irreducibles)
735748
end

test/reduction.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,3 +255,17 @@ eqs = [a ~ D(w)
255255
ss = alias_elimination(sys)
256256
@test equations(ss) == [0 ~ D(D(phi)) - a, 0 ~ sin(t) - D(phi)]
257257
@test observed(ss) == [w ~ D(phi)]
258+
259+
@variables t x(t) y(t)
260+
D = Differential(t)
261+
@named sys = ODESystem([D(x) ~ 1 - x,
262+
D(y) + D(x) ~ 0])
263+
new_sys = structural_simplify(sys)
264+
@test equations(new_sys) == [D(x) ~ 1 - x]
265+
@test observed(new_sys) == [D(y) ~ -D(x)]
266+
267+
@named sys = ODESystem([D(x) ~ 1 - x,
268+
y + D(x) ~ 0])
269+
new_sys = structural_simplify(sys)
270+
@test equations(new_sys) == [D(x) ~ 1 - x]
271+
@test observed(new_sys) == [y ~ -D(x)]

0 commit comments

Comments
 (0)