Skip to content

Commit 3fd2871

Browse files
YingboMaKeno
andauthored
Add some comments in alias elimination (#1360)
Co-authored-by: Keno Fischer <[email protected]>
1 parent 5d9021d commit 3fd2871

File tree

1 file changed

+42
-14
lines changed

1 file changed

+42
-14
lines changed

src/systems/alias_elimination.jl

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ function alias_elimination(sys)
2323
for (ei, e) in enumerate(mm.nzrows)
2424
vs = 𝑠neighbors(graph, e)
2525
if isempty(vs)
26+
# remove empty equations
2627
push!(dels, e)
2728
else
2829
rhs = mapfoldl(+, pairs(nonzerosmap(@view mm[ei, :]))) do (var, coeff)
@@ -135,7 +136,7 @@ function Base.getindex(ag::AliasGraph, i::Integer)
135136
coeff, var = (sign(r), abs(r))
136137
if var in keys(ag)
137138
# Amortized lookup. Check if since we last looked this up, our alias was
138-
# itself aliased. If so, adjust adjust the alias table.
139+
# itself aliased. If so, just adjust the alias table.
139140
ac, av = ag[var]
140141
nc = ac * coeff
141142
ag.aliasto[var] = nc > 0 ? av : -av
@@ -201,6 +202,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
201202
is_linear_equations[e] = true
202203
end
203204

205+
# Variables that are highest order differentiated cannot be states of an ODE
204206
is_not_potential_state = isnothing.(var_to_diff)
205207
is_linear_variables = copy(is_not_potential_state)
206208
for i in 𝑠vertices(graph); is_linear_equations[i] && continue
@@ -249,26 +251,43 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
249251

250252
# Step 1: Perform bareiss factorization on the adjacency matrix of the linear
251253
# subsystem of the system we're interested in.
254+
#
255+
# Let `m = the number of linear equations` and `n = the number of
256+
# variables`.
257+
#
258+
# `do_bareiss` conceptually gives us this system:
259+
# rank1 | [ M₁₁ M₁₂ | M₁₃ M₁₄ ] [v₁] = [0]
260+
# rank2 | [ 0 M₂₂ | M₂₃ M₂₄ ] P [v₂] = [0]
261+
# -------------------|------------------------
262+
# rank3 | [ 0 0 | M₃₃ M₃₄ ] [v₃] = [0]
263+
# [ 0 0 | 0 0 ] [v₄] = [0]
252264
(rank1, rank2, rank3, pivots) = do_bareiss!(mm, mm_orig)
253265

254-
# Step 2: Simplify the system using the bareiss factorization
266+
# Step 2: Simplify the system using the Bareiss factorization
255267
ag = AliasGraph(size(mm, 2))
256268
for v in setdiff(solvable_variables, @view pivots[1:rank1])
257269
ag[v] = 0
258270
end
259271

260-
# kind of like the backward substitution
261-
lss!(ei::Integer) = locally_structure_simplify!((@view mm[ei, :]), pivots[ei], ag, isnothing(diff_to_var[pivots[ei]]))
272+
# Kind of like the backward substitution, but we don't actually rely on it
273+
# being lower triangular. We eliminate a variable if there are at most 2
274+
# variables left after the substitution.
275+
function lss!(ei::Integer)
276+
vi = pivots[ei]
277+
# the lowest differentiated variable can be eliminated
278+
islowest = isnothing(diff_to_var[vi])
279+
locally_structure_simplify!((@view mm[ei, :]), vi, ag, islowest)
280+
end
262281

263282
# Step 2.1: Go backwards, collecting eliminated variables and substituting
264283
# alias as we go.
265284
foreach(lss!, reverse(1:rank2))
266285

267-
# Step 2.2: Sometimes bareiss can make the equations more complicated.
286+
# Step 2.2: Sometimes Bareiss can make the equations more complicated.
268287
# Go back and check the original matrix. If this happened,
269288
# Replace the equation by the one from the original system,
270-
# but be sure to also run lss! again, since we only ran that
271-
# on the bareiss'd matrix, not the original one.
289+
# but be sure to also run `lss!` again, since we only ran that
290+
# on the Bareiss'd matrix, not the original one.
272291
reduced = mapreduce(|, 1:rank2; init=false) do ei
273292
if count_nonzeros(@view mm_orig[ei, :]) < count_nonzeros(@view mm[ei, :])
274293
mm[ei, :] = @view mm_orig[ei, :]
@@ -277,14 +296,14 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
277296
return false
278297
end
279298

280-
# Step 2.3: Iterate to convergance.
299+
# Step 2.3: Iterate to convergence.
281300
# N.B.: `lss!` modifies the array.
282301
# TODO: We know exactly what variable we eliminated. Starting over at the
283302
# start is wasteful. We can lookup which equations have this variable
284303
# using the graph.
285304
reduced && while any(lss!, 1:rank2); end
286305

287-
# Step 3: Reflect our update decitions back into the graph
306+
# Step 3: Reflect our update decisions back into the graph
288307
for (ei, e) in enumerate(mm.nzrows)
289308
set_neighbors!(graph, e, mm.row_cols[ei])
290309
end
@@ -326,23 +345,32 @@ function locally_structure_simplify!(adj_row, pivot_col, ag, may_eliminate)
326345
continue
327346
end
328347
(coeff, alias_var) = alias
348+
# `var = coeff * alias_var`, so we eliminate this var.
329349
adj_row[var] = 0
330350
if alias_var != 0
351+
# val * var = val * (coeff * alias_var) = (val * coeff) * alias_var
331352
val *= coeff
353+
# val * var + c * alias_var + ... = (val * coeff + c) * alias_var + ...
332354
new_coeff = (adj_row[alias_var] += val)
333355
if alias_var < var
334-
# If this adds to a coeff that was not previously accounted for, and
335-
# we've already passed it, make sure to count it here. We're
336-
# relying on `var` being produced in sorted order here.
356+
# If this adds to a coeff that was not previously accounted for,
357+
# and we've already passed it, make sure to count it here. We
358+
# need to know if there are at most 2 terms left after this
359+
# loop.
360+
#
361+
# We're relying on `var` being produced in sorted order here.
337362
nirreducible += 1
338363
alias_candidate = new_coeff => alias_var
339364
end
340365
end
341366
end
342367

343368
if may_eliminate && nirreducible <= 1
344-
# There were only one or two terms left in the equation (including the pivot variable).
345-
# We can eliminate the pivot variable.
369+
# There were only one or two terms left in the equation (including the
370+
# pivot variable). We can eliminate the pivot variable.
371+
#
372+
# Note that when `nirreducible <= 1`, `alias_candidate` is uniquely
373+
# determined.
346374
if alias_candidate !== 0
347375
alias_candidate = -exactdiv(alias_candidate[1], pivot_val) => alias_candidate[2]
348376
end

0 commit comments

Comments
 (0)