Skip to content

Commit 6cfb1af

Browse files
authored
Merge pull request #1380 from Keno/kf/bareissbug
Fix bug in bareiss algorithm
2 parents 9cc6f21 + b7307a6 commit 6cfb1af

File tree

3 files changed

+44
-12
lines changed

3 files changed

+44
-12
lines changed

src/systems/alias_elimination.jl

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,30 @@ using SymbolicUtils: Rewriters
22

33
const KEEP = typemin(Int)
44

5-
function alias_elimination(sys)
5+
function alias_eliminate_graph(sys::AbstractSystem)
66
sys = initialize_system_structure(sys; quick_cancel=true)
77
s = structure(sys)
88

99
mm = linear_subsys_adjmat(sys)
10-
size(mm, 1) == 0 && return sys # No linear subsystems
10+
size(mm, 1) == 0 && return sys, nothing, mm # No linear subsystems
1111

1212
ag, mm = alias_eliminate_graph!(s.graph, complete(s.var_to_diff), mm)
13+
return sys, ag, mm
14+
end
15+
16+
# For debug purposes
17+
function aag_bareiss(sys::AbstractSystem)
18+
sys = initialize_system_structure(sys; quick_cancel=true)
19+
s = structure(sys)
20+
mm = linear_subsys_adjmat(sys)
21+
return aag_bareiss!(s.graph, complete(s.var_to_diff), mm)
22+
end
23+
24+
function alias_elimination(sys)
25+
sys, ag, mm = alias_eliminate_graph(sys)
26+
ag === nothing && return sys
1327

28+
s = structure(sys)
1429
@unpack fullvars, graph = s
1530

1631
subs = OrderedDict()
@@ -193,9 +208,7 @@ count_nonzeros(a::AbstractArray) = count(!iszero, a)
193208
# Here we have a guarantee that they won't, so we can make this identification
194209
count_nonzeros(a::SparseVector) = nnz(a)
195210

196-
function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
197-
diff_to_var = invview(var_to_diff)
198-
211+
function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
199212
mm = copy(mm_orig)
200213
is_linear_equations = falses(size(AsSubMatrix(mm_orig), 1))
201214
for e in mm_orig.nzrows
@@ -245,10 +258,10 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
245258
(rank1, rank2, rank3, pivots)
246259
end
247260

248-
# mm2 = Array(copy(mm))
249-
# @show do_bareiss!(mm2)
250-
# display(mm2)
261+
return mm, solvable_variables, do_bareiss!(mm, mm_orig)
262+
end
251263

264+
function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
252265
# Step 1: Perform bareiss factorization on the adjacency matrix of the linear
253266
# subsystem of the system we're interested in.
254267
#
@@ -261,7 +274,8 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
261274
# -------------------|------------------------
262275
# rank3 | [ 0 0 | M₃₃ M₃₄ ] [v₃] = [0]
263276
# [ 0 0 | 0 0 ] [v₄] = [0]
264-
(rank1, rank2, rank3, pivots) = do_bareiss!(mm, mm_orig)
277+
mm, solvable_variables, (rank1, rank2, rank3, pivots) =
278+
aag_bareiss!(graph, var_to_diff, mm_orig)
265279

266280
# Step 2: Simplify the system using the Bareiss factorization
267281
ag = AliasGraph(size(mm, 2))
@@ -272,6 +286,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
272286
# Kind of like the backward substitution, but we don't actually rely on it
273287
# being lower triangular. We eliminate a variable if there are at most 2
274288
# variables left after the substitution.
289+
diff_to_var = invview(var_to_diff)
275290
function lss!(ei::Integer)
276291
vi = pivots[ei]
277292
# the lowest differentiated variable can be eliminated

src/systems/sparsematrixclil.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ function swaprows!(S::SparseMatrixCLIL, i, j)
3232
swap!(S.row_vals, i, j)
3333
end
3434

35+
function SparseMatrixCLIL(mm::AbstractMatrix)
36+
nrows, ncols = size(mm)
37+
row_cols = [findall(!iszero, row) for row in eachrow(mm)]
38+
row_vals = [row[cols] for (row, cols) in zip(eachrow(mm), row_cols)]
39+
SparseMatrixCLIL(nrows, ncols, Int[1:length(row_cols);], row_cols, row_vals)
40+
end
41+
3542
struct CLILVector{T, Ti} <: AbstractSparseVector{T, Ti}
3643
vec::SparseVector{T, Ti}
3744
end
@@ -166,7 +173,9 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap
166173

167174
tmp_incidence = similar(eadj[ei], 0)
168175
tmp_coeffs = similar(old_cadj[ei], 0)
169-
vars = union(ivars, kvars)
176+
# TODO: We know both ivars and kvars are sorted, we could just write
177+
# a quick iterator here that does this without allocation/faster.
178+
vars = sort(union(ivars, kvars))
170179

171180
for v in vars
172181
v == vpivot && continue
@@ -197,7 +206,7 @@ struct AsSubMatrix{T, Ti<:Integer} <: AbstractSparseMatrix{T, Ti}
197206
end
198207
Base.size(S::AsSubMatrix) = (S.M.nparentrows, S.M.ncols)
199208

200-
function Base.getindex(S::SparseMatrixCLIL{T}, i1, i2) where {T}
209+
function Base.getindex(S::SparseMatrixCLIL{T}, i1::Integer, i2::Integer) where {T}
201210
checkbounds(S, i1, i2)
202211

203212
col = S.row_cols[i1]
@@ -207,7 +216,7 @@ function Base.getindex(S::SparseMatrixCLIL{T}, i1, i2) where {T}
207216
return S.row_vals[i1][nncol]
208217
end
209218

210-
function Base.getindex(S::AsSubMatrix{T}, i1, i2) where {T}
219+
function Base.getindex(S::AsSubMatrix{T}, i1::Integer, i2::Integer) where {T}
211220
checkbounds(S, i1, i2)
212221
S = S.M
213222

test/structural_transformation/tearing.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,14 @@ eqs = [
139139
0 ~ x + z,
140140
]
141141
@named nlsys = NonlinearSystem(eqs, [x, y, z], [])
142+
let (mm, _, _) = ModelingToolkit.aag_bareiss(nlsys)
143+
@test mm == [
144+
-1 1 0;
145+
0 -1 -1;
146+
0 0 0
147+
]
148+
end
149+
142150
newsys = tearing(nlsys)
143151
@test length(equations(newsys)) == 1
144152

0 commit comments

Comments
 (0)