Skip to content

Commit debb50c

Browse files
oscardssmithoscarddssmith
andauthored
move bareiss (#1449)
Co-authored-by: oscarddssmith <[email protected]>
1 parent beb76c0 commit debb50c

File tree

3 files changed

+178
-182
lines changed

3 files changed

+178
-182
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ include("utils.jl")
117117
include("domains.jl")
118118

119119
# Code that should eventually go elsewhere, but is here for fow
120-
include("compat/bareiss.jl")
120+
include("structural_transformation/bareiss.jl")
121121
if !isdefined(Graphs, :IncrementalCycleTracker)
122122
include("compat/incremental_cycles.jl")
123123
end

src/compat/bareiss.jl

Lines changed: 0 additions & 181 deletions
This file was deleted.
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# Keeps compatibility with bariess code movoed to Base/stdlib on older releases
2+
3+
using LinearAlgebra
4+
using SparseArrays
5+
using SparseArrays: AbstractSparseMatrixCSC
6+
7+
macro swap(a, b)
8+
esc(:(($a, $b) = ($b, $a)))
9+
end
10+
11+
function swaprows!(a::AbstractMatrix, i, j)
12+
i == j && return
13+
rows = axes(a,1)
14+
@boundscheck i in rows || throw(BoundsError(a, (:,i)))
15+
@boundscheck j in rows || throw(BoundsError(a, (:,j)))
16+
for k in axes(a,2)
17+
@inbounds a[i,k],a[j,k] = a[j,k],a[i,k]
18+
end
19+
end
20+
function Base.circshift!(a::AbstractVector, shift::Integer)
21+
n = length(a)
22+
n == 0 && return
23+
shift = mod(shift, n)
24+
shift == 0 && return
25+
reverse!(a, 1, shift)
26+
reverse!(a, shift+1, length(a))
27+
reverse!(a)
28+
return a
29+
end
30+
function Base.swapcols!(A::AbstractSparseMatrixCSC, i, j)
31+
i == j && return
32+
33+
# For simplicitly, let i denote the smaller of the two columns
34+
j < i && @swap(i, j)
35+
36+
colptr = getcolptr(A)
37+
irow = colptr[i]:(colptr[i+1]-1)
38+
jrow = colptr[j]:(colptr[j+1]-1)
39+
40+
function rangeexchange!(arr, irow, jrow)
41+
if length(irow) == length(jrow)
42+
for (a, b) in zip(irow, jrow)
43+
@inbounds @swap(arr[i], arr[j])
44+
end
45+
return
46+
end
47+
# This is similar to the triple-reverse tricks for
48+
# circshift!, except that we have three ranges here,
49+
# so it ends up being 4 reverse calls (but still
50+
# 2 overall reversals for the memory range). Like
51+
# circshift!, there's also a cycle chasing algorithm
52+
# with optimal memory complexity, but the performance
53+
# tradeoffs against this implementation are non-trivial,
54+
# so let's just do this simple thing for now.
55+
# See https://github.com/JuliaLang/julia/pull/42676 for
56+
# discussion of circshift!-like algorithms.
57+
reverse!(@view arr[irow])
58+
reverse!(@view arr[jrow])
59+
reverse!(@view arr[(last(irow)+1):(first(jrow)-1)])
60+
reverse!(@view arr[first(irow):last(jrow)])
61+
end
62+
rangeexchange!(rowvals(A), irow, jrow)
63+
rangeexchange!(nonzeros(A), irow, jrow)
64+
65+
if length(irow) != length(jrow)
66+
@inbounds colptr[i+1:j] .+= length(jrow) - length(irow)
67+
end
68+
return nothing
69+
end
70+
function swaprows!(A::AbstractSparseMatrixCSC, i, j)
71+
# For simplicitly, let i denote the smaller of the two rows
72+
j < i && @swap(i, j)
73+
74+
rows = rowvals(A)
75+
vals = nonzeros(A)
76+
for col = 1:size(A, 2)
77+
rr = nzrange(A, col)
78+
iidx = searchsortedfirst(@view(rows[rr]), i)
79+
has_i = iidx <= length(rr) && rows[rr[iidx]] == i
80+
81+
jrange = has_i ? (iidx:last(rr)) : rr
82+
jidx = searchsortedlast(@view(rows[jrange]), j)
83+
has_j = jidx != 0 && rows[jrange[jidx]] == j
84+
85+
if !has_j && !has_i
86+
# Has neither row - nothing to do
87+
continue
88+
elseif has_i && has_j
89+
# This column had both i and j rows - swap them
90+
@swap(vals[rr[iidx]], vals[jrange[jidx]])
91+
elseif has_i
92+
# Update the rowval and then rotate both nonzeros
93+
# and the remaining rowvals into the correct place
94+
rows[rr[iidx]] = j
95+
jidx == 0 && continue
96+
rotate_range = rr[iidx]:jrange[jidx]
97+
circshift!(@view(vals[rotate_range]), -1)
98+
circshift!(@view(rows[rotate_range]), -1)
99+
else
100+
# Same as i, but in the opposite direction
101+
@assert has_j
102+
rows[jrange[jidx]] = i
103+
iidx > length(rr) && continue
104+
rotate_range = rr[iidx]:jrange[jidx]
105+
circshift!(@view(vals[rotate_range]), 1)
106+
circshift!(@view(rows[rotate_range]), 1)
107+
end
108+
end
109+
return nothing
110+
end
111+
112+
function bareiss_update!(zero!, M::StridedMatrix, k, swapto, pivot, prev_pivot)
113+
for i in k+1:size(M, 2), j in k+1:size(M, 1)
114+
M[j,i] = exactdiv(M[j,i]*pivot - M[j,k]*M[k,i], prev_pivot)
115+
end
116+
zero!(M, k+1:size(M, 1), k)
117+
end
118+
119+
@views function bareiss_update!(zero!, M::AbstractMatrix, k, swapto, pivot, prev_pivot)
120+
V = M[k+1:end, k+1:end]
121+
V .= exactdiv.(V .* pivot .- M[k+1:end, k] * M[k, k+1:end]', prev_pivot)
122+
zero!(M, k+1:size(M, 1), k)
123+
end
124+
125+
function bareiss_update_virtual_colswap!(zero!, M::AbstractMatrix, k, swapto, pivot, prev_pivot)
126+
V = @view M[k+1:end, :]
127+
V .= @views exactdiv.(V .* pivot .- M[k+1:end, swapto[2]] * M[k, :]', prev_pivot)
128+
zero!(M, k+1:size(M, 1), swapto[2])
129+
end
130+
131+
bareiss_zero!(M, i, j) = M[i,j] .= zero(eltype(M))
132+
133+
function find_pivot_col(M, i)
134+
p = findfirst(!iszero, @view M[i,i:end])
135+
p === nothing && return nothing
136+
idx = CartesianIndex(i, p + i - 1)
137+
(idx, M[idx])
138+
end
139+
140+
function find_pivot_any(M, i)
141+
p = findfirst(!iszero, @view M[i:end,i:end])
142+
p === nothing && return nothing
143+
idx = p + CartesianIndex(i - 1, i - 1)
144+
(idx, M[idx])
145+
end
146+
147+
const bareiss_colswap = (Base.swapcols!, swaprows!, bareiss_update!, bareiss_zero!)
148+
const bareiss_virtcolswap = ((M,i,j)->nothing, swaprows!, bareiss_update_virtual_colswap!, bareiss_zero!)
149+
150+
"""
151+
bareiss!(M, [swap_strategy])
152+
153+
Perform Bareiss's fraction-free row-reduction algorithm on the matrix `M`.
154+
Optionally, a specific pivoting method may be specified.
155+
156+
swap_strategy is an optional argument that determines how the swapping of rows and coulmns is performed.
157+
bareiss_colswap (the default) swaps the columns and rows normally.
158+
bareiss_virtcolswap pretends to swap the columns which can be faster for sparse matrices.
159+
"""
160+
function bareiss!(M::AbstractMatrix, swap_strategy=bareiss_colswap;
161+
find_pivot=find_pivot_any)
162+
swapcols!, swaprows!, update!, zero! = swap_strategy;
163+
prev = one(eltype(M))
164+
n = size(M, 1)
165+
for k in 1:n
166+
r = find_pivot(M, k)
167+
r === nothing && return k - 1
168+
(swapto, pivot) = r
169+
if CartesianIndex(k, k) != swapto
170+
swapcols!(M, k, swapto[2])
171+
swaprows!(M, k, swapto[1])
172+
end
173+
update!(zero!, M, k, swapto, pivot, prev)
174+
prev = pivot
175+
end
176+
return n
177+
end

0 commit comments

Comments
 (0)