475
475
it === nothing && return nothing
476
476
e, ns = it
477
477
# c * a = b <=> a = c * b when -1 <= c <= 1
478
- return (ag[e][1 ], RootedAliasTree (iag, e)), (stage, iterate (it, ns))
478
+ return (ag[e][1 ], RootedAliasTree (iag, e)), (stage,
479
+ iterate (neighbors (invag, root), ns))
479
480
end
480
481
481
482
count_nonzeros (a:: AbstractArray ) = count (! iszero, a)
@@ -605,6 +606,15 @@ function simple_aliases!(ag, graph, var_to_diff, mm_orig, irreducibles = ())
605
606
return mm
606
607
end
607
608
609
+ function mark_processed! (processed, var_to_diff, v)
610
+ diff_to_var = invview (var_to_diff)
611
+ processed[v] = true
612
+ while (v = diff_to_var[v]) != = nothing
613
+ processed[v] = true
614
+ end
615
+ return nothing
616
+ end
617
+
608
618
function alias_eliminate_graph! (graph, var_to_diff, mm_orig:: SparseMatrixCLIL )
609
619
# Step 1: Perform bareiss factorization on the adjacency matrix of the linear
610
620
# subsystem of the system we're interested in.
@@ -674,7 +684,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
674
684
@assert length (level_to_var) == level
675
685
push! (level_to_var, v)
676
686
end
677
- processed[v] = true
687
+ mark_processed! ( processed, var_to_diff, v)
678
688
current_coeff_level[] = (coeff, level + 1 )
679
689
end
680
690
end
@@ -684,7 +694,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
684
694
max_lv = max (max_lv, lv)
685
695
v = nodevalue (t)
686
696
iszero (v) && continue
687
- processed[v] = true
697
+ mark_processed! ( processed, var_to_diff, v)
688
698
v == r && continue
689
699
if lv < length (level_to_var)
690
700
if level_to_var[lv + 1 ] == v
@@ -730,6 +740,9 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
730
740
731
741
if ! isempty (irreducibles)
732
742
ag = newag
743
+ for k in keys (ag)
744
+ push! (irreducibles, k)
745
+ end
733
746
mm_orig2 = isempty (ag) ? mm_orig : reduce! (copy (mm_orig), ag)
734
747
mm = simple_aliases! (ag, graph, var_to_diff, mm_orig2, irreducibles)
735
748
end
0 commit comments