Skip to content

Commit 020899f

Browse files
committed
Clean up
1 parent 2c7fba5 commit 020899f

File tree

1 file changed

+33
-54
lines changed

1 file changed

+33
-54
lines changed

src/systems/alias_elimination.jl

Lines changed: 33 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,8 @@ function extreme_var(var_to_diff, v, level = nothing, ::Val{descend} = Val(true)
3333
level === nothing ? v : (v => level)
3434
end
3535

36-
function alias_elimination(sys)
36+
function alias_elimination(sys; debug = false)
3737
state = TearingState(sys; quick_cancel = true)
38-
Main._state[] = state
3938
ag, mm = alias_eliminate_graph!(state)
4039
ag === nothing && return sys
4140

@@ -72,7 +71,6 @@ function alias_elimination(sys)
7271
iszero(coeff) && continue
7372
add_edge!(invag, alias, v)
7473
end
75-
Main._a[] = ag, invag
7674
processed = falses(nvars)
7775
#iag = InducedAliasGraph(ag, invag, var_to_diff, processed)
7876
iag = InducedAliasGraph(ag, invag, var_to_diff)
@@ -82,33 +80,33 @@ function alias_elimination(sys)
8280
(dv === nothing && diff_to_var[v] === nothing) && continue
8381

8482
r, _ = find_root!(iag, v)
85-
let
83+
if debug
8684
sv = fullvars[v]
8785
root = fullvars[r]
8886
@info "Found root $r" sv=>root
8987
end
9088
level_to_var = Int[]
9189
extreme_var(var_to_diff, r, nothing, Val(false), callback = Base.Fix1(push!, level_to_var))
9290
nlevels = length(level_to_var)
93-
current_level = Ref(0)
94-
add_alias! = let current_level = current_level, level_to_var = level_to_var, newag = newag, processed = processed
91+
current_coeff_level = Ref((0, 0))
92+
add_alias! = let current_coeff_level = current_coeff_level, level_to_var = level_to_var, newag = newag, processed = processed
9593
v -> begin
96-
level = current_level[]
94+
coeff, level = current_coeff_level[]
9795
if level + 1 <= length(level_to_var)
9896
# TODO: make sure the coefficient is 1
9997
av = level_to_var[level + 1]
10098
if v != av # if the level_to_var isn't from the root branch
101-
newag[v] = 1 => av
99+
newag[v] = coeff => av
102100
end
103101
else
104102
@assert length(level_to_var) == level
105103
push!(level_to_var, v)
106104
end
107105
processed[v] = true
108-
current_level[] += 1
106+
current_coeff_level[] = (coeff, level + 1)
109107
end
110108
end
111-
for (lv, t) in StatefulBFS(RootedAliasTree(iag, r))
109+
for (coeff, lv, t) in StatefulAliasBFS(RootedAliasTree(iag, r))
112110
v = nodevalue(t)
113111
processed[v] = true
114112
v == r && continue
@@ -117,7 +115,7 @@ function alias_elimination(sys)
117115
continue
118116
end
119117
end
120-
current_level[] = lv
118+
current_coeff_level[] = coeff, lv
121119
extreme_var(var_to_diff, v, nothing, Val(false), callback = add_alias!)
122120
end
123121
if nlevels < (new_nlevels = length(level_to_var))
@@ -128,17 +126,7 @@ function alias_elimination(sys)
128126
end
129127
end
130128
end
131-
#=
132-
for (v, (c, a)) in ag
133-
va = iszero(a) ? a : fullvars[a]
134-
@warn "old alias" fullvars[v] => (c, va)
135-
end
136-
for (v, (c, a)) in newag
137-
va = iszero(a) ? a : fullvars[a]
138-
@warn "new alias" fullvars[v] => (c, va)
139-
end
140-
=#
141-
println("================")
129+
142130
newkeys = keys(newag)
143131
for (v, (c, a)) in ag
144132
(v in newkeys || a in newkeys) && continue
@@ -149,9 +137,10 @@ function alias_elimination(sys)
149137
end
150138
end
151139
ag = newag
152-
for (v, (c, a)) in ag
140+
141+
debug && for (v, (c, a)) in ag
153142
va = iszero(a) ? a : fullvars[a]
154-
@warn "new alias" fullvars[v] => (c, va)
143+
@info "new alias" fullvars[v] => (c, va)
155144
end
156145

157146
subs = Dict()
@@ -363,7 +352,6 @@ struct IAGNeighbors
363352
end
364353

365354
function Base.iterate(it::IAGNeighbors, state = nothing)
366-
Main._a[] = it, state
367355
@unpack ag, invag, var_to_diff, visited = it.iag
368356
callback! = let visited = visited
369357
var -> visited[var] = true
@@ -444,6 +432,22 @@ AbstractTrees.nodevalue(rat::RootedAliasTree) = rat.root
444432
AbstractTrees.shouldprintkeys(rat::RootedAliasTree) = false
445433
has_fast_reverse(::Type{<:AbstractSimpleTreeIter{<:RootedAliasTree}}) = false
446434

435+
struct StatefulAliasBFS{T} <: AbstractSimpleTreeIter{T}
436+
t::T
437+
end
438+
# alias coefficient, depth, children
439+
Base.eltype(::Type{<:StatefulAliasBFS{T}}) where T = Tuple{Int, Int, childtype(T)}
440+
function Base.iterate(it::StatefulAliasBFS, queue = (eltype(it)[(1, 0, it.t)]))
441+
isempty(queue) && return nothing
442+
coeff, lv, t = popfirst!(queue)
443+
nextlv = lv + 1
444+
for (coeff′, c) in children(t)
445+
# -1 <= coeff <= 1
446+
push!(queue, (coeff * coeff′, nextlv, c))
447+
end
448+
return (coeff, lv, t), queue
449+
end
450+
447451
struct RootedAliasChildren
448452
t::RootedAliasTree
449453
end
@@ -462,18 +466,19 @@ function Base.iterate(c::RootedAliasChildren, s = nothing)
462466
(stage, it) = s
463467
if stage == 1 # root
464468
stage += 1
465-
return root, (stage, it)
469+
return (1, root), (stage, it)
466470
elseif stage == 2 # ag
467471
stage += 1
468472
cv = get(ag, root, nothing)
469473
if cv !== nothing
470-
return RootedAliasTree(iag, cv[2]), (stage, it)
474+
return (cv[1], RootedAliasTree(iag, cv[2])), (stage, it)
471475
end
472476
end
473477
# invag (stage 3)
474478
it === nothing && return nothing
475479
e, ns = it
476-
return RootedAliasTree(iag, e), (stage, iterate(invag, ns))
480+
# c * a = b <=> a = c * b when -1 <= c <= 1
481+
return (ag[e], RootedAliasTree(iag, e)), (stage, iterate(invag, ns))
477482
end
478483

479484
count_nonzeros(a::AbstractArray) = count(!iszero, a)
@@ -656,32 +661,6 @@ function locally_structure_simplify!(adj_row, pivot_var, ag, var_to_diff)
656661

657662
if alias_candidate isa Pair
658663
alias_val, alias_var = alias_candidate
659-
#preferred_var = pivot_var
660-
#=
661-
switch = false # we prefer `alias_var` by default, unless we switch
662-
diff_to_var = invview(var_to_diff)
663-
pivot_var′′::Union{Nothing, Int} = pivot_var′::Int = pivot_var
664-
alias_var′′::Union{Nothing, Int} = alias_var′::Int = alias_var
665-
# We prefer the higher differenitated variable. Note that `{⋅}′′` vars
666-
# could be `nothing` while `{⋅}′` vars are always `Int`.
667-
while (pivot_var′′ = diff_to_var[pivot_var′]) !== nothing
668-
pivot_var′ = pivot_var′′
669-
if (alias_var′′ = diff_to_var[alias_var′]) === nothing
670-
switch = true
671-
break
672-
end
673-
pivot_var′ = pivot_var′′
674-
end
675-
# If we have a tie, then we prefer the lower variable.
676-
if alias_var′′ === pivot_var′′ === nothing
677-
@assert pivot_var′ != alias_var′
678-
switch = pivot_var′ < alias_var′
679-
end
680-
if switch
681-
pivot_var, alias_var = alias_var, pivot_var
682-
pivot_val, alias_val = alias_val, pivot_val
683-
end
684-
=#
685664

686665
# `p` is the pivot variable, `a` is the alias variable, `v` and `c` are
687666
# their coefficients.

0 commit comments

Comments
 (0)