Skip to content

Commit b7307a6

Browse files
committed
Fix bug in bareiss algorithm
We need to maintain the invariant that the rows of the sparse matrix are sorted, otherwise the bareiss algorithm will be incorrect. Without this, for appropriate choice of pivots, our bareiss implementation turned ``` [1 1 0 1 0 1 0 1 1] ``` into ``` [1 1 0 0 0 -1 0 0 0] ``` rather than the correct ``` [1 1 0 0 1 -1 0 0 0] ``` Fix that and refactor a bit to make this testable.
1 parent 6b8c7ab commit b7307a6

File tree

3 files changed

+44
-12
lines changed

3 files changed

+44
-12
lines changed

src/systems/alias_elimination.jl

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,30 @@ using SymbolicUtils: Rewriters
22

33
const KEEP = typemin(Int)
44

5-
function alias_elimination(sys)
5+
function alias_eliminate_graph(sys::AbstractSystem)
66
sys = initialize_system_structure(sys; quick_cancel=true)
77
s = structure(sys)
88

99
mm = linear_subsys_adjmat(sys)
10-
size(mm, 1) == 0 && return sys # No linear subsystems
10+
size(mm, 1) == 0 && return sys, nothing, mm # No linear subsystems
1111

1212
ag, mm = alias_eliminate_graph!(s.graph, complete(s.var_to_diff), mm)
13+
return sys, ag, mm
14+
end
15+
16+
# For debug purposes
17+
function aag_bareiss(sys::AbstractSystem)
18+
sys = initialize_system_structure(sys; quick_cancel=true)
19+
s = structure(sys)
20+
mm = linear_subsys_adjmat(sys)
21+
return aag_bareiss!(s.graph, complete(s.var_to_diff), mm)
22+
end
23+
24+
function alias_elimination(sys)
25+
sys, ag, mm = alias_eliminate_graph(sys)
26+
ag === nothing && return sys
1327

28+
s = structure(sys)
1429
@unpack fullvars, graph = s
1530

1631
subs = OrderedDict()
@@ -193,9 +208,7 @@ count_nonzeros(a::AbstractArray) = count(!iszero, a)
193208
# Here we have a guarantee that they won't, so we can make this identification
194209
count_nonzeros(a::SparseVector) = nnz(a)
195210

196-
function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
197-
diff_to_var = invview(var_to_diff)
198-
211+
function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
199212
mm = copy(mm_orig)
200213
is_linear_equations = falses(size(AsSubMatrix(mm_orig), 1))
201214
for e in mm_orig.nzrows
@@ -245,10 +258,10 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
245258
(rank1, rank2, rank3, pivots)
246259
end
247260

248-
# mm2 = Array(copy(mm))
249-
# @show do_bareiss!(mm2)
250-
# display(mm2)
261+
return mm, solvable_variables, do_bareiss!(mm, mm_orig)
262+
end
251263

264+
function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
252265
# Step 1: Perform bareiss factorization on the adjacency matrix of the linear
253266
# subsystem of the system we're interested in.
254267
#
@@ -261,7 +274,8 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
261274
# -------------------|------------------------
262275
# rank3 | [ 0 0 | M₃₃ M₃₄ ] [v₃] = [0]
263276
# [ 0 0 | 0 0 ] [v₄] = [0]
264-
(rank1, rank2, rank3, pivots) = do_bareiss!(mm, mm_orig)
277+
mm, solvable_variables, (rank1, rank2, rank3, pivots) =
278+
aag_bareiss!(graph, var_to_diff, mm_orig)
265279

266280
# Step 2: Simplify the system using the Bareiss factorization
267281
ag = AliasGraph(size(mm, 2))
@@ -272,6 +286,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
272286
# Kind of like the backward substitution, but we don't actually rely on it
273287
# being lower triangular. We eliminate a variable if there are at most 2
274288
# variables left after the substitution.
289+
diff_to_var = invview(var_to_diff)
275290
function lss!(ei::Integer)
276291
vi = pivots[ei]
277292
# the lowest differentiated variable can be eliminated

src/systems/sparsematrixclil.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ function swaprows!(S::SparseMatrixCLIL, i, j)
3232
swap!(S.row_vals, i, j)
3333
end
3434

35+
function SparseMatrixCLIL(mm::AbstractMatrix)
36+
nrows, ncols = size(mm)
37+
row_cols = [findall(!iszero, row) for row in eachrow(mm)]
38+
row_vals = [row[cols] for (row, cols) in zip(eachrow(mm), row_cols)]
39+
SparseMatrixCLIL(nrows, ncols, Int[1:length(row_cols);], row_cols, row_vals)
40+
end
41+
3542
struct CLILVector{T, Ti} <: AbstractSparseVector{T, Ti}
3643
vec::SparseVector{T, Ti}
3744
end
@@ -166,7 +173,9 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap
166173

167174
tmp_incidence = similar(eadj[ei], 0)
168175
tmp_coeffs = similar(old_cadj[ei], 0)
169-
vars = union(ivars, kvars)
176+
# TODO: We know both ivars and kvars are sorted, we could just write
177+
# a quick iterator here that does this without allocation/faster.
178+
vars = sort(union(ivars, kvars))
170179

171180
for v in vars
172181
v == vpivot && continue
@@ -197,7 +206,7 @@ struct AsSubMatrix{T, Ti<:Integer} <: AbstractSparseMatrix{T, Ti}
197206
end
198207
Base.size(S::AsSubMatrix) = (S.M.nparentrows, S.M.ncols)
199208

200-
function Base.getindex(S::SparseMatrixCLIL{T}, i1, i2) where {T}
209+
function Base.getindex(S::SparseMatrixCLIL{T}, i1::Integer, i2::Integer) where {T}
201210
checkbounds(S, i1, i2)
202211

203212
col = S.row_cols[i1]
@@ -207,7 +216,7 @@ function Base.getindex(S::SparseMatrixCLIL{T}, i1, i2) where {T}
207216
return S.row_vals[i1][nncol]
208217
end
209218

210-
function Base.getindex(S::AsSubMatrix{T}, i1, i2) where {T}
219+
function Base.getindex(S::AsSubMatrix{T}, i1::Integer, i2::Integer) where {T}
211220
checkbounds(S, i1, i2)
212221
S = S.M
213222

test/structural_transformation/tearing.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,14 @@ eqs = [
139139
0 ~ x + z,
140140
]
141141
@named nlsys = NonlinearSystem(eqs, [x, y, z], [])
142+
let (mm, _, _) = ModelingToolkit.aag_bareiss(nlsys)
143+
@test mm == [
144+
-1 1 0;
145+
0 -1 -1;
146+
0 0 0
147+
]
148+
end
149+
142150
newsys = tearing(nlsys)
143151
@test length(equations(newsys)) == 1
144152

0 commit comments

Comments
 (0)