Skip to content

Commit 5fc9cd9

Browse files
committed
WIP
1 parent 978b869 commit 5fc9cd9

File tree

2 files changed

+139
-6
lines changed

2 files changed

+139
-6
lines changed

src/systems/alias_elimination.jl

Lines changed: 77 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ function alias_elimination(sys)
9090
end
9191
Main._a[] = ag, invag
9292
processed = falses(nvars)
93-
iag = InducedAliasGraph(ag, invag, var_to_diff, processed)
93+
#iag = InducedAliasGraph(ag, invag, var_to_diff, processed)
94+
iag = InducedAliasGraph(ag, invag, var_to_diff)
9495
relative_level = BitDict(nvars)
9596
newag = AliasGraph(nvars)
9697
for (v, dv) in enumerate(var_to_diff)
@@ -101,10 +102,14 @@ function alias_elimination(sys)
101102
# variabels.
102103
# Note that `rootlv` is non-positive
103104
r, rootlv = walk_to_root!(relative_level, iag, v)
105+
fill!(iag.visited, false)
104106
let
105107
sv = fullvars[v]
106108
root = fullvars[r]
107-
@warn "" sv=>root level=rootlv
109+
@info "Found root $r" sv=>root level=rootlv
110+
for vv in relative_level
111+
@show fullvars[vv[1]]
112+
end
108113
end
109114
level_to_var = Int[]
110115
extreme_var(var_to_diff, r, nothing, Val(false), callback = Base.Fix1(push!, level_to_var))
@@ -116,11 +121,16 @@ function alias_elimination(sys)
116121
# FIXME: only alias variables in the reachable set
117122
if level + 1 <= length(level_to_var)
118123
# TODO: make sure the coefficient is 1
119-
newag[v] = 1 => level_to_var[level + 1]
124+
av = level_to_var[level + 1]
125+
if v != av # if the level_to_var isn't from the root branch
126+
newag[v] = 1 => av
127+
#@info "create alias" fullvars[v] => fullvars[level_to_var[level + 1]]
128+
end
120129
else
121130
@assert length(level_to_var) == level
122131
push!(level_to_var, v)
123132
end
133+
processed[v] = true
124134
current_level[] += 1
125135
end
126136
end
@@ -139,14 +149,30 @@ function alias_elimination(sys)
139149
end
140150
end
141151
end
152+
#=
153+
for (v, (c, a)) in ag
154+
va = iszero(a) ? a : fullvars[a]
155+
@warn "old alias" fullvars[v] => (c, va)
156+
end
157+
for (v, (c, a)) in newag
158+
va = iszero(a) ? a : fullvars[a]
159+
@warn "new alias" fullvars[v] => (c, va)
160+
end
161+
=#
162+
println("================")
142163
newkeys = keys(newag)
143164
for (v, (c, a)) in ag
144165
(v in newkeys || a in newkeys) && continue
145-
newag[v] = c => a
166+
if iszero(c)
167+
newag[v] = c
168+
else
169+
newag[v] = c => a
170+
end
146171
end
147172
ag = newag
148173
for (v, (c, a)) in ag
149-
@warn "new alias" fullvars[v] => (c, fullvars[a])
174+
va = iszero(a) ? a : fullvars[a]
175+
@warn "new alias" fullvars[v] => (c, va)
150176
end
151177

152178
subs = Dict()
@@ -179,14 +205,15 @@ function alias_elimination(sys)
179205
newstates = []
180206
for j in eachindex(fullvars)
181207
if j in keys(ag)
208+
#=
182209
_, var = ag[j]
183210
iszero(var) && continue
184211
# Put back equations for alias eliminated dervars
185212
if isdervar(state.structure, var)
186213
has_higher_order = false
187214
v = var
188215
while (v = var_to_diff[v]) !== nothing
189-
if !(v in keys(ag))
216+
if !(v::Int in keys(ag))
190217
has_higher_order = true
191218
break
192219
end
@@ -197,6 +224,7 @@ function alias_elimination(sys)
197224
diff_to_var[j] === nothing && push!(newstates, rhs)
198225
end
199226
end
227+
=#
200228
else
201229
diff_to_var[j] === nothing && push!(newstates, fullvars[j])
202230
end
@@ -405,6 +433,49 @@ end
405433

406434
Graphs.neighbors(iag::InducedAliasGraph, v::Integer) = IAGNeighbors(iag, v)
407435

436+
struct RootedAliasTree
437+
iag::InducedAliasGraph
438+
root::Int
439+
end
440+
441+
AbstractTrees.childtype(::Type{<:RootedAliasTree}) = Union{RootedAliasTree, Int}
442+
AbstractTrees.children(rat::RootedAliasTree) = RootedAliasChildren(rat)
443+
AbstractTrees.nodevalue(rat::RootedAliasTree) = rat.root
444+
AbstractTrees.shouldprintkeys(rat::RootedAliasTree) = false
445+
has_fast_reverse(::Type{<:AbstractSimpleTreeIter{<:RootedAliasTree}}) = false
446+
447+
struct RootedAliasChildren
448+
t::RootedAliasTree
449+
end
450+
451+
function Base.iterate(c::RootedAliasChildren, s = nothing)
452+
rat = c.t
453+
@unpack iag, root = rat
454+
@unpack ag, invag, var_to_diff, visited = iag
455+
(root = var_to_diff[root]) === nothing && return nothing
456+
root::Int
457+
if s === nothing
458+
stage = 1
459+
it = iterate(neighbors(invag, root))
460+
s = (stage, it)
461+
end
462+
(stage, it) = s
463+
if stage == 1 # root
464+
stage += 1
465+
return root, (stage, it)
466+
elseif stage == 2 # ag
467+
stage += 1
468+
cv = get(ag, root, nothing)
469+
if cv !== nothing
470+
return RootedAliasTree(iag, cv[2]), (stage, it)
471+
end
472+
end
473+
# invag (stage 3)
474+
it === nothing && return nothing
475+
e, ns = it
476+
return RootedAliasTree(iag, e), (stage, iterate(invag, ns))
477+
end
478+
408479
count_nonzeros(a::AbstractArray) = count(!iszero, a)
409480

410481
# N.B.: Ordinarily sparse vectors allow zero stored elements.

src/utils.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,3 +629,65 @@ function Base.empty!(d::BitDict)
629629
empty!(d.keys)
630630
d
631631
end
632+
633+
abstract type AbstractSimpleTreeIter{T} end
634+
Base.IteratorSize(::Type{<:AbstractSimpleTreeIter}) = Base.SizeUnknown()
635+
Base.eltype(::Type{<:AbstractSimpleTreeIter{T}}) where T = childtype(T)
636+
has_fast_reverse(::Type{<:AbstractSimpleTreeIter}) = true
637+
has_fast_reverse(::T) where T<:AbstractSimpleTreeIter = has_fast_reverse(T)
638+
reverse_buffer(it::AbstractSimpleTreeIter) = has_fast_reverse(it) ? nothing : eltype(it)[]
639+
reverse_children!(::Nothing, cs) = Iterators.reverse(cs)
640+
function reverse_children!(rev_buff, cs)
641+
Iterators.reverse(cs)
642+
empty!(rev_buff)
643+
for c in cs
644+
push!(rev_buff, c)
645+
end
646+
Iterators.reverse(rev_buff)
647+
end
648+
649+
struct StatefulPreOrderDFS{T} <: AbstractSimpleTreeIter{T}
650+
t::T
651+
end
652+
function Base.iterate(it::StatefulPreOrderDFS, state = (eltype(it)[it.t], reverse_buffer(it)))
653+
stack, rev_buff = state
654+
isempty(stack) && return nothing
655+
t = pop!(stack)
656+
for c in reverse_children!(rev_buff, children(t))
657+
push!(stack, c)
658+
end
659+
return t, state
660+
end
661+
struct StatefulPostOrderDFS{T} <: AbstractSimpleTreeIter{T}
662+
t::T
663+
end
664+
function Base.iterate(it::StatefulPostOrderDFS, state = (eltype(it)[it.t], falses(1), reverse_buffer(it)))
665+
isempty(state[2]) && return nothing
666+
vstack, sstack, rev_buff = state
667+
while true
668+
t = pop!(vstack)
669+
isresume = pop!(sstack)
670+
isresume && return t, state
671+
push!(vstack, t)
672+
push!(sstack, true)
673+
for c in reverse_children!(rev_buff, children(t))
674+
push!(vstack, c)
675+
push!(sstack, false)
676+
end
677+
end
678+
end
679+
680+
# Note that StatefulBFS also returns the depth.
681+
struct StatefulBFS{T} <: AbstractSimpleTreeIter{T}
682+
t::T
683+
end
684+
Base.eltype(::Type{<:StatefulBFS{T}}) where T = Tuple{Int, childtype(T)}
685+
function Base.iterate(it::StatefulBFS, queue = (eltype(it)[(0, it.t)]))
686+
isempty(queue) && return nothing
687+
lv, t = popfirst!(queue)
688+
lv += 1
689+
for c in children(t)
690+
push!(queue, (lv, c))
691+
end
692+
return (lv, t), queue
693+
end

0 commit comments

Comments
 (0)