Skip to content

Commit 1660b9b

Browse files
committed
Handle zeroing of variables better
1 parent 4e78e6a commit 1660b9b

File tree

1 file changed

+50
-9
lines changed

1 file changed

+50
-9
lines changed

src/systems/alias_elimination.jl

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@ end
1616
# For debug purposes
1717
function aag_bareiss(sys::AbstractSystem)
1818
state = TearingState(sys)
19+
complete!(state.structure)
1920
mm = linear_subsys_adjmat(state)
20-
return aag_bareiss!(state.structure.graph, complete(state.structure.var_to_diff), mm)
21+
return aag_bareiss!(state.structure.graph, state.structure.var_to_diff, mm)
2122
end
2223

2324
function extreme_var(var_to_diff, v, level = nothing, ::Val{descend} = Val(true);
@@ -441,8 +442,6 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL, irreducible
441442
for v in irreducibles
442443
is_reducible[v] = false
443444
end
444-
# TODO/FIXME: This needs a proper recursion to compute the transitive
445-
# closure.
446445
is_linear_variables = find_linear_variables(graph, linear_equations, var_to_diff,
447446
irreducibles)
448447
solvable_variables = findall(is_linear_variables)
@@ -652,16 +651,51 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
652651
end
653652
end
654653
end
654+
# If a non-differentiated variable equals to 0, then we can eliminate
655+
# the whole differentiation chain. Otherwise, we can have to keep the
656+
# lowest differentiate variable in the differentiation chain.
657+
# E.g.
658+
# ```
659+
# D(x) ~ 0
660+
# D(D(x)) ~ y
661+
# ```
662+
# reduces to
663+
# ```
664+
# D(x) ~ 0
665+
# y := 0
666+
# ```
667+
# but
668+
# ```
669+
# x ~ 0
670+
# D(x) ~ y
671+
# ```
672+
# reduces to
673+
# ```
674+
# x := 0
675+
# y := 0
676+
# ```
677+
zero_vars_set = BitSet()
655678
for v in zero_vars
656679
for a in Iterators.flatten((v, outneighbors(eqg, v)))
657680
while true
658-
dag[a] = 0
659-
da = var_to_diff[a]
660-
da === nothing && break
661-
a = da
681+
push!(zero_vars_set, a)
682+
a = var_to_diff[a]
683+
a === nothing && break
662684
end
663685
end
664686
end
687+
for v in zero_vars_set
688+
while (iv = diff_to_var[v]) in zero_vars_set
689+
v = iv
690+
end
691+
if diff_to_var[v] === nothing # `v` is reducible
692+
dag[v] = 0
693+
end
694+
# reducible after v
695+
while (v = var_to_diff[v]) !== nothing
696+
dag[v] = 0
697+
end
698+
end
665699

666700
# clean up
667701
for v in dls.visited
@@ -715,10 +749,17 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
715749
end
716750
update_graph_neighbors!(graph, ag)
717751
finalag = AliasGraph(nvars)
718-
# RHS must still exist in the system to be valid aliases.
752+
# RHS or its derivaitves must still exist in the system to be valid aliases.
719753
needs_update = false
754+
function contains_v_or_dv(var_to_diff, graph, v)
755+
while true
756+
isempty(𝑑neighbors(graph, v)) || return true
757+
v = var_to_diff[v]
758+
v === nothing && return false
759+
end
760+
end
720761
for (v, (c, a)) in ag
721-
if iszero(a) || !isempty(𝑑neighbors(graph, a))
762+
if iszero(a) || contains_v_or_dv(var_to_diff, graph, a)
722763
finalag[v] = c => a
723764
else
724765
needs_update = true

0 commit comments

Comments
 (0)