Skip to content

Commit fc51f56

Browse files
committed
Set redundant linear algebraic variables to zero and further clean up
1 parent 9594e65 commit fc51f56

File tree

1 file changed

+79
-65
lines changed

1 file changed

+79
-65
lines changed

src/systems/alias_elimination.jl

Lines changed: 79 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -426,98 +426,82 @@ function find_linear_variables(graph, linear_equations, var_to_diff, irreducible
426426
return linear_variables
427427
end
428428

429-
function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL, irreducibles = ())
429+
function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
430430
mm = copy(mm_orig)
431-
linear_equations = mm_orig.nzrows
432-
is_linear_equations = falses(size(AsSubMatrix(mm_orig), 1))
433-
for e in mm_orig.nzrows
434-
is_linear_equations[e] = true
435-
end
431+
linear_equations_set = BitSet(mm_orig.nzrows)
436432

437-
# If linear highest differentiated variables cannot be assigned to a pivot,
438-
# then we can set it to zero. We use `rank1` to track this.
439-
#
440-
# We only use alias graph to record reducible variables. We use `rank2` to
441-
# track this.
433+
# All unassigned (not a pivot) algebraic variables that only appears in
434+
# linear algebraic equations can be set to 0.
442435
#
443436
# For all the other variables, we can update the original system with
444437
# Bareiss'ed coefficients as Gaussian elimination is nullspace perserving
445438
# and we are only working on linear homogeneous subsystem.
446-
is_reducible = trues(length(var_to_diff))
447-
#TODO: what's the correct criterion here?
448-
is_linear_variables = isnothing.(var_to_diff) .& isnothing.(invview(var_to_diff))
439+
440+
is_algebraic = let var_to_diff = var_to_diff
441+
v -> var_to_diff[v] === nothing === invview(var_to_diff)[v]
442+
end
443+
is_linear_variables = is_algebraic.(1:length(var_to_diff))
449444
for i in 𝑠vertices(graph)
450-
is_linear_equations[i] && continue
445+
# only consider linear algebraic equations
446+
(i in linear_equations_set && all(is_algebraic, 𝑠neighbors(graph, i))) && continue
451447
for j in 𝑠neighbors(graph, i)
452448
is_linear_variables[j] = false
453449
end
454450
end
455451
solvable_variables = findall(is_linear_variables)
456452

457-
function do_bareiss!(M, Mold = nothing)
458-
rank1 = rank2 = nothing
459-
pivots = Int[]
460-
function find_pivot(M, k)
461-
if rank1 === nothing
453+
return mm, solvable_variables, do_bareiss!(mm, mm_orig, is_linear_variables)
454+
end
455+
456+
function do_bareiss!(M, Mold, is_linear_variables)
457+
rank1r = Ref{Union{Nothing, Int}}(nothing)
458+
find_pivot = let rank1r = rank1r
459+
(M, k) -> begin
460+
if rank1r[] === nothing
462461
r = find_masked_pivot(is_linear_variables, M, k)
463462
r !== nothing && return r
464-
rank1 = k - 1
465-
end
466-
if rank2 === nothing
467-
r = find_masked_pivot(is_reducible, M, k)
468-
r !== nothing && return r
469-
rank2 = k - 1
463+
rank1r[] = k - 1
470464
end
471465
# TODO: It would be better to sort the variables by
472466
# derivative order here to enable more elimination
473467
# opportunities.
474468
return find_masked_pivot(nothing, M, k)
475469
end
476-
function find_and_record_pivot(M, k)
470+
end
471+
pivots = Int[]
472+
find_and_record_pivot = let pivots = pivots
473+
(M, k) -> begin
477474
r = find_pivot(M, k)
478475
r === nothing && return nothing
479476
push!(pivots, r[1][2])
480477
return r
481478
end
482-
function myswaprows!(M, i, j)
479+
end
480+
myswaprows! = let Mold = Mold
481+
(M, i, j) -> begin
483482
Mold !== nothing && swaprows!(Mold, i, j)
484483
swaprows!(M, i, j)
485484
end
486-
bareiss_ops = ((M, i, j) -> nothing, myswaprows!,
487-
bareiss_update_virtual_colswap_mtk!, bareiss_zero!)
488-
rank3, = bareiss!(M, bareiss_ops; find_pivot = find_and_record_pivot)
489-
rank2 = something(rank2, rank3)
490-
rank1 = something(rank1, rank2)
491-
(rank1, rank2, rank3, pivots)
492485
end
493-
494-
return mm, solvable_variables, do_bareiss!(mm, mm_orig)
486+
bareiss_ops = ((M, i, j) -> nothing, myswaprows!,
487+
bareiss_update_virtual_colswap_mtk!, bareiss_zero!)
488+
rank2, = bareiss!(M, bareiss_ops; find_pivot = find_and_record_pivot)
489+
rank1 = something(rank1r[], rank2)
490+
(rank1, rank2, pivots)
495491
end
496492

497493
# Kind of like the backward substitution, but we don't actually rely on it
498494
# being lower triangular. We eliminate a variable if there are at most 2
499495
# variables left after the substitution.
500-
function lss(mm, pivots, ag)
496+
function lss(mm, ag, pivots)
501497
ei -> let mm = mm, pivots = pivots, ag = ag
502-
vi = pivots[ei]
498+
vi = pivots === nothing ? nothing : pivots[ei]
503499
locally_structure_simplify!((@view mm[ei, :]), vi, ag)
504500
end
505501
end
506502

507-
function simple_aliases!(ag, graph, var_to_diff, mm_orig, irreducibles = ())
508-
mm, solvable_variables, (rank1, rank2, rank3, pivots) = aag_bareiss!(graph, var_to_diff,
509-
mm_orig,
510-
irreducibles)
511-
512-
# Step 2: Simplify the system using the Bareiss factorization
513-
rk1vars = BitSet(@view pivots[1:rank1])
514-
for v in solvable_variables
515-
v in rk1vars && continue
516-
ag[v] = 0
517-
end
518-
519-
echelon_mm = copy(mm)
520-
lss! = lss(mm, pivots, ag)
503+
function reduce!(mm, mm_orig, ag, rank2, pivots = nothing)
504+
lss! = lss(mm, ag, pivots)
521505
# Step 2.1: Go backwards, collecting eliminated variables and substituting
522506
# alias as we go.
523507
foreach(lss!, reverse(1:rank2))
@@ -543,6 +527,20 @@ function simple_aliases!(ag, graph, var_to_diff, mm_orig, irreducibles = ())
543527
reduced && while any(lss!, 1:rank2)
544528
end
545529

530+
return mm
531+
end
532+
533+
function simple_aliases!(ag, graph, var_to_diff, mm_orig)
534+
echelon_mm, solvable_variables, (rank1, rank2, pivots) = aag_bareiss!(graph, var_to_diff, mm_orig)
535+
536+
# Step 2: Simplify the system using the Bareiss factorization
537+
rk1vars = BitSet(@view pivots[1:rank1])
538+
for v in solvable_variables
539+
v in rk1vars && continue
540+
ag[v] = 0
541+
end
542+
543+
mm = reduce!(copy(echelon_mm), mm_orig, ag, rank2, pivots)
546544
return mm, echelon_mm
547545
end
548546

@@ -587,13 +585,13 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
587585
(v, w) -> var_to_diff[v] == w || var_to_diff[w] == v
588586
end
589587
diff_aliases = Vector{Pair{Int, Int}}[]
588+
stem = Int[]
589+
stem_set = BitSet()
590590
for (v, dv) in enumerate(var_to_diff)
591591
processed[v] && continue
592592
(dv === nothing && diff_to_var[v] === nothing) && continue
593593
r = find_root!(dls, g, v)
594594
prev_r = -1
595-
stem = Int[]
596-
stem_set = BitSet()
597595
for _ in 1:10_000 # just to make sure that we don't stuck in an infinite loop
598596
reach₌ = Pair{Int, Int}[]
599597
r === nothing || for n in neighbors(eqg, r)
@@ -650,6 +648,21 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
650648
# edges.
651649
weighted_transitiveclosure!(eqg)
652650
# Canonicalize by preferring the lower differentiated variable
651+
# If we have the system
652+
# ```
653+
# D(x) ~ x
654+
# D(x) + D(y) ~ 0
655+
# ```
656+
# preferring the lower variable would lead to
657+
# ```
658+
# D(x) ~ x <== added back because `x := D(x)` removes `D(x)`
659+
# D(y) ~ -x
660+
# ```
661+
# while preferring the higher variable would lead to
662+
# ```
663+
# D(x) + D(y) ~ 0
664+
# ```
665+
# which is not correct.
653666
for i in 1:(length(stem) - 1)
654667
r = stem[i]
655668
for dr in @view stem[(i + 1):end]
@@ -712,6 +725,8 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
712725
end
713726
empty!(dls.visited)
714727
empty!(diff_aliases)
728+
empty!(stem)
729+
empty!(stem_set)
715730
end
716731
# update `dag`
717732
for k in keys(dag)
@@ -720,7 +735,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
720735

721736
# Step 4: Merge dag and ag
722737
removed_aliases = BitSet()
723-
freshag = AliasGraph(nvars)
738+
merged_ag = AliasGraph(nvars)
724739
for (v, (c, a)) in dag
725740
# D(x) ~ D(y) cannot be removed if x and y are not aliases
726741
if v != a && !iszero(a)
@@ -732,23 +747,21 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
732747
vv === nothing && break
733748
if !(haskey(dag, vv) && dag[vv][2] == diff_to_var[aa])
734749
push!(removed_aliases, vv′)
735-
@goto SKIP_FRESHAG
750+
@goto SKIP_merged_ag
736751
end
737752
end
738753
end
739-
freshag[v] = c => a
740-
@label SKIP_FRESHAG
754+
merged_ag[v] = c => a
755+
@label SKIP_merged_ag
741756
push!(removed_aliases, a)
742757
end
743758
for (v, (c, a)) in ag
744759
(processed[v] || (!iszero(a) && processed[a])) && continue
745760
v in removed_aliases && continue
746-
freshag[v] = c => a
747-
end
748-
if freshag != ag
749-
ag = freshag
750-
mm = reduce!(copy(echelon_mm), ag)
761+
merged_ag[v] = c => a
751762
end
763+
ag = merged_ag
764+
mm = reduce!(copy(echelon_mm), mm_orig, ag, size(echelon_mm, 1))
752765

753766
# Step 5: Reflect our update decisions back into the graph, and make sure
754767
# that the RHS of observable variables are defined.
@@ -776,7 +789,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
776789
ag = finalag
777790

778791
if needs_update
779-
mm = reduce!(copy(echelon_mm), ag)
792+
mm = reduce!(copy(echelon_mm), mm_orig, ag, size(echelon_mm, 1))
780793
for (ei, e) in enumerate(mm.nzrows)
781794
set_neighbors!(graph, e, mm.row_cols[ei])
782795
end
@@ -803,6 +816,7 @@ function exactdiv(a::Integer, b)
803816
end
804817

805818
function locally_structure_simplify!(adj_row, pivot_var, ag)
819+
# If `pivot_var === nothing`, then we only apply `ag` to `adj_row`
806820
if pivot_var === nothing
807821
pivot_val = nothing
808822
else
@@ -857,7 +871,7 @@ function locally_structure_simplify!(adj_row, pivot_var, ag)
857871
else
858872
dropzeros!(adj_row)
859873
end
860-
return true
874+
return false
861875
end
862876
# If there were only one or two terms left in the equation (including the
863877
# pivot variable). We can eliminate the pivot variable. Note that when

0 commit comments

Comments
 (0)