1
1
using SymbolicUtils: Rewriters
2
+ using SimpleWeightedGraphs
3
+ using Graphs. Experimental. Traversals
2
4
3
5
const KEEP = typemin (Int)
4
6
@@ -288,35 +290,22 @@ end
288
290
289
291
function tograph (ag:: AliasGraph , var_to_diff:: DiffGraph )
290
292
g = SimpleDiGraph {Int} (length (var_to_diff))
293
+ eqg = SimpleWeightedGraph {Int, Int} (length (var_to_diff))
291
294
zero_vars = Int[]
292
- for (v, (_ , a)) in ag
295
+ for (v, (c , a)) in ag
293
296
if iszero (a)
294
297
push! (zero_vars, v)
295
298
continue
296
299
end
297
300
add_edge! (g, v, a)
298
301
add_edge! (g, a, v)
302
+
303
+ add_edge! (eqg, v, a, c)
304
+ add_edge! (eqg, a, v, c)
299
305
end
300
306
transitiveclosure! (g)
301
- #=
302
- # Compute the largest transitive closure that doesn't include any diff
303
- # edges.
304
- og = g
305
- newg = SimpleDiGraph{Int}(length(var_to_diff))
306
- for e in Graphs.edges(og)
307
- s, d = src(e), dst(e)
308
- (var_to_diff[s] == d || var_to_diff[d] == s) && continue
309
- oldg = copy(newg)
310
- add_edge!(newg, s, d)
311
- add_edge!(newg, d, s)
312
- transitiveclosure!(newg)
313
- if any(e->(var_to_diff[src(e)] == dst(e) || var_to_diff[dst(e)] == src(e)), edges(newg))
314
- newg = oldg
315
- end
316
- end
317
- g = newg
318
- =#
319
- eqg = copy (g)
307
+ Main. _a[] = copy (eqg)
308
+ weighted_transitiveclosure! (eqg)
320
309
321
310
c = " green"
322
311
edge_styles = Dict {Tuple{Int, Int}, String} ()
@@ -329,7 +318,17 @@ function tograph(ag::AliasGraph, var_to_diff::DiffGraph)
329
318
g, eqg, zero_vars, edge_styles
330
319
end
331
320
332
- using Graphs. Experimental. Traversals
321
+ function weighted_transitiveclosure! (g)
322
+ cps = connected_components (g)
323
+ for cp in cps
324
+ for k in cp, i in cp, j in cp
325
+ (has_edge (g, i, k) && has_edge (g, k, j)) || continue
326
+ add_edge! (g, i, j, get_weight (g, i, k) * get_weight (g, k, j))
327
+ end
328
+ end
329
+ return g
330
+ end
331
+
333
332
struct DiffLevelState <: Traversals.AbstractTraversalState
334
333
dists:: Vector{Int}
335
334
var_to_diff:: DiffGraph
@@ -763,6 +762,10 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
763
762
ag = AliasGraph (nvars)
764
763
mm, echelon_mm = simple_aliases! (ag, graph, var_to_diff, mm_orig)
765
764
fullvars = Main. _state[]. fullvars
765
+ for (v, (c, a)) in ag
766
+ a = iszero (a) ? 0 : c * fullvars[a]
767
+ @info " ag" fullvars[v] => a
768
+ end
766
769
767
770
# Step 3: Handle differentiated variables
768
771
# At this point, `var_to_diff` and `ag` form a tree structure like the
@@ -802,17 +805,14 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
802
805
(dv === nothing && diff_to_var[v] === nothing ) && continue
803
806
r = find_root! (dls, g, v)
804
807
@show fullvars[r]
805
- level_to_var = Int[]
806
- extreme_var (var_to_diff, r, nothing , Val (false ),
807
- callback = Base. Fix1 (push!, level_to_var))
808
- nlevels = length (level_to_var)
809
808
prev_r = - 1
810
809
stem = Int[]
810
+ stem_set = BitSet ()
811
811
for _ in 1 : 10_000 # just to make sure that we don't stuck in an infinite loop
812
812
reach₌ = Pair{Int, Int}[]
813
- r === nothing || for n in neighbors (g , r)
813
+ r === nothing || for n in neighbors (eqg , r)
814
814
(n == r || is_diff_edge (r, n)) && continue
815
- c = 1
815
+ c = get_weight (eqg, r, n)
816
816
push! (reach₌, c => n)
817
817
end
818
818
if (n = length (diff_aliases)) >= 1
@@ -823,45 +823,80 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
823
823
push! (reach₌, c => da)
824
824
end
825
825
end
826
- for (c, a) in reach₌
827
- @info fullvars[r] => c * fullvars[a]
828
- end
829
826
if r === nothing
830
- @warn " hi"
831
- # TODO : updated_diff_vars check
832
827
isempty (reach₌) && break
833
- dr = first (reach₌)
828
+ idx = findfirst (x-> x[1 ] == 1 , reach₌)
829
+ if idx === nothing
830
+ c, dr = reach₌[1 ]
831
+ @assert c == - 1
832
+ dag[dr] = (c, dr)
833
+ else
834
+ c, dr = reach₌[idx]
835
+ @assert c == 1
836
+ end
834
837
var_to_diff[prev_r] = dr
835
838
push! (updated_diff_vars, prev_r)
836
839
prev_r = dr
837
840
else
838
- @warn " " fullvars[r]
839
841
prev_r = r
840
842
r = var_to_diff[r]
841
843
end
842
- for (c, v) in reach₌
844
+ for (c, a) in reach₌
845
+ if r === nothing
846
+ var_to_diff[prev_r] === nothing && continue
847
+ @info fullvars[var_to_diff[prev_r]] => c * fullvars[a]
848
+ else
849
+ @info fullvars[r] => c * fullvars[a]
850
+ end
851
+ end
852
+ prev_r in stem_set && break
853
+ push! (stem_set, prev_r)
854
+ push! (stem, prev_r)
855
+ push! (diff_aliases, reach₌)
856
+ for (_, v) in reach₌
843
857
v == prev_r && continue
844
858
add_edge! (eqg, v, prev_r)
845
- push! (stem, prev_r)
846
- dag[v] = c => prev_r
847
859
end
848
- push! (diff_aliases, reach₌)
849
860
end
861
+
862
+ @show fullvars[updated_diff_vars]
863
+ @info " " fullvars
864
+ @show stem
865
+ @show diff_aliases
850
866
@info " " fullvars[stem]
851
- transitiveclosure! (eqg)
867
+ display (diff_aliases)
868
+ @assert length (stem) == length (diff_aliases)
869
+ for i in eachindex (stem)
870
+ a = stem[i]
871
+ for (c, v) in diff_aliases[i]
872
+ # alias edges that coincide with diff edges are handled later
873
+ v in stem_set && continue
874
+ dag[v] = c => a
875
+ end
876
+ end
877
+ # Obtain transitive closure after completing the alias edges from diff
878
+ # edges.
879
+ weighted_transitiveclosure! (eqg)
880
+ # Canonicalize by preferring the lower differentiated variable
852
881
for i in 1 : length (stem) - 1
853
- r, dr = stem[i], stem[i+ 1 ]
854
- if has_edge (eqg, r, dr)
855
- c = 1
856
- dag[dr] = c => r
882
+ r = stem[i]
883
+ for dr in @view stem[i+ 1 : end ]
884
+ if has_edge (eqg, r, dr)
885
+ c = get_weight (eqg, r, dr)
886
+ dag[dr] = c => r
887
+ end
857
888
end
858
889
end
859
- for v in zero_vars, a in outneighbors (g, v)
860
- dag[a] = 0
890
+ for v in zero_vars
891
+ for a in Iterators. flatten ((v, outneighbors (eqg, v)))
892
+ while true
893
+ dag[a] = 0
894
+ da = var_to_diff[a]
895
+ da === nothing && break
896
+ a = da
897
+ end
898
+ end
861
899
end
862
- @show nlevels
863
- display (diff_aliases)
864
- @assert length (diff_aliases) == nlevels
865
900
866
901
# clean up
867
902
for v in dls. visited
@@ -871,6 +906,7 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
871
906
empty! (dls. visited)
872
907
empty! (diff_aliases)
873
908
end
909
+ @show dag
874
910
for k in keys (dag)
875
911
dag[k]
876
912
end
0 commit comments