|
| 1 | +# Keeps compatibility with bariess code movoed to Base/stdlib on older releases |
| 2 | + |
| 3 | +using LinearAlgebra |
| 4 | +using SparseArrays |
| 5 | +using SparseArrays: AbstractSparseMatrixCSC |
| 6 | + |
| 7 | +macro swap(a, b) |
| 8 | + esc(:(($a, $b) = ($b, $a))) |
| 9 | +end |
| 10 | + |
| 11 | +function swaprows!(a::AbstractMatrix, i, j) |
| 12 | + i == j && return |
| 13 | + rows = axes(a,1) |
| 14 | + @boundscheck i in rows || throw(BoundsError(a, (:,i))) |
| 15 | + @boundscheck j in rows || throw(BoundsError(a, (:,j))) |
| 16 | + for k in axes(a,2) |
| 17 | + @inbounds a[i,k],a[j,k] = a[j,k],a[i,k] |
| 18 | + end |
| 19 | +end |
| 20 | +function Base.circshift!(a::AbstractVector, shift::Integer) |
| 21 | + n = length(a) |
| 22 | + n == 0 && return |
| 23 | + shift = mod(shift, n) |
| 24 | + shift == 0 && return |
| 25 | + reverse!(a, 1, shift) |
| 26 | + reverse!(a, shift+1, length(a)) |
| 27 | + reverse!(a) |
| 28 | + return a |
| 29 | +end |
| 30 | +function Base.swapcols!(A::AbstractSparseMatrixCSC, i, j) |
| 31 | + i == j && return |
| 32 | + |
| 33 | + # For simplicitly, let i denote the smaller of the two columns |
| 34 | + j < i && @swap(i, j) |
| 35 | + |
| 36 | + colptr = getcolptr(A) |
| 37 | + irow = colptr[i]:(colptr[i+1]-1) |
| 38 | + jrow = colptr[j]:(colptr[j+1]-1) |
| 39 | + |
| 40 | + function rangeexchange!(arr, irow, jrow) |
| 41 | + if length(irow) == length(jrow) |
| 42 | + for (a, b) in zip(irow, jrow) |
| 43 | + @inbounds @swap(arr[i], arr[j]) |
| 44 | + end |
| 45 | + return |
| 46 | + end |
| 47 | + # This is similar to the triple-reverse tricks for |
| 48 | + # circshift!, except that we have three ranges here, |
| 49 | + # so it ends up being 4 reverse calls (but still |
| 50 | + # 2 overall reversals for the memory range). Like |
| 51 | + # circshift!, there's also a cycle chasing algorithm |
| 52 | + # with optimal memory complexity, but the performance |
| 53 | + # tradeoffs against this implementation are non-trivial, |
| 54 | + # so let's just do this simple thing for now. |
| 55 | + # See https://github.com/JuliaLang/julia/pull/42676 for |
| 56 | + # discussion of circshift!-like algorithms. |
| 57 | + reverse!(@view arr[irow]) |
| 58 | + reverse!(@view arr[jrow]) |
| 59 | + reverse!(@view arr[(last(irow)+1):(first(jrow)-1)]) |
| 60 | + reverse!(@view arr[first(irow):last(jrow)]) |
| 61 | + end |
| 62 | + rangeexchange!(rowvals(A), irow, jrow) |
| 63 | + rangeexchange!(nonzeros(A), irow, jrow) |
| 64 | + |
| 65 | + if length(irow) != length(jrow) |
| 66 | + @inbounds colptr[i+1:j] .+= length(jrow) - length(irow) |
| 67 | + end |
| 68 | + return nothing |
| 69 | +end |
| 70 | +function swaprows!(A::AbstractSparseMatrixCSC, i, j) |
| 71 | + # For simplicitly, let i denote the smaller of the two rows |
| 72 | + j < i && @swap(i, j) |
| 73 | + |
| 74 | + rows = rowvals(A) |
| 75 | + vals = nonzeros(A) |
| 76 | + for col = 1:size(A, 2) |
| 77 | + rr = nzrange(A, col) |
| 78 | + iidx = searchsortedfirst(@view(rows[rr]), i) |
| 79 | + has_i = iidx <= length(rr) && rows[rr[iidx]] == i |
| 80 | + |
| 81 | + jrange = has_i ? (iidx:last(rr)) : rr |
| 82 | + jidx = searchsortedlast(@view(rows[jrange]), j) |
| 83 | + has_j = jidx != 0 && rows[jrange[jidx]] == j |
| 84 | + |
| 85 | + if !has_j && !has_i |
| 86 | + # Has neither row - nothing to do |
| 87 | + continue |
| 88 | + elseif has_i && has_j |
| 89 | + # This column had both i and j rows - swap them |
| 90 | + @swap(vals[rr[iidx]], vals[jrange[jidx]]) |
| 91 | + elseif has_i |
| 92 | + # Update the rowval and then rotate both nonzeros |
| 93 | + # and the remaining rowvals into the correct place |
| 94 | + rows[rr[iidx]] = j |
| 95 | + jidx == 0 && continue |
| 96 | + rotate_range = rr[iidx]:jrange[jidx] |
| 97 | + circshift!(@view(vals[rotate_range]), -1) |
| 98 | + circshift!(@view(rows[rotate_range]), -1) |
| 99 | + else |
| 100 | + # Same as i, but in the opposite direction |
| 101 | + @assert has_j |
| 102 | + rows[jrange[jidx]] = i |
| 103 | + iidx > length(rr) && continue |
| 104 | + rotate_range = rr[iidx]:jrange[jidx] |
| 105 | + circshift!(@view(vals[rotate_range]), 1) |
| 106 | + circshift!(@view(rows[rotate_range]), 1) |
| 107 | + end |
| 108 | + end |
| 109 | + return nothing |
| 110 | +end |
| 111 | + |
| 112 | +function bareiss_update!(zero!, M::StridedMatrix, k, swapto, pivot, prev_pivot) |
| 113 | + for i in k+1:size(M, 2), j in k+1:size(M, 1) |
| 114 | + M[j,i] = exactdiv(M[j,i]*pivot - M[j,k]*M[k,i], prev_pivot) |
| 115 | + end |
| 116 | + zero!(M, k+1:size(M, 1), k) |
| 117 | +end |
| 118 | + |
| 119 | +@views function bareiss_update!(zero!, M::AbstractMatrix, k, swapto, pivot, prev_pivot) |
| 120 | + V = M[k+1:end, k+1:end] |
| 121 | + V .= exactdiv.(V .* pivot .- M[k+1:end, k] * M[k, k+1:end]', prev_pivot) |
| 122 | + zero!(M, k+1:size(M, 1), k) |
| 123 | +end |
| 124 | + |
| 125 | +function bareiss_update_virtual_colswap!(zero!, M::AbstractMatrix, k, swapto, pivot, prev_pivot) |
| 126 | + V = @view M[k+1:end, :] |
| 127 | + V .= @views exactdiv.(V .* pivot .- M[k+1:end, swapto[2]] * M[k, :]', prev_pivot) |
| 128 | + zero!(M, k+1:size(M, 1), swapto[2]) |
| 129 | +end |
| 130 | + |
| 131 | +bareiss_zero!(M, i, j) = M[i,j] .= zero(eltype(M)) |
| 132 | + |
| 133 | +function find_pivot_col(M, i) |
| 134 | + p = findfirst(!iszero, @view M[i,i:end]) |
| 135 | + p === nothing && return nothing |
| 136 | + idx = CartesianIndex(i, p + i - 1) |
| 137 | + (idx, M[idx]) |
| 138 | +end |
| 139 | + |
| 140 | +function find_pivot_any(M, i) |
| 141 | + p = findfirst(!iszero, @view M[i:end,i:end]) |
| 142 | + p === nothing && return nothing |
| 143 | + idx = p + CartesianIndex(i - 1, i - 1) |
| 144 | + (idx, M[idx]) |
| 145 | +end |
| 146 | + |
| 147 | +const bareiss_colswap = (Base.swapcols!, swaprows!, bareiss_update!, bareiss_zero!) |
| 148 | +const bareiss_virtcolswap = ((M,i,j)->nothing, swaprows!, bareiss_update_virtual_colswap!, bareiss_zero!) |
| 149 | + |
| 150 | +""" |
| 151 | + bareiss!(M, [swap_strategy]) |
| 152 | +
|
| 153 | +Perform Bareiss's fraction-free row-reduction algorithm on the matrix `M`. |
| 154 | +Optionally, a specific pivoting method may be specified. |
| 155 | +
|
| 156 | +swap_strategy is an optional argument that determines how the swapping of rows and coulmns is performed. |
| 157 | +bareiss_colswap (the default) swaps the columns and rows normally. |
| 158 | +bareiss_virtcolswap pretends to swap the columns which can be faster for sparse matrices. |
| 159 | +""" |
| 160 | +function bareiss!(M::AbstractMatrix, swap_strategy=bareiss_colswap; |
| 161 | + find_pivot=find_pivot_any) |
| 162 | + swapcols!, swaprows!, update!, zero! = swap_strategy; |
| 163 | + prev = one(eltype(M)) |
| 164 | + n = size(M, 1) |
| 165 | + for k in 1:n |
| 166 | + r = find_pivot(M, k) |
| 167 | + r === nothing && return k - 1 |
| 168 | + (swapto, pivot) = r |
| 169 | + if CartesianIndex(k, k) != swapto |
| 170 | + swapcols!(M, k, swapto[2]) |
| 171 | + swaprows!(M, k, swapto[1]) |
| 172 | + end |
| 173 | + update!(zero!, M, k, swapto, pivot, prev) |
| 174 | + prev = pivot |
| 175 | + end |
| 176 | + return n |
| 177 | +end |
0 commit comments