Skip to content

Commit 5aeb26c

Browse files
committed
More robust aliasing check and clean up
1 parent 4a4a059 commit 5aeb26c

File tree

3 files changed

+38
-46
lines changed

3 files changed

+38
-46
lines changed

src/bipartite_graph.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,9 @@ function set_neighbors!(g::BipartiteGraph, i::Integer, new_neighbors)
435435
for n in old_neighbors
436436
@inbounds list = g.badjlist[n]
437437
index = searchsortedfirst(list, i)
438-
deleteat!(list, index)
438+
if 1 <= index <= length(list) && list[index] == i
439+
deleteat!(list, index)
440+
end
439441
end
440442
for n in new_neighbors
441443
@inbounds list = g.badjlist[n]

src/systems/alias_elimination.jl

Lines changed: 33 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ end
3434

3535
alias_elimination(sys) = alias_elimination!(TearingState(sys; quick_cancel = true))
3636
function alias_elimination!(state::TearingState)
37-
Main._state[] = state
3837
sys = state.sys
3938
complete!(state.structure)
4039
ag, mm, updated_diff_vars = alias_eliminate_graph!(state)
@@ -510,7 +509,8 @@ function find_linear_variables(graph, linear_equations, var_to_diff, irreducible
510509
linear_variables = falses(length(var_to_diff))
511510
var_to_lineq = Dict{Int, BitSet}()
512511
mark_not_linear! = let linear_variables = linear_variables, stack = stack,
513-
var_to_lineq = var_to_lineq
512+
var_to_lineq = var_to_lineq
513+
514514
v -> begin
515515
linear_variables[v] = false
516516
push!(stack, v)
@@ -529,7 +529,7 @@ function find_linear_variables(graph, linear_equations, var_to_diff, irreducible
529529
end
530530
for eq in linear_equations, v in 𝑠neighbors(graph, eq)
531531
linear_variables[v] = true
532-
vlineqs = get!(()->BitSet(), var_to_lineq, v)
532+
vlineqs = get!(() -> BitSet(), var_to_lineq, v)
533533
push!(vlineqs, eq)
534534
end
535535
for v in irreducibles
@@ -586,7 +586,8 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL, irreducible
586586
end
587587
# TODO/FIXME: This needs a proper recursion to compute the transitive
588588
# closure.
589-
is_linear_variables = find_linear_variables(graph, linear_equations, var_to_diff, irreducibles)
589+
is_linear_variables = find_linear_variables(graph, linear_equations, var_to_diff,
590+
irreducibles)
590591
solvable_variables = findall(is_linear_variables)
591592

592593
function do_bareiss!(M, Mold = nothing)
@@ -646,11 +647,6 @@ function simple_aliases!(ag, graph, var_to_diff, mm_orig, irreducibles = ())
646647

647648
# Step 2: Simplify the system using the Bareiss factorization
648649
rk1vars = BitSet(@view pivots[1:rank1])
649-
fullvars = Main._state[].fullvars
650-
@info "" mm_orig.nzrows mm_orig
651-
@show fullvars
652-
@show fullvars[pivots[1:rank1]]
653-
@show fullvars[solvable_variables]
654650
for v in solvable_variables
655651
v in rk1vars && continue
656652
ag[v] = 0
@@ -702,13 +698,6 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
702698
nvars = ndsts(graph)
703699
ag = AliasGraph(nvars)
704700
mm, echelon_mm = simple_aliases!(ag, graph, var_to_diff, mm_orig)
705-
state = Main._state[]
706-
fullvars = state.fullvars
707-
for (v, (c, a)) in ag
708-
a = a == 0 ? 0 : c * fullvars[a]
709-
v = fullvars[v]
710-
@info "simple alias" v => a
711-
end
712701

713702
# Step 3: Handle differentiated variables
714703
# At this point, `var_to_diff` and `ag` form a tree structure like the
@@ -745,15 +734,14 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
745734
newinvag = SimpleDiGraph(nvars)
746735
removed_aliases = BitSet()
747736
updated_diff_vars = Int[]
748-
irreducibles = Int[]
749737
for (v, dv) in enumerate(var_to_diff)
750738
processed[v] && continue
751739
(dv === nothing && diff_to_var[v] === nothing) && continue
752740

753741
r, _ = find_root!(iag, v)
754-
sv = fullvars[v]
755-
root = fullvars[r]
756-
@info "Found root $r" sv=>root
742+
# sv = fullvars[v]
743+
# root = fullvars[r]
744+
# @info "Found root $r" sv=>root
757745
level_to_var = Int[]
758746
extreme_var(var_to_diff, r, nothing, Val(false),
759747
callback = Base.Fix1(push!, level_to_var))
@@ -799,7 +787,6 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
799787
extreme_var(var_to_diff, v, nothing, Val(false), callback = add_alias!)
800788
end
801789

802-
@show processed
803790
len = length(level_to_var)
804791
set_v_zero! = let dag = dag
805792
v -> dag[v] = 0
@@ -820,14 +807,8 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
820807
# zero. Irreducible variables are highest differentiated variables (with
821808
# order >= 1) that are not zero.
822809
if zero_av_idx > 0
823-
extreme_var(var_to_diff, level_to_var[zero_av_idx], nothing, Val(false), callback = set_v_zero!)
824-
if zero_av_idx > 2
825-
@warn "1"
826-
push!(irreducibles, level_to_var[zero_av_idx - 1])
827-
end
828-
elseif len >= 2
829-
@warn "2"
830-
push!(irreducibles, level_to_var[len])
810+
extreme_var(var_to_diff, level_to_var[zero_av_idx], nothing, Val(false),
811+
callback = set_v_zero!)
831812
end
832813
# Handle virtual variables
833814
if nlevels < len
@@ -839,17 +820,11 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
839820
end
840821
end
841822

842-
# Merge dag and ag
823+
# Step 4: Merge dag and ag
843824
freshag = AliasGraph(nvars)
844-
@show irreducibles
845-
@show dag
846825
for (v, (c, a)) in dag
847-
# TODO: make sure that `irreducibles` are
848826
# D(x) ~ D(y) cannot be removed if x and y are not aliases
849-
if v != a && a in irreducibles
850-
push!(removed_aliases, v)
851-
@goto NEXT_ITER
852-
elseif v != a && !iszero(a)
827+
if v != a && !iszero(a)
853828
vv = v
854829
aa = a
855830
while true
@@ -873,16 +848,31 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
873848
ag = freshag
874849
mm = reduce!(copy(echelon_mm), ag)
875850
end
876-
@info "" echelon_mm mm
877851

852+
# Step 5: Reflect our update decisions back into the graph, and make sure
853+
# that the RHS of observable variables are defined.
854+
for (ei, e) in enumerate(mm.nzrows)
855+
set_neighbors!(graph, e, mm.row_cols[ei])
856+
end
857+
update_graph_neighbors!(graph, ag)
858+
finalag = AliasGraph(nvars)
859+
# RHS must still exist in the system to be valid aliases.
860+
needs_update = false
878861
for (v, (c, a)) in ag
879-
va = iszero(a) ? a : fullvars[a]
880-
@info "new alias" fullvars[v]=>(c, va)
862+
if iszero(a) || !isempty(𝑑neighbors(graph, a))
863+
finalag[v] = c => a
864+
else
865+
needs_update = true
866+
end
881867
end
868+
ag = finalag
882869

883-
# Step 4: Reflect our update decisions back into the graph
884-
for (ei, e) in enumerate(mm.nzrows)
885-
set_neighbors!(graph, e, mm.row_cols[ei])
870+
if needs_update
871+
mm = reduce!(copy(echelon_mm), ag)
872+
for (ei, e) in enumerate(mm.nzrows)
873+
set_neighbors!(graph, e, mm.row_cols[ei])
874+
end
875+
update_graph_neighbors!(graph, ag)
886876
end
887877

888878
return ag, mm, updated_diff_vars

test/reduction.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,11 +286,11 @@ eqs = [x ~ 0
286286
@named sys = ODESystem(eqs, t, [x, y, a, b], [])
287287
ss = alias_elimination(sys)
288288
@test equations(ss) == [0 ~ b - a]
289-
@test sort(observed(ss), by=string) == ([D(x), x, y] .~ 0)
289+
@test sort(observed(ss), by = string) == ([D(x), x, y] .~ 0)
290290

291291
eqs = [x ~ 0
292292
D(x) ~ x + y]
293293
@named sys = ODESystem(eqs, t, [x, y], [])
294294
ss = alias_elimination(sys)
295295
@test isempty(equations(ss))
296-
@test sort(observed(ss), by=string) == ([D(x), x, y] .~ 0)
296+
@test sort(observed(ss), by = string) == ([D(x), x, y] .~ 0)

0 commit comments

Comments
 (0)