Skip to content

Commit ad25906

Browse files
add tests for bareiss, fix precompile (#1455)
Co-authored-by: Christopher Rackauckas <[email protected]>
1 parent f74cbcc commit ad25906

File tree

4 files changed

+129
-96
lines changed

4 files changed

+129
-96
lines changed

src/ModelingToolkit.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ Get the set of parameters variables for the given system.
107107
"""
108108
function parameters end
109109

110+
# this has to be included early to deal with depency issues
111+
include("structural_transformation/bareiss.jl")
110112
include("bipartite_graph.jl")
111113
using .BipartiteGraphs
112114

@@ -116,9 +118,6 @@ include("parameters.jl")
116118
include("utils.jl")
117119
include("domains.jl")
118120

119-
# Code that should eventually go elsewhere, but is here for fow
120-
include("structural_transformation/bareiss.jl")
121-
122121
include("systems/abstractsystem.jl")
123122
include("systems/connectors.jl")
124123

src/structural_transformation/bareiss.jl

Lines changed: 98 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -2,111 +2,116 @@
22

33
using LinearAlgebra
44
using SparseArrays
5-
using SparseArrays: AbstractSparseMatrixCSC
5+
using SparseArrays: AbstractSparseMatrixCSC, getcolptr
66

77
macro swap(a, b)
88
esc(:(($a, $b) = ($b, $a)))
99
end
1010

11-
function swaprows!(a::AbstractMatrix, i, j)
12-
i == j && return
13-
rows = axes(a,1)
14-
@boundscheck i in rows || throw(BoundsError(a, (:,i)))
15-
@boundscheck j in rows || throw(BoundsError(a, (:,j)))
16-
for k in axes(a,2)
17-
@inbounds a[i,k],a[j,k] = a[j,k],a[i,k]
11+
# https://github.com/JuliaLang/julia/pull/42678
12+
@static if VERSION > v"1.8.0-DEV.762"
13+
import Base: swaprows!
14+
else
15+
function swaprows!(a::AbstractMatrix, i, j)
16+
i == j && return
17+
rows = axes(a,1)
18+
@boundscheck i in rows || throw(BoundsError(a, (:,i)))
19+
@boundscheck j in rows || throw(BoundsError(a, (:,j)))
20+
for k in axes(a,2)
21+
@inbounds a[i,k],a[j,k] = a[j,k],a[i,k]
22+
end
1823
end
19-
end
20-
function Base.circshift!(a::AbstractVector, shift::Integer)
21-
n = length(a)
22-
n == 0 && return
23-
shift = mod(shift, n)
24-
shift == 0 && return
25-
reverse!(a, 1, shift)
26-
reverse!(a, shift+1, length(a))
27-
reverse!(a)
28-
return a
29-
end
30-
function Base.swapcols!(A::AbstractSparseMatrixCSC, i, j)
31-
i == j && return
32-
33-
# For simplicitly, let i denote the smaller of the two columns
34-
j < i && @swap(i, j)
35-
36-
colptr = getcolptr(A)
37-
irow = colptr[i]:(colptr[i+1]-1)
38-
jrow = colptr[j]:(colptr[j+1]-1)
39-
40-
function rangeexchange!(arr, irow, jrow)
41-
if length(irow) == length(jrow)
42-
for (a, b) in zip(irow, jrow)
43-
@inbounds @swap(arr[i], arr[j])
24+
function Base.circshift!(a::AbstractVector, shift::Integer)
25+
n = length(a)
26+
n == 0 && return
27+
shift = mod(shift, n)
28+
shift == 0 && return
29+
reverse!(a, 1, shift)
30+
reverse!(a, shift+1, length(a))
31+
reverse!(a)
32+
return a
33+
end
34+
function Base.swapcols!(A::AbstractSparseMatrixCSC, i, j)
35+
i == j && return
36+
37+
# For simplicitly, let i denote the smaller of the two columns
38+
j < i && @swap(i, j)
39+
40+
colptr = getcolptr(A)
41+
irow = colptr[i]:(colptr[i+1]-1)
42+
jrow = colptr[j]:(colptr[j+1]-1)
43+
44+
function rangeexchange!(arr, irow, jrow)
45+
if length(irow) == length(jrow)
46+
for (a, b) in zip(irow, jrow)
47+
@inbounds @swap(arr[i], arr[j])
48+
end
49+
return
4450
end
45-
return
51+
# This is similar to the triple-reverse tricks for
52+
# circshift!, except that we have three ranges here,
53+
# so it ends up being 4 reverse calls (but still
54+
# 2 overall reversals for the memory range). Like
55+
# circshift!, there's also a cycle chasing algorithm
56+
# with optimal memory complexity, but the performance
57+
# tradeoffs against this implementation are non-trivial,
58+
# so let's just do this simple thing for now.
59+
# See https://github.com/JuliaLang/julia/pull/42676 for
60+
# discussion of circshift!-like algorithms.
61+
reverse!(@view arr[irow])
62+
reverse!(@view arr[jrow])
63+
reverse!(@view arr[(last(irow)+1):(first(jrow)-1)])
64+
reverse!(@view arr[first(irow):last(jrow)])
4665
end
47-
# This is similar to the triple-reverse tricks for
48-
# circshift!, except that we have three ranges here,
49-
# so it ends up being 4 reverse calls (but still
50-
# 2 overall reversals for the memory range). Like
51-
# circshift!, there's also a cycle chasing algorithm
52-
# with optimal memory complexity, but the performance
53-
# tradeoffs against this implementation are non-trivial,
54-
# so let's just do this simple thing for now.
55-
# See https://github.com/JuliaLang/julia/pull/42676 for
56-
# discussion of circshift!-like algorithms.
57-
reverse!(@view arr[irow])
58-
reverse!(@view arr[jrow])
59-
reverse!(@view arr[(last(irow)+1):(first(jrow)-1)])
60-
reverse!(@view arr[first(irow):last(jrow)])
61-
end
62-
rangeexchange!(rowvals(A), irow, jrow)
63-
rangeexchange!(nonzeros(A), irow, jrow)
66+
rangeexchange!(rowvals(A), irow, jrow)
67+
rangeexchange!(nonzeros(A), irow, jrow)
6468

65-
if length(irow) != length(jrow)
66-
@inbounds colptr[i+1:j] .+= length(jrow) - length(irow)
69+
if length(irow) != length(jrow)
70+
@inbounds colptr[i+1:j] .+= length(jrow) - length(irow)
71+
end
72+
return nothing
6773
end
68-
return nothing
69-
end
70-
function swaprows!(A::AbstractSparseMatrixCSC, i, j)
71-
# For simplicitly, let i denote the smaller of the two rows
72-
j < i && @swap(i, j)
73-
74-
rows = rowvals(A)
75-
vals = nonzeros(A)
76-
for col = 1:size(A, 2)
77-
rr = nzrange(A, col)
78-
iidx = searchsortedfirst(@view(rows[rr]), i)
79-
has_i = iidx <= length(rr) && rows[rr[iidx]] == i
80-
81-
jrange = has_i ? (iidx:last(rr)) : rr
82-
jidx = searchsortedlast(@view(rows[jrange]), j)
83-
has_j = jidx != 0 && rows[jrange[jidx]] == j
84-
85-
if !has_j && !has_i
86-
# Has neither row - nothing to do
87-
continue
88-
elseif has_i && has_j
89-
# This column had both i and j rows - swap them
90-
@swap(vals[rr[iidx]], vals[jrange[jidx]])
91-
elseif has_i
92-
# Update the rowval and then rotate both nonzeros
93-
# and the remaining rowvals into the correct place
94-
rows[rr[iidx]] = j
95-
jidx == 0 && continue
96-
rotate_range = rr[iidx]:jrange[jidx]
97-
circshift!(@view(vals[rotate_range]), -1)
98-
circshift!(@view(rows[rotate_range]), -1)
99-
else
100-
# Same as i, but in the opposite direction
101-
@assert has_j
102-
rows[jrange[jidx]] = i
103-
iidx > length(rr) && continue
104-
rotate_range = rr[iidx]:jrange[jidx]
105-
circshift!(@view(vals[rotate_range]), 1)
106-
circshift!(@view(rows[rotate_range]), 1)
74+
function swaprows!(A::AbstractSparseMatrixCSC, i, j)
75+
# For simplicitly, let i denote the smaller of the two rows
76+
j < i && @swap(i, j)
77+
78+
rows = rowvals(A)
79+
vals = nonzeros(A)
80+
for col = 1:size(A, 2)
81+
rr = nzrange(A, col)
82+
iidx = searchsortedfirst(@view(rows[rr]), i)
83+
has_i = iidx <= length(rr) && rows[rr[iidx]] == i
84+
85+
jrange = has_i ? (iidx:last(rr)) : rr
86+
jidx = searchsortedlast(@view(rows[jrange]), j)
87+
has_j = jidx != 0 && rows[jrange[jidx]] == j
88+
89+
if !has_j && !has_i
90+
# Has neither row - nothing to do
91+
continue
92+
elseif has_i && has_j
93+
# This column had both i and j rows - swap them
94+
@swap(vals[rr[iidx]], vals[jrange[jidx]])
95+
elseif has_i
96+
# Update the rowval and then rotate both nonzeros
97+
# and the remaining rowvals into the correct place
98+
rows[rr[iidx]] = j
99+
jidx == 0 && continue
100+
rotate_range = rr[iidx]:jrange[jidx]
101+
circshift!(@view(vals[rotate_range]), -1)
102+
circshift!(@view(rows[rotate_range]), -1)
103+
else
104+
# Same as i, but in the opposite direction
105+
@assert has_j
106+
rows[jrange[jidx]] = i
107+
iidx > length(rr) && continue
108+
rotate_range = rr[iidx]:jrange[jidx]
109+
circshift!(@view(vals[rotate_range]), 1)
110+
circshift!(@view(rows[rotate_range]), 1)
111+
end
107112
end
113+
return nothing
108114
end
109-
return nothing
110115
end
111116

112117
function bareiss_update!(zero!, M::StridedMatrix, k, swapto, pivot, prev_pivot)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
using SparseArrays
2+
using ModelingToolkit
3+
import ModelingToolkit: bareiss!, find_pivot_col, bareiss_update!, swaprows!
4+
import Base: swapcols!
5+
6+
function det_bareiss!(M)
7+
parity = 1
8+
_swaprows!(M, i, j) = (i != j && (parity = -parity); swaprows!(M, i, j))
9+
_swapcols!(M, i, j) = (i != j && (parity = -parity); swapcols!(M, i, j))
10+
# We only look at the last entry, so we don't care that the sub-diagonals are
11+
# garbage.
12+
zero!(M, i, j) = nothing
13+
rank = bareiss!(M, (_swapcols!, _swaprows!, bareiss_update!, zero!);
14+
find_pivot=find_pivot_col)
15+
return parity * M[end,end]
16+
end
17+
18+
@testset "bareiss tests" begin
19+
# copy gives a dense matrix
20+
@testset "bareiss tests: $T" for T in (copy, sparse)
21+
# matrix determinent pairs
22+
for (M, d) in ((BigInt[9 1 8 0; 0 0 8 7; 7 6 8 3; 2 9 7 7], -1),
23+
(BigInt[1 big(2)^65+1; 3 4], 4-3*(big(2)^65+1)))
24+
# test that the determinent was correctly computed
25+
@test det_bareiss!(T(M)) == d
26+
end
27+
end
28+
end

test/structural_transformation/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ using SafeTestsets
33
@safetestset "Utilities" begin include("utils.jl") end
44
@safetestset "Index Reduction & SCC" begin include("index_reduction.jl") end
55
@safetestset "Tearing" begin include("tearing.jl") end
6+
@safetestset "Bareiss" begin include("bareiss.jl") end

0 commit comments

Comments
 (0)