Skip to content

Commit 0444f7d

Browse files
committed
Add InducedAliasGraph
1 parent 1a07aba commit 0444f7d

File tree

2 files changed

+93
-35
lines changed

2 files changed

+93
-35
lines changed

src/systems/alias_elimination.jl

Lines changed: 78 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -19,57 +19,35 @@ function aag_bareiss(sys::AbstractSystem)
1919
return aag_bareiss!(state.structure.graph, complete(state.structure.var_to_diff), mm)
2020
end
2121

22-
function extreme_var(var_to_diff, v, level = nothing, ::Val{descend} = Val(true)) where descend
22+
function extreme_var(var_to_diff, v, level = nothing, ::Val{descend} = Val(true); callback = _ -> nothing) where descend
2323
g = descend ? invview(var_to_diff) : var_to_diff
24+
callback(v)
2425
while (v′ = g[v]) !== nothing
25-
v = v′
26+
v::Int = v′
27+
callback(v)
2628
if level !== nothing
2729
descend ? (level -= 1) : (level += 1)
2830
end
2931
end
3032
level === nothing ? v : (v => level)
3133
end
3234

33-
function neighbor_branches!(visited, (ag, invag), var_to_diff, v, level = 0)
34-
ns = Pair{Int, Int}[]
35-
visited[v] && return ns
36-
v′::Union{Nothing, Int} = v
37-
diff_to_var = invview(var_to_diff)
38-
while (v′ = diff_to_var[v]) !== nothing
39-
v = v′
40-
level -= 1
41-
end
42-
while true
43-
if (_n = get(ag, v, nothing)) !== nothing
44-
n = _n[2]
45-
visited[n] || push!(ns, n => level)
46-
end
47-
for n in neighbors(invag, v)
48-
visited[n] || push!(ns, n => level)
49-
end
50-
visited[v] = true
51-
(v′ = var_to_diff[v]) === nothing && break
52-
v = v′
53-
level += 1
54-
end
55-
ns
56-
end
57-
58-
function walk_to_root!(visited, ags, var_to_diff, v::Integer, level = 0)
59-
brs = neighbor_branches!(visited, ags, var_to_diff, v, level)
35+
function walk_to_root!(iag, v::Integer, level = 0)
36+
brs = neighbors(iag, v)
6037
min_var_level = v => level
61-
isempty(brs) && return extreme_var(var_to_diff, min_var_level...)
62-
for (x, lv) in brs
63-
x, lv = walk_to_root!(visited, ags, var_to_diff, x, lv)
38+
for (x, lv′) in brs
39+
lv = lv′ + level
40+
x, lv = walk_to_root!(iag, x, lv)
6441
if min_var_level[2] > lv
6542
min_var_level = x => lv
6643
end
6744
end
68-
return extreme_var(var_to_diff, min_var_level...)
45+
return extreme_var(iag.var_to_diff, min_var_level...)
6946
end
7047

7148
function alias_elimination(sys)
7249
state = TearingState(sys; quick_cancel = true)
50+
Main._state[] = state
7351
ag, mm = alias_eliminate_graph!(state)
7452
ag === nothing && return sys
7553

@@ -108,15 +86,16 @@ function alias_elimination(sys)
10886
end
10987
Main._a[] = ag, invag
11088
processed = falses(nvars)
111-
visited = falses(nvars)
89+
iag = InducedAliasGraph(ag, invag, var_to_diff, processed)
11290
newag = AliasGraph(nvars)
11391
for (v, dv) in enumerate(var_to_diff)
11492
processed[v] && continue
11593
(dv === nothing && diff_to_var[v] === nothing) && continue
11694

11795
# TODO: use an iterator, and get a relative level vector for `processed`
11896
# variabels.
119-
r, lv = walk_to_root!(processed, (ag, invag), var_to_diff, v)
97+
r, lv = walk_to_root!(iag, v)
98+
fill!(processed, false)
12099
#lv = extreme_var(var_to_diff, v, -lv, Val(false))
121100
lv′ = extreme_var(var_to_diff, v, 0, Val(false))[2]
122101
let
@@ -330,6 +309,70 @@ function Base.in(i::Int, agk::AliasGraphKeySet)
330309
1 <= i <= length(aliasto) && aliasto[i] !== nothing
331310
end
332311

312+
struct InducedAliasGraph
313+
ag::AliasGraph
314+
invag::SimpleDiGraph{Int}
315+
var_to_diff::DiffGraph
316+
visited::BitVector
317+
end
318+
319+
InducedAliasGraph(ag, invag, var_to_diff) = InducedAliasGraph(ag, invag, var_to_diff, falses(nv(invag)))
320+
321+
struct IAGNeighbors
322+
iag::InducedAliasGraph
323+
v::Int
324+
end
325+
326+
function Base.iterate(it::IAGNeighbors, state = nothing)
327+
@unpack ag, invag, var_to_diff, visited = it.iag
328+
callback! = let visited = visited
329+
var -> visited[var] = true
330+
end
331+
if state === nothing
332+
v, lv = extreme_var(var_to_diff, it.v, 0)
333+
used_ag = false
334+
nb = neighbors(invag, v)
335+
nit = iterate(nb)
336+
state = (v, lv, used_ag, nb, nit)
337+
end
338+
339+
v, level, used_ag, nb, nit = state
340+
visited[v] && return nothing
341+
while true
342+
@label TRYAGIN
343+
if used_ag
344+
if nit !== nothing
345+
n, ns = nit
346+
if !visited[n]
347+
n, lv = extreme_var(var_to_diff, n, level)
348+
extreme_var(var_to_diff, n, nothing, Val(false), callback = callback!)
349+
nit = iterate(nb, ns)
350+
return n => lv, (v, level, used_ag, nb, nit)
351+
end
352+
end
353+
else
354+
used_ag = true
355+
if (_n = get(ag, v, nothing)) !== nothing
356+
n = _n[2]
357+
if !visited[n]
358+
n, lv = extreme_var(var_to_diff, n, level)
359+
extreme_var(var_to_diff, n, nothing, Val(false), callback = callback!)
360+
return n => lv, (v, level, used_ag, nb, nit)
361+
end
362+
else
363+
@goto TRYAGIN
364+
end
365+
end
366+
visited[v] = true
367+
(v′ = var_to_diff[v]) === nothing && return nothing
368+
v::Int = v′
369+
level += 1
370+
used_ag = false
371+
end
372+
end
373+
374+
Graphs.neighbors(iag::InducedAliasGraph, v::Integer) = IAGNeighbors(iag, v)
375+
333376
count_nonzeros(a::AbstractArray) = count(!iszero, a)
334377

335378
# N.B.: Ordinarily sparse vectors allow zero stored elements.

src/systems/systemstructure.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,21 @@ function invview(dg::DiffGraph)
142142
return DiffGraph(dg.diff_to_primal, dg.primal_to_diff)
143143
end
144144

145+
struct DiffChainIterator{Descend}
146+
var_to_diff::DiffGraph
147+
v::Int
148+
end
149+
150+
function Base.iterate(di::DiffChainIterator{Descend}, v=nothing) where Descend
151+
if v === nothing
152+
vv = di.v
153+
return (vv, vv)
154+
end
155+
g = Descend ? invview(di.var_to_diff) : di.var_to_diff
156+
v′ = g[v]
157+
v′ === nothing ? nothing : (v′, v′)
158+
end
159+
145160
abstract type TransformationState{T} end
146161
abstract type AbstractTearingState{T} <: TransformationState{T} end
147162

0 commit comments

Comments
 (0)