|
2 | 2 |
|
3 | 3 | using LinearAlgebra
|
4 | 4 | using SparseArrays
|
5 |
| -using SparseArrays: AbstractSparseMatrixCSC |
| 5 | +using SparseArrays: AbstractSparseMatrixCSC, getcolptr |
6 | 6 |
|
7 | 7 | macro swap(a, b)
|
8 | 8 | esc(:(($a, $b) = ($b, $a)))
|
9 | 9 | end
|
10 | 10 |
|
11 |
| -function swaprows!(a::AbstractMatrix, i, j) |
12 |
| - i == j && return |
13 |
| - rows = axes(a,1) |
14 |
| - @boundscheck i in rows || throw(BoundsError(a, (:,i))) |
15 |
| - @boundscheck j in rows || throw(BoundsError(a, (:,j))) |
16 |
| - for k in axes(a,2) |
17 |
| - @inbounds a[i,k],a[j,k] = a[j,k],a[i,k] |
| 11 | +# https://github.com/JuliaLang/julia/pull/42678 |
| 12 | +@static if VERSION > v"1.8.0-DEV.762" |
| 13 | + import Base: swaprows! |
| 14 | +else |
| 15 | + function swaprows!(a::AbstractMatrix, i, j) |
| 16 | + i == j && return |
| 17 | + rows = axes(a,1) |
| 18 | + @boundscheck i in rows || throw(BoundsError(a, (:,i))) |
| 19 | + @boundscheck j in rows || throw(BoundsError(a, (:,j))) |
| 20 | + for k in axes(a,2) |
| 21 | + @inbounds a[i,k],a[j,k] = a[j,k],a[i,k] |
| 22 | + end |
18 | 23 | end
|
19 |
| -end |
20 |
| -function Base.circshift!(a::AbstractVector, shift::Integer) |
21 |
| - n = length(a) |
22 |
| - n == 0 && return |
23 |
| - shift = mod(shift, n) |
24 |
| - shift == 0 && return |
25 |
| - reverse!(a, 1, shift) |
26 |
| - reverse!(a, shift+1, length(a)) |
27 |
| - reverse!(a) |
28 |
| - return a |
29 |
| -end |
30 |
| -function Base.swapcols!(A::AbstractSparseMatrixCSC, i, j) |
31 |
| - i == j && return |
32 |
| - |
33 |
| - # For simplicitly, let i denote the smaller of the two columns |
34 |
| - j < i && @swap(i, j) |
35 |
| - |
36 |
| - colptr = getcolptr(A) |
37 |
| - irow = colptr[i]:(colptr[i+1]-1) |
38 |
| - jrow = colptr[j]:(colptr[j+1]-1) |
39 |
| - |
40 |
| - function rangeexchange!(arr, irow, jrow) |
41 |
| - if length(irow) == length(jrow) |
42 |
| - for (a, b) in zip(irow, jrow) |
43 |
| - @inbounds @swap(arr[i], arr[j]) |
| 24 | + function Base.circshift!(a::AbstractVector, shift::Integer) |
| 25 | + n = length(a) |
| 26 | + n == 0 && return |
| 27 | + shift = mod(shift, n) |
| 28 | + shift == 0 && return |
| 29 | + reverse!(a, 1, shift) |
| 30 | + reverse!(a, shift+1, length(a)) |
| 31 | + reverse!(a) |
| 32 | + return a |
| 33 | + end |
| 34 | + function Base.swapcols!(A::AbstractSparseMatrixCSC, i, j) |
| 35 | + i == j && return |
| 36 | + |
| 37 | + # For simplicitly, let i denote the smaller of the two columns |
| 38 | + j < i && @swap(i, j) |
| 39 | + |
| 40 | + colptr = getcolptr(A) |
| 41 | + irow = colptr[i]:(colptr[i+1]-1) |
| 42 | + jrow = colptr[j]:(colptr[j+1]-1) |
| 43 | + |
| 44 | + function rangeexchange!(arr, irow, jrow) |
| 45 | + if length(irow) == length(jrow) |
| 46 | + for (a, b) in zip(irow, jrow) |
| 47 | + @inbounds @swap(arr[i], arr[j]) |
| 48 | + end |
| 49 | + return |
44 | 50 | end
|
45 |
| - return |
| 51 | + # This is similar to the triple-reverse tricks for |
| 52 | + # circshift!, except that we have three ranges here, |
| 53 | + # so it ends up being 4 reverse calls (but still |
| 54 | + # 2 overall reversals for the memory range). Like |
| 55 | + # circshift!, there's also a cycle chasing algorithm |
| 56 | + # with optimal memory complexity, but the performance |
| 57 | + # tradeoffs against this implementation are non-trivial, |
| 58 | + # so let's just do this simple thing for now. |
| 59 | + # See https://github.com/JuliaLang/julia/pull/42676 for |
| 60 | + # discussion of circshift!-like algorithms. |
| 61 | + reverse!(@view arr[irow]) |
| 62 | + reverse!(@view arr[jrow]) |
| 63 | + reverse!(@view arr[(last(irow)+1):(first(jrow)-1)]) |
| 64 | + reverse!(@view arr[first(irow):last(jrow)]) |
46 | 65 | end
|
47 |
| - # This is similar to the triple-reverse tricks for |
48 |
| - # circshift!, except that we have three ranges here, |
49 |
| - # so it ends up being 4 reverse calls (but still |
50 |
| - # 2 overall reversals for the memory range). Like |
51 |
| - # circshift!, there's also a cycle chasing algorithm |
52 |
| - # with optimal memory complexity, but the performance |
53 |
| - # tradeoffs against this implementation are non-trivial, |
54 |
| - # so let's just do this simple thing for now. |
55 |
| - # See https://github.com/JuliaLang/julia/pull/42676 for |
56 |
| - # discussion of circshift!-like algorithms. |
57 |
| - reverse!(@view arr[irow]) |
58 |
| - reverse!(@view arr[jrow]) |
59 |
| - reverse!(@view arr[(last(irow)+1):(first(jrow)-1)]) |
60 |
| - reverse!(@view arr[first(irow):last(jrow)]) |
61 |
| - end |
62 |
| - rangeexchange!(rowvals(A), irow, jrow) |
63 |
| - rangeexchange!(nonzeros(A), irow, jrow) |
| 66 | + rangeexchange!(rowvals(A), irow, jrow) |
| 67 | + rangeexchange!(nonzeros(A), irow, jrow) |
64 | 68 |
|
65 |
| - if length(irow) != length(jrow) |
66 |
| - @inbounds colptr[i+1:j] .+= length(jrow) - length(irow) |
| 69 | + if length(irow) != length(jrow) |
| 70 | + @inbounds colptr[i+1:j] .+= length(jrow) - length(irow) |
| 71 | + end |
| 72 | + return nothing |
67 | 73 | end
|
68 |
| - return nothing |
69 |
| -end |
70 |
| -function swaprows!(A::AbstractSparseMatrixCSC, i, j) |
71 |
| - # For simplicitly, let i denote the smaller of the two rows |
72 |
| - j < i && @swap(i, j) |
73 |
| - |
74 |
| - rows = rowvals(A) |
75 |
| - vals = nonzeros(A) |
76 |
| - for col = 1:size(A, 2) |
77 |
| - rr = nzrange(A, col) |
78 |
| - iidx = searchsortedfirst(@view(rows[rr]), i) |
79 |
| - has_i = iidx <= length(rr) && rows[rr[iidx]] == i |
80 |
| - |
81 |
| - jrange = has_i ? (iidx:last(rr)) : rr |
82 |
| - jidx = searchsortedlast(@view(rows[jrange]), j) |
83 |
| - has_j = jidx != 0 && rows[jrange[jidx]] == j |
84 |
| - |
85 |
| - if !has_j && !has_i |
86 |
| - # Has neither row - nothing to do |
87 |
| - continue |
88 |
| - elseif has_i && has_j |
89 |
| - # This column had both i and j rows - swap them |
90 |
| - @swap(vals[rr[iidx]], vals[jrange[jidx]]) |
91 |
| - elseif has_i |
92 |
| - # Update the rowval and then rotate both nonzeros |
93 |
| - # and the remaining rowvals into the correct place |
94 |
| - rows[rr[iidx]] = j |
95 |
| - jidx == 0 && continue |
96 |
| - rotate_range = rr[iidx]:jrange[jidx] |
97 |
| - circshift!(@view(vals[rotate_range]), -1) |
98 |
| - circshift!(@view(rows[rotate_range]), -1) |
99 |
| - else |
100 |
| - # Same as i, but in the opposite direction |
101 |
| - @assert has_j |
102 |
| - rows[jrange[jidx]] = i |
103 |
| - iidx > length(rr) && continue |
104 |
| - rotate_range = rr[iidx]:jrange[jidx] |
105 |
| - circshift!(@view(vals[rotate_range]), 1) |
106 |
| - circshift!(@view(rows[rotate_range]), 1) |
| 74 | + function swaprows!(A::AbstractSparseMatrixCSC, i, j) |
| 75 | + # For simplicitly, let i denote the smaller of the two rows |
| 76 | + j < i && @swap(i, j) |
| 77 | + |
| 78 | + rows = rowvals(A) |
| 79 | + vals = nonzeros(A) |
| 80 | + for col = 1:size(A, 2) |
| 81 | + rr = nzrange(A, col) |
| 82 | + iidx = searchsortedfirst(@view(rows[rr]), i) |
| 83 | + has_i = iidx <= length(rr) && rows[rr[iidx]] == i |
| 84 | + |
| 85 | + jrange = has_i ? (iidx:last(rr)) : rr |
| 86 | + jidx = searchsortedlast(@view(rows[jrange]), j) |
| 87 | + has_j = jidx != 0 && rows[jrange[jidx]] == j |
| 88 | + |
| 89 | + if !has_j && !has_i |
| 90 | + # Has neither row - nothing to do |
| 91 | + continue |
| 92 | + elseif has_i && has_j |
| 93 | + # This column had both i and j rows - swap them |
| 94 | + @swap(vals[rr[iidx]], vals[jrange[jidx]]) |
| 95 | + elseif has_i |
| 96 | + # Update the rowval and then rotate both nonzeros |
| 97 | + # and the remaining rowvals into the correct place |
| 98 | + rows[rr[iidx]] = j |
| 99 | + jidx == 0 && continue |
| 100 | + rotate_range = rr[iidx]:jrange[jidx] |
| 101 | + circshift!(@view(vals[rotate_range]), -1) |
| 102 | + circshift!(@view(rows[rotate_range]), -1) |
| 103 | + else |
| 104 | + # Same as i, but in the opposite direction |
| 105 | + @assert has_j |
| 106 | + rows[jrange[jidx]] = i |
| 107 | + iidx > length(rr) && continue |
| 108 | + rotate_range = rr[iidx]:jrange[jidx] |
| 109 | + circshift!(@view(vals[rotate_range]), 1) |
| 110 | + circshift!(@view(rows[rotate_range]), 1) |
| 111 | + end |
107 | 112 | end
|
| 113 | + return nothing |
108 | 114 | end
|
109 |
| - return nothing |
110 | 115 | end
|
111 | 116 |
|
112 | 117 | function bareiss_update!(zero!, M::StridedMatrix, k, swapto, pivot, prev_pivot)
|
|
0 commit comments