Skip to content

Commit 25099f8

Browse files
author
Oscar Smith
authored
Fix Barreiss algorithm (#2089)
This has a number of changes. The first is that we were using ÷ instead of exact_div in bareiss_update_virtual_colswap_mtk which led to undetected overflows. We also had a number of bugs that arise when using BigInt precsion in bareiss (e.g. checking `!==0` vs `!=0`. This PR makes it so we first try to bareiss with the regular Int64 matrix, but if that gets an overflow, we fallback to the BigInt version.
1 parent 2efd9e1 commit 25099f8

File tree

2 files changed

+32
-12
lines changed

2 files changed

+32
-12
lines changed

src/systems/alias_elimination.jl

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
using SymbolicUtils: Rewriters
22
using Graphs.Experimental.Traversals
33

4-
const KEEP = typemin(Int)
5-
64
function alias_eliminate_graph!(state::TransformationState; kwargs...)
75
mm = linear_subsys_adjmat!(state; kwargs...)
86
if size(mm, 1) == 0
@@ -225,8 +223,9 @@ the `constraint`.
225223
vertices = eadj[i]
226224
if constraint(length(vertices))
227225
for (j, v) in enumerate(vertices)
228-
(mask === nothing || mask[v]) &&
226+
if (mask === nothing || mask[v])
229227
return (CartesianIndex(i, v), M.row_vals[i][j])
228+
end
230229
end
231230
end
232231
end
@@ -241,7 +240,6 @@ end
241240
row = @view M[i, :]
242241
if constraint(count(!iszero, row))
243242
for (v, val) in enumerate(row)
244-
iszero(val) && continue
245243
if mask === nothing || mask[v]
246244
return CartesianIndex(i, v), val
247245
end
@@ -325,7 +323,8 @@ function Base.setindex!(ag::AliasGraph, v::Integer, i::Integer)
325323
return 0 => 0
326324
end
327325

328-
function Base.setindex!(ag::AliasGraph, p::Union{Pair{Int, Int}, Tuple{Int, Int}},
326+
function Base.setindex!(ag::AliasGraph,
327+
p::Union{Pair{<:Integer, Int}, Tuple{<:Integer, Int}},
329328
i::Integer)
330329
(c, v) = p
331330
if c == 0 || v == 0
@@ -530,7 +529,7 @@ function find_linear_variables(graph, linear_equations, var_to_diff, irreducible
530529
return linear_variables
531530
end
532531

533-
function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
532+
function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL{T, Ti}) where {T, Ti}
534533
mm = copy(mm_orig)
535534
linear_equations_set = BitSet(mm_orig.nzrows)
536535

@@ -554,7 +553,16 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
554553
end
555554
solvable_variables = findall(is_linear_variables)
556555

557-
return mm, solvable_variables, do_bareiss!(mm, mm_orig, is_linear_variables)
556+
local bar
557+
try
558+
bar = do_bareiss!(mm, mm_orig, is_linear_variables)
559+
catch e
560+
e isa OverflowError || rethrow(e)
561+
mm = convert(SparseMatrixCLIL{BigInt, Ti}, mm_orig)
562+
bar = do_bareiss!(mm, mm_orig, is_linear_variables)
563+
end
564+
565+
return mm, solvable_variables, bar
558566
end
559567

560568
function do_bareiss!(M, Mold, is_linear_variables)
@@ -589,6 +597,7 @@ function do_bareiss!(M, Mold, is_linear_variables)
589597
end
590598
bareiss_ops = ((M, i, j) -> nothing, myswaprows!,
591599
bareiss_update_virtual_colswap_mtk!, bareiss_zero!)
600+
592601
rank2, = bareiss!(M, bareiss_ops; find_pivot = find_and_record_pivot)
593602
rank1 = something(rank1r[], rank2)
594603
(rank1, rank2, pivots)
@@ -983,7 +992,7 @@ function locally_structure_simplify!(adj_row, pivot_var, ag)
983992
nirreducible = 0
984993
# When this row only as the pivot element, the pivot is zero by homogeneity
985994
# of the linear system.
986-
alias_candidate::Union{Int, Pair{Int, Int}} = 0
995+
alias_candidate::Union{Int, Pair{eltype(adj_row), Int}} = 0
987996

988997
# N.B.: Assumes that the non-zeros iterator is robust to modification
989998
# of the underlying array datastructure.

src/systems/sparsematrixclil.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,20 @@ function Base.copy(S::SparseMatrixCLIL{T, Ti}) where {T, Ti}
2929
map(copy, S.row_vals))
3030
end
3131
function swaprows!(S::SparseMatrixCLIL, i, j)
32+
i == j && return
3233
swap!(S.nzrows, i, j)
3334
swap!(S.row_cols, i, j)
3435
swap!(S.row_vals, i, j)
3536
end
3637

38+
function Base.convert(::Type{SparseMatrixCLIL{T, Ti}}, S::SparseMatrixCLIL) where {T, Ti}
39+
return SparseMatrixCLIL(S.nparentrows,
40+
S.ncols,
41+
copy.(S.nzrows),
42+
copy.(S.row_cols),
43+
[T.(row) for row in S.row_vals])
44+
end
45+
3746
function SparseMatrixCLIL(mm::AbstractMatrix)
3847
nrows, ncols = size(mm)
3948
row_cols = [findall(!iszero, row) for row in eachrow(mm)]
@@ -59,6 +68,7 @@ function Base.setindex!(S::SparseMatrixCLIL, v::CLILVector, i::Integer, c::Colon
5968
if v.vec.n != S.ncols
6069
throw(BoundsError(v, 1:(S.ncols)))
6170
end
71+
any(iszero, v.vec.nzval) && error("setindex failed")
6272
S.row_cols[i] = copy(v.vec.nzind)
6373
S.row_vals[i] = copy(v.vec.nzval)
6474
return v
@@ -91,7 +101,7 @@ function Base.iterate(nzp::NonZerosPairs{<:CLILVector}, (idx, col))
91101
idx = length(col)
92102
end
93103
oldcol = nzind[idx]
94-
if col !== oldcol
104+
if col != oldcol
95105
# The vector was changed since the last iteration. Find our
96106
# place in the vector again.
97107
tail = col > oldcol ? (@view nzind[(idx + 1):end]) : (@view nzind[1:idx])
@@ -189,13 +199,14 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap
189199
v == vpivot && continue
190200
ck = getcoeff(kvars, kcoeffs, v)
191201
ci = getcoeff(ivars, icoeffs, v)
192-
ci = (pivot * ci - coeff * ck) ÷ last_pivot
193-
if ci !== 0
202+
p1 = Base.Checked.checked_mul(pivot, ci)
203+
p2 = Base.Checked.checked_mul(coeff, ck)
204+
ci = exactdiv(Base.Checked.checked_sub(p1, p2), last_pivot)
205+
if !iszero(ci)
194206
push!(tmp_incidence, v)
195207
push!(tmp_coeffs, ci)
196208
end
197209
end
198-
199210
eadj[ei] = tmp_incidence
200211
old_cadj[ei] = tmp_coeffs
201212
end

0 commit comments

Comments
 (0)