Skip to content

Commit 2400c58

Browse files
committed
Use visited to avoid cycles
1 parent 37596c4 commit 2400c58

File tree

1 file changed

+36
-22
lines changed

1 file changed

+36
-22
lines changed

src/systems/alias_elimination.jl

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,6 @@ end
231231
function reduce!(mm::SparseMatrixCLIL, ag::AliasGraph)
232232
dels = Int[]
233233
for (i, rs) in enumerate(mm.row_cols)
234-
p = i == 7
235234
rvals = mm.row_vals[i]
236235
j = 1
237236
while j <= length(rs)
@@ -244,9 +243,9 @@ function reduce!(mm::SparseMatrixCLIL, ag::AliasGraph)
244243
inc = coeff * rvals[j]
245244
i = searchsortedfirst(rs, alias)
246245
if i > length(rs) || rs[i] != alias
247-
if i <= j
248-
j += 1
249-
end
246+
# if we add a variable to what we already visited, make sure
247+
# to bump the cursor.
248+
j += i <= j
250249
insert!(rs, i, alias)
251250
insert!(rvals, i, inc)
252251
else
@@ -272,11 +271,11 @@ struct InducedAliasGraph
272271
ag::AliasGraph
273272
invag::SimpleDiGraph{Int}
274273
var_to_diff::DiffGraph
275-
visited::BitVector
274+
visited::BitSet
276275
end
277276

278277
function InducedAliasGraph(ag, invag, var_to_diff)
279-
InducedAliasGraph(ag, invag, var_to_diff, falses(nv(invag)))
278+
InducedAliasGraph(ag, invag, var_to_diff, BitSet())
280279
end
281280

282281
struct IAGNeighbors
@@ -286,9 +285,7 @@ end
286285

287286
function Base.iterate(it::IAGNeighbors, state = nothing)
288287
@unpack ag, invag, var_to_diff, visited = it.iag
289-
callback! = let visited = visited
290-
var -> visited[var] = true
291-
end
288+
callback! = Base.Fix1(push!, visited)
292289
if state === nothing
293290
v, lv = extreme_var(var_to_diff, it.v, 0)
294291
used_ag = false
@@ -298,13 +295,13 @@ function Base.iterate(it::IAGNeighbors, state = nothing)
298295
end
299296

300297
v, level, used_ag, nb, nit = state
301-
visited[v] && return nothing
298+
v in visited && return nothing
302299
while true
303300
@label TRYAGIN
304301
if used_ag
305302
if nit !== nothing
306303
n, ns = nit
307-
if !visited[n]
304+
if !(n in visited)
308305
n, lv = extreme_var(var_to_diff, n, level)
309306
extreme_var(var_to_diff, n, nothing, Val(false), callback = callback!)
310307
nit = iterate(nb, ns)
@@ -316,7 +313,7 @@ function Base.iterate(it::IAGNeighbors, state = nothing)
316313
# We don't care about the alising value because we only use this to
317314
# find the root of the tree.
318315
if (_n = get(ag, v, nothing)) !== nothing && (n = _n[2]) > 0
319-
if !visited[n]
316+
if !(n in visited)
320317
n, lv = extreme_var(var_to_diff, n, level)
321318
extreme_var(var_to_diff, n, nothing, Val(false), callback = callback!)
322319
return n => lv, (v, level, used_ag, nb, nit)
@@ -325,7 +322,7 @@ function Base.iterate(it::IAGNeighbors, state = nothing)
325322
@goto TRYAGIN
326323
end
327324
end
328-
visited[v] = true
325+
push!(visited, v)
329326
(v = var_to_diff[v]) === nothing && return nothing
330327
level += 1
331328
used_ag = false
@@ -349,11 +346,12 @@ function _find_root!(iag::InducedAliasGraph, v::Integer, level = 0)
349346
end
350347

351348
function find_root!(iag::InducedAliasGraph, v::Integer)
352-
ret = _find_root!(iag, v)
353-
fill!(iag.visited, false)
354-
ret
349+
clear_visited!(iag)
350+
_find_root!(iag, v)
355351
end
356352

353+
clear_visited!(iag::InducedAliasGraph) = (empty!(iag.visited); iag)
354+
357355
struct RootedAliasTree
358356
iag::InducedAliasGraph
359357
root::Int
@@ -376,9 +374,6 @@ function Base.iterate(it::StatefulAliasBFS, queue = (eltype(it)[(1, 0, it.t)]))
376374
coeff, lv, t = popfirst!(queue)
377375
nextlv = lv + 1
378376
for (coeff′, c) in children(t)
379-
# FIXME: use the visited cache!
380-
# A "cycle" might occur when we have `x ~ D(x)`.
381-
nodevalue(c) == t.root && continue
382377
# -1 <= coeff <= 1
383378
push!(queue, (coeff * coeff′, nextlv, c))
384379
end
@@ -392,7 +387,26 @@ end
392387
function Base.iterate(c::RootedAliasChildren, s = nothing)
393388
rat = c.t
394389
@unpack iag, root = rat
395-
@unpack ag, invag, var_to_diff, visited = iag
390+
@unpack visited = iag
391+
push!(visited, root)
392+
it = _iterate(c, s)
393+
it === nothing && return nothing
394+
while true
395+
node = nodevalue(it[1][2])
396+
if node in visited
397+
it = _iterate(c, it[2])
398+
it === nothing && return nothing
399+
else
400+
push!(visited, node)
401+
return it
402+
end
403+
end
404+
end
405+
406+
@inline function _iterate(c::RootedAliasChildren, s = nothing)
407+
rat = c.t
408+
@unpack iag, root = rat
409+
@unpack ag, invag, var_to_diff = iag
396410
(iszero(root) || (root = var_to_diff[root]) === nothing) && return nothing
397411
if s === nothing
398412
stage = 1
@@ -516,7 +530,6 @@ function simple_aliases!(ag, graph, var_to_diff, mm_orig, only_algebraic, irredu
516530
irreducibles)
517531

518532
# Step 2: Simplify the system using the Bareiss factorization
519-
ks = keys(ag)
520533
for v in setdiff(solvable_variables, @view pivots[1:rank1])
521534
ag[v] = 0
522535
end
@@ -624,6 +637,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
624637
end
625638
end
626639
max_lv = 0
640+
clear_visited!(iag)
627641
for (coeff, lv, t) in StatefulAliasBFS(RootedAliasTree(iag, r))
628642
max_lv = max(max_lv, lv)
629643
v = nodevalue(t)
@@ -692,7 +706,7 @@ function locally_structure_simplify!(adj_row, pivot_var, ag)
692706
iszero(pivot_val) && return false
693707

694708
nirreducible = 0
695-
alias_candidate::Union{Int, Pair{Int, Int}} = 0
709+
alias_candidate::Pair{Int, Int} = 0 => 0
696710

697711
# N.B.: Assumes that the non-zeros iterator is robust to modification
698712
# of the underlying array datastructure.

0 commit comments

Comments
 (0)