Skip to content

Commit 2c7fba5

Browse files
committed
Only add aliases for variables in the reachable set
1 parent 5fc9cd9 commit 2c7fba5

File tree

2 files changed

+37
-37
lines changed

2 files changed

+37
-37
lines changed

src/systems/alias_elimination.jl

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,6 @@ function extreme_var(var_to_diff, v, level = nothing, ::Val{descend} = Val(true)
3333
level === nothing ? v : (v => level)
3434
end
3535

36-
function walk_to_root!(relative_level, iag, v::Integer, level = 0)
37-
brs = neighbors(iag, v)
38-
min_var_level = v => level
39-
for (x, lv′) in brs
40-
lv = lv′ + level
41-
x, lv = walk_to_root!(relative_level, iag, x, lv)
42-
relative_level[x] = lv
43-
if min_var_level[2] > lv
44-
min_var_level = x => lv
45-
end
46-
end
47-
x, lv = extreme_var(iag.var_to_diff, min_var_level...)
48-
relative_level[x] = lv
49-
return x => lv
50-
end
51-
5236
function alias_elimination(sys)
5337
state = TearingState(sys; quick_cancel = true)
5438
Main._state[] = state
@@ -92,39 +76,29 @@ function alias_elimination(sys)
9276
processed = falses(nvars)
9377
#iag = InducedAliasGraph(ag, invag, var_to_diff, processed)
9478
iag = InducedAliasGraph(ag, invag, var_to_diff)
95-
relative_level = BitDict(nvars)
9679
newag = AliasGraph(nvars)
9780
for (v, dv) in enumerate(var_to_diff)
9881
processed[v] && continue
9982
(dv === nothing && diff_to_var[v] === nothing) && continue
10083

101-
# TODO: use an iterator, and get a relative level vector for `processed`
102-
# variabels.
103-
# Note that `rootlv` is non-positive
104-
r, rootlv = walk_to_root!(relative_level, iag, v)
105-
fill!(iag.visited, false)
84+
r, _ = find_root!(iag, v)
10685
let
10786
sv = fullvars[v]
10887
root = fullvars[r]
109-
@info "Found root $r" sv=>root level=rootlv
110-
for vv in relative_level
111-
@show fullvars[vv[1]]
112-
end
88+
@info "Found root $r" sv=>root
11389
end
11490
level_to_var = Int[]
11591
extreme_var(var_to_diff, r, nothing, Val(false), callback = Base.Fix1(push!, level_to_var))
11692
nlevels = length(level_to_var)
11793
current_level = Ref(0)
118-
add_alias! = let current_level = current_level, level_to_var = level_to_var, newag = newag
94+
add_alias! = let current_level = current_level, level_to_var = level_to_var, newag = newag, processed = processed
11995
v -> begin
12096
level = current_level[]
121-
# FIXME: only alias variables in the reachable set
12297
if level + 1 <= length(level_to_var)
12398
# TODO: make sure the coefficient is 1
12499
av = level_to_var[level + 1]
125100
if v != av # if the level_to_var isn't from the root branch
126101
newag[v] = 1 => av
127-
#@info "create alias" fullvars[v] => fullvars[level_to_var[level + 1]]
128102
end
129103
else
130104
@assert length(level_to_var) == level
@@ -134,13 +108,18 @@ function alias_elimination(sys)
134108
current_level[] += 1
135109
end
136110
end
137-
for (v, rl) in pairs(relative_level)
138-
@assert diff_to_var[v] === nothing
111+
for (lv, t) in StatefulBFS(RootedAliasTree(iag, r))
112+
v = nodevalue(t)
113+
processed[v] = true
139114
v == r && continue
140-
current_level[] = rl - rootlv
115+
if lv < length(level_to_var)
116+
if level_to_var[lv + 1] == v
117+
continue
118+
end
119+
end
120+
current_level[] = lv
141121
extreme_var(var_to_diff, v, nothing, Val(false), callback = add_alias!)
142122
end
143-
empty!(relative_level)
144123
if nlevels < (new_nlevels = length(level_to_var))
145124
@assert !(D isa Nothing)
146125
for i in (nlevels + 1):new_nlevels
@@ -384,6 +363,7 @@ struct IAGNeighbors
384363
end
385364

386365
function Base.iterate(it::IAGNeighbors, state = nothing)
366+
Main._a[] = it, state
387367
@unpack ag, invag, var_to_diff, visited = it.iag
388368
callback! = let visited = visited
389369
var -> visited[var] = true
@@ -424,22 +404,42 @@ function Base.iterate(it::IAGNeighbors, state = nothing)
424404
end
425405
end
426406
visited[v] = true
427-
(v′ = var_to_diff[v]) === nothing && return nothing
428-
v::Int = v′
407+
(v = var_to_diff[v]) === nothing && return nothing
429408
level += 1
430409
used_ag = false
431410
end
432411
end
433412

434413
Graphs.neighbors(iag::InducedAliasGraph, v::Integer) = IAGNeighbors(iag, v)
435414

415+
function _find_root!(iag::InducedAliasGraph, v::Integer, level = 0)
416+
brs = neighbors(iag, v)
417+
min_var_level = v => level
418+
for (x, lv′) in brs
419+
lv = lv′ + level
420+
x, lv = _find_root!(iag, x, lv)
421+
if min_var_level[2] > lv
422+
min_var_level = x => lv
423+
end
424+
end
425+
x, lv = extreme_var(iag.var_to_diff, min_var_level...)
426+
return x => lv
427+
end
428+
429+
function find_root!(iag::InducedAliasGraph, v::Integer)
430+
ret = _find_root!(iag, v)
431+
fill!(iag.visited, false)
432+
ret
433+
end
434+
436435
struct RootedAliasTree
437436
iag::InducedAliasGraph
438437
root::Int
439438
end
440439

441440
AbstractTrees.childtype(::Type{<:RootedAliasTree}) = Union{RootedAliasTree, Int}
442441
AbstractTrees.children(rat::RootedAliasTree) = RootedAliasChildren(rat)
442+
AbstractTrees.nodetype(::Type{<:RootedAliasTree}) = Int
443443
AbstractTrees.nodevalue(rat::RootedAliasTree) = rat.root
444444
AbstractTrees.shouldprintkeys(rat::RootedAliasTree) = false
445445
has_fast_reverse(::Type{<:AbstractSimpleTreeIter{<:RootedAliasTree}}) = false

src/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -685,9 +685,9 @@ Base.eltype(::Type{<:StatefulBFS{T}}) where T = Tuple{Int, childtype(T)}
685685
function Base.iterate(it::StatefulBFS, queue = (eltype(it)[(0, it.t)]))
686686
isempty(queue) && return nothing
687687
lv, t = popfirst!(queue)
688-
lv += 1
688+
nextlv = lv + 1
689689
for c in children(t)
690-
push!(queue, (lv, c))
690+
push!(queue, (nextlv, c))
691691
end
692692
return (lv, t), queue
693693
end

0 commit comments

Comments
 (0)