@@ -33,9 +33,8 @@ function extreme_var(var_to_diff, v, level = nothing, ::Val{descend} = Val(true)
33
33
level === nothing ? v : (v => level)
34
34
end
35
35
36
- function alias_elimination (sys)
36
+ function alias_elimination (sys; debug = false )
37
37
state = TearingState (sys; quick_cancel = true )
38
- Main. _state[] = state
39
38
ag, mm = alias_eliminate_graph! (state)
40
39
ag === nothing && return sys
41
40
@@ -72,7 +71,6 @@ function alias_elimination(sys)
72
71
iszero (coeff) && continue
73
72
add_edge! (invag, alias, v)
74
73
end
75
- Main. _a[] = ag, invag
76
74
processed = falses (nvars)
77
75
# iag = InducedAliasGraph(ag, invag, var_to_diff, processed)
78
76
iag = InducedAliasGraph (ag, invag, var_to_diff)
@@ -82,33 +80,33 @@ function alias_elimination(sys)
82
80
(dv === nothing && diff_to_var[v] === nothing ) && continue
83
81
84
82
r, _ = find_root! (iag, v)
85
- let
83
+ if debug
86
84
sv = fullvars[v]
87
85
root = fullvars[r]
88
86
@info " Found root $r " sv=> root
89
87
end
90
88
level_to_var = Int[]
91
89
extreme_var (var_to_diff, r, nothing , Val (false ), callback = Base. Fix1 (push!, level_to_var))
92
90
nlevels = length (level_to_var)
93
- current_level = Ref (0 )
94
- add_alias! = let current_level = current_level , level_to_var = level_to_var, newag = newag, processed = processed
91
+ current_coeff_level = Ref (( 0 , 0 ) )
92
+ add_alias! = let current_coeff_level = current_coeff_level , level_to_var = level_to_var, newag = newag, processed = processed
95
93
v -> begin
96
- level = current_level []
94
+ coeff, level = current_coeff_level []
97
95
if level + 1 <= length (level_to_var)
98
96
# TODO : make sure the coefficient is 1
99
97
av = level_to_var[level + 1 ]
100
98
if v != av # if the level_to_var isn't from the root branch
101
- newag[v] = 1 => av
99
+ newag[v] = coeff => av
102
100
end
103
101
else
104
102
@assert length (level_to_var) == level
105
103
push! (level_to_var, v)
106
104
end
107
105
processed[v] = true
108
- current_level [] += 1
106
+ current_coeff_level [] = (coeff, level + 1 )
109
107
end
110
108
end
111
- for (lv, t) in StatefulBFS (RootedAliasTree (iag, r))
109
+ for (coeff, lv, t) in StatefulAliasBFS (RootedAliasTree (iag, r))
112
110
v = nodevalue (t)
113
111
processed[v] = true
114
112
v == r && continue
@@ -117,7 +115,7 @@ function alias_elimination(sys)
117
115
continue
118
116
end
119
117
end
120
- current_level [] = lv
118
+ current_coeff_level [] = coeff, lv
121
119
extreme_var (var_to_diff, v, nothing , Val (false ), callback = add_alias!)
122
120
end
123
121
if nlevels < (new_nlevels = length (level_to_var))
@@ -128,17 +126,7 @@ function alias_elimination(sys)
128
126
end
129
127
end
130
128
end
131
- #=
132
- for (v, (c, a)) in ag
133
- va = iszero(a) ? a : fullvars[a]
134
- @warn "old alias" fullvars[v] => (c, va)
135
- end
136
- for (v, (c, a)) in newag
137
- va = iszero(a) ? a : fullvars[a]
138
- @warn "new alias" fullvars[v] => (c, va)
139
- end
140
- =#
141
- println (" ================" )
129
+
142
130
newkeys = keys (newag)
143
131
for (v, (c, a)) in ag
144
132
(v in newkeys || a in newkeys) && continue
@@ -149,9 +137,10 @@ function alias_elimination(sys)
149
137
end
150
138
end
151
139
ag = newag
152
- for (v, (c, a)) in ag
140
+
141
+ debug && for (v, (c, a)) in ag
153
142
va = iszero (a) ? a : fullvars[a]
154
- @warn " new alias" fullvars[v] => (c, va)
143
+ @info " new alias" fullvars[v] => (c, va)
155
144
end
156
145
157
146
subs = Dict ()
@@ -363,7 +352,6 @@ struct IAGNeighbors
363
352
end
364
353
365
354
function Base. iterate (it:: IAGNeighbors , state = nothing )
366
- Main. _a[] = it, state
367
355
@unpack ag, invag, var_to_diff, visited = it. iag
368
356
callback! = let visited = visited
369
357
var -> visited[var] = true
@@ -444,6 +432,22 @@ AbstractTrees.nodevalue(rat::RootedAliasTree) = rat.root
444
432
AbstractTrees. shouldprintkeys (rat:: RootedAliasTree ) = false
445
433
has_fast_reverse (:: Type{<:AbstractSimpleTreeIter{<:RootedAliasTree}} ) = false
446
434
435
+ struct StatefulAliasBFS{T} <: AbstractSimpleTreeIter{T}
436
+ t:: T
437
+ end
438
+ # alias coefficient, depth, children
439
+ Base. eltype (:: Type{<:StatefulAliasBFS{T}} ) where T = Tuple{Int, Int, childtype (T)}
440
+ function Base. iterate (it:: StatefulAliasBFS , queue = (eltype (it)[(1 , 0 , it. t)]))
441
+ isempty (queue) && return nothing
442
+ coeff, lv, t = popfirst! (queue)
443
+ nextlv = lv + 1
444
+ for (coeff′, c) in children (t)
445
+ # -1 <= coeff <= 1
446
+ push! (queue, (coeff * coeff′, nextlv, c))
447
+ end
448
+ return (coeff, lv, t), queue
449
+ end
450
+
447
451
struct RootedAliasChildren
448
452
t:: RootedAliasTree
449
453
end
@@ -462,18 +466,19 @@ function Base.iterate(c::RootedAliasChildren, s = nothing)
462
466
(stage, it) = s
463
467
if stage == 1 # root
464
468
stage += 1
465
- return root, (stage, it)
469
+ return ( 1 , root) , (stage, it)
466
470
elseif stage == 2 # ag
467
471
stage += 1
468
472
cv = get (ag, root, nothing )
469
473
if cv != = nothing
470
- return RootedAliasTree (iag, cv[2 ]), (stage, it)
474
+ return (cv[ 1 ], RootedAliasTree (iag, cv[2 ]) ), (stage, it)
471
475
end
472
476
end
473
477
# invag (stage 3)
474
478
it === nothing && return nothing
475
479
e, ns = it
476
- return RootedAliasTree (iag, e), (stage, iterate (invag, ns))
480
+ # c * a = b <=> a = c * b when -1 <= c <= 1
481
+ return (ag[e], RootedAliasTree (iag, e)), (stage, iterate (invag, ns))
477
482
end
478
483
479
484
count_nonzeros (a:: AbstractArray ) = count (! iszero, a)
@@ -656,32 +661,6 @@ function locally_structure_simplify!(adj_row, pivot_var, ag, var_to_diff)
656
661
657
662
if alias_candidate isa Pair
658
663
alias_val, alias_var = alias_candidate
659
- # preferred_var = pivot_var
660
- #=
661
- switch = false # we prefer `alias_var` by default, unless we switch
662
- diff_to_var = invview(var_to_diff)
663
- pivot_var′′::Union{Nothing, Int} = pivot_var′::Int = pivot_var
664
- alias_var′′::Union{Nothing, Int} = alias_var′::Int = alias_var
665
- # We prefer the higher differenitated variable. Note that `{⋅}′′` vars
666
- # could be `nothing` while `{⋅}′` vars are always `Int`.
667
- while (pivot_var′′ = diff_to_var[pivot_var′]) !== nothing
668
- pivot_var′ = pivot_var′′
669
- if (alias_var′′ = diff_to_var[alias_var′]) === nothing
670
- switch = true
671
- break
672
- end
673
- pivot_var′ = pivot_var′′
674
- end
675
- # If we have a tie, then we prefer the lower variable.
676
- if alias_var′′ === pivot_var′′ === nothing
677
- @assert pivot_var′ != alias_var′
678
- switch = pivot_var′ < alias_var′
679
- end
680
- if switch
681
- pivot_var, alias_var = alias_var, pivot_var
682
- pivot_val, alias_val = alias_val, pivot_val
683
- end
684
- =#
685
664
686
665
# `p` is the pivot variable, `a` is the alias variable, `v` and `c` are
687
666
# their coefficients.
0 commit comments