Skip to content

Commit 3e8a6b4

Browse files
committed
Split Bareiss
1 parent 82cac6f commit 3e8a6b4

File tree

1 file changed

+32
-66
lines changed

1 file changed

+32
-66
lines changed

src/systems/alias_elimination.jl

Lines changed: 32 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -420,25 +420,23 @@ count_nonzeros(a::AbstractArray) = count(!iszero, a)
420420
# Here we have a guarantee that they won't, so we can make this identification
421421
count_nonzeros(a::SparseVector) = nnz(a)
422422

423-
function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL,
424-
only_linear_algebraic = false, irreducibles = ())
423+
function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL, only_algebraic,
424+
irreducibles)
425425
mm = copy(mm_orig)
426426
is_linear_equations = falses(size(AsSubMatrix(mm_orig), 1))
427-
diff_to_var = invview(var_to_diff)
428-
islowest = let diff_to_var = diff_to_var
429-
v -> diff_to_var[v] === nothing
430-
end
431-
for e in mm_orig.nzrows
432-
is_linear_equations[e] = all(islowest, 𝑠neighbors(graph, e))
433-
end
434-
435-
var_to_eq = let is_linear_equations = is_linear_equations, islowest = islowest,
436-
irreducibles = irreducibles
437427

438-
maximal_matching(graph, eq -> is_linear_equations[eq],
439-
var -> islowest(var) && !(var in irreducibles))
428+
is_not_potential_state = isnothing.(var_to_diff)
429+
for v in irreducibles
430+
is_not_potential_state[v] = false
431+
end
432+
is_linear_variables = only_algebraic ? copy(is_not_potential_state) :
433+
is_not_potential_state
434+
for i in 𝑠vertices(graph)
435+
is_linear_equations[i] && continue
436+
for j in 𝑠neighbors(graph, i)
437+
is_linear_variables[j] = false
438+
end
440439
end
441-
is_linear_variables = isa.(var_to_eq, Int)
442440
solvable_variables = findall(is_linear_variables)
443441

444442
function do_bareiss!(M, Mold = nothing)
@@ -450,7 +448,13 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL,
450448
r !== nothing && return r
451449
rank1 = k - 1
452450
end
453-
only_linear_algebraic && return nothing
451+
if only_algebraic
452+
if rank2 === nothing
453+
r = find_masked_pivot(is_not_potential_state, M, k)
454+
r !== nothing && return r
455+
rank2 = k - 1
456+
end
457+
end
454458
# TODO: It would be better to sort the variables by
455459
# derivative order here to enable more elimination
456460
# opportunities.
@@ -468,9 +472,10 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL,
468472
end
469473
bareiss_ops = ((M, i, j) -> nothing, myswaprows!,
470474
bareiss_update_virtual_colswap_mtk!, bareiss_zero!)
471-
rank2, = bareiss!(M, bareiss_ops; find_pivot = find_and_record_pivot)
475+
rank3, = bareiss!(M, bareiss_ops; find_pivot = find_and_record_pivot)
472476
rank1 = something(rank1, rank2)
473-
(rank1, rank2, pivots)
477+
rank2 = something(rank2, rank3)
478+
(rank1, rank2, rank3, pivots)
474479
end
475480

476481
return mm, solvable_variables, do_bareiss!(mm, mm_orig)
@@ -486,8 +491,7 @@ function lss(mm, pivots, ag)
486491
end
487492
end
488493

489-
function simple_aliases!(ag, graph, var_to_diff, mm_orig, only_linear_algebraic = false,
490-
irreducibles = ())
494+
function simple_aliases!(ag, graph, var_to_diff, mm_orig, only_algebraic, irreducibles = ())
491495
# Let `m = the number of linear equations` and `n = the number of
492496
# variables`.
493497
#
@@ -503,23 +507,15 @@ function simple_aliases!(ag, graph, var_to_diff, mm_orig, only_linear_algebraic
503507
# contribute to the equations, but are not solved by the linear system. Note
504508
# that the complete system may be larger than the linear subsystem and
505509
# include variables that do not appear here.
506-
mm, solvable_variables, (rank1, rank2, pivots) = aag_bareiss!(graph, var_to_diff,
507-
mm_orig,
508-
only_linear_algebraic,
509-
irreducibles)
510+
mm, solvable_variables, (rank1, rank2, rank3, pivots) = aag_bareiss!(graph, var_to_diff,
511+
mm_orig,
512+
only_algebraic,
513+
irreducibles)
510514

511515
# Step 2: Simplify the system using the Bareiss factorization
512516
ks = keys(ag)
513-
if !only_linear_algebraic
514-
for v in setdiff(solvable_variables, @view pivots[1:rank1])
515-
ag[v] = 0
516-
end
517-
else
518-
for v in setdiff(solvable_variables, @view pivots[1:rank1])
519-
if !(v in ks)
520-
ag[v] = 0
521-
end
522-
end
517+
for v in setdiff(solvable_variables, @view pivots[1:rank1])
518+
ag[v] = 0
523519
end
524520

525521
lss! = lss(mm, pivots, ag)
@@ -558,7 +554,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL;
558554
#
559555
nvars = ndsts(graph)
560556
ag = AliasGraph(nvars)
561-
mm = simple_aliases!(ag, graph, var_to_diff, mm_orig)
557+
mm = simple_aliases!(ag, graph, var_to_diff, mm_orig, false)
562558

563559
# Step 3: Handle differentiated variables
564560
# At this point, `var_to_diff` and `ag` form a tree structure like the
@@ -665,39 +661,9 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL;
665661
end
666662
end
667663

668-
# There might be "cycles" like `D(x) = x`
669-
for v in irreducibles
670-
if v in keys(newag)
671-
newag[v] = nothing
672-
end
673-
if v in keys(ag)
674-
ag[v] = nothing
675-
end
676-
end
677-
for v in keys(ag)
678-
push!(irreducibles, v)
679-
end
680-
681-
for (v, (c, a)) in newag
682-
va = iszero(a) ? a : fullvars[a]
683-
@info "new alias" fullvars[v]=>(c, va)
684-
end
685-
newkeys = keys(newag)
686664
if !isempty(irreducibles)
687-
for (v, (c, a)) in ag
688-
(a in irreducibles || v in irreducibles) && continue
689-
if iszero(c)
690-
newag[v] = c
691-
else
692-
newag[v] = c => a
693-
end
694-
end
695665
ag = newag
696-
697-
# We cannot use the `mm` from Bareiss because it doesn't consider
698-
# irreducibles
699-
mm_orig2 = reduce!(copy(mm_orig), ag)
700-
mm = mm_orig2
666+
mm_orig2 = isempty(ag) ? mm_orig : reduce!(copy(mm_orig), ag)
701667
mm = simple_aliases!(ag, graph, var_to_diff, mm_orig2, true, irreducibles)
702668
end
703669

0 commit comments

Comments
 (0)