1
1
using SymbolicUtils: Rewriters
2
- using SimpleWeightedGraphs
3
2
using Graphs. Experimental. Traversals
4
3
5
4
const KEEP = typemin (Int)
@@ -153,7 +152,8 @@ function alias_elimination!(state::TearingState; kwargs...)
153
152
end
154
153
end
155
154
for ieq in eqs_to_update
156
- eqs[ieq] = substitute (eqs[ieq], subs)
155
+ eq = eqs[ieq]
156
+ eqs[ieq] = fast_substitute (eq, subs)
157
157
end
158
158
159
159
for old_ieq in to_expand
@@ -365,9 +365,33 @@ function Base.in(i::Int, agk::AliasGraphKeySet)
365
365
1 <= i <= length (aliasto) && aliasto[i] != = nothing
366
366
end
367
367
368
+ canonicalize (a, b) = a <= b ? (a, b) : (b, a)
369
+ struct WeightedGraph{T, W} <: AbstractGraph{T}
370
+ graph:: SimpleGraph{T}
371
+ dict:: Dict{Tuple{T, T}, W}
372
+ end
373
+ function WeightedGraph {T, W} (n) where {T, W}
374
+ WeightedGraph {T, W} (SimpleGraph {T} (n), Dict {Tuple{T, T}, W} ())
375
+ end
376
+
377
+ function Graphs. add_edge! (g:: WeightedGraph , u, v, w)
378
+ r = add_edge! (g. graph, u, v)
379
+ r && (g. dict[canonicalize (u, v)] = w)
380
+ r
381
+ end
382
+ Graphs. has_edge (g:: WeightedGraph , u, v) = has_edge (g. graph, u, v)
383
+ Graphs. ne (g:: WeightedGraph ) = ne (g. graph)
384
+ Graphs. nv (g:: WeightedGraph ) = nv (g. graph)
385
+ get_weight (g:: WeightedGraph , u, v) = g. dict[canonicalize (u, v)]
386
+ Graphs. is_directed (:: Type{<:WeightedGraph} ) = false
387
+ Graphs. inneighbors (g:: WeightedGraph , v) = inneighbors (g. graph, v)
388
+ Graphs. outneighbors (g:: WeightedGraph , v) = outneighbors (g. graph, v)
389
+ Graphs. vertices (g:: WeightedGraph ) = vertices (g. graph)
390
+ Graphs. edges (g:: WeightedGraph ) = vertices (g. graph)
391
+
368
392
function equality_diff_graph (ag:: AliasGraph , var_to_diff:: DiffGraph )
369
393
g = SimpleDiGraph {Int} (length (var_to_diff))
370
- eqg = SimpleWeightedGraph {Int, Int} (length (var_to_diff))
394
+ eqg = WeightedGraph {Int, Int} (length (var_to_diff))
371
395
zero_vars = Int[]
372
396
for (v, (c, a)) in ag
373
397
if iszero (a)
@@ -378,7 +402,6 @@ function equality_diff_graph(ag::AliasGraph, var_to_diff::DiffGraph)
378
402
add_edge! (g, a, v)
379
403
380
404
add_edge! (eqg, v, a, c)
381
- add_edge! (eqg, a, v, c)
382
405
end
383
406
transitiveclosure! (g)
384
407
weighted_transitiveclosure! (eqg)
394
417
function weighted_transitiveclosure! (g)
395
418
cps = connected_components (g)
396
419
for cp in cps
397
- for k in cp, i in cp, j in cp
398
- (has_edge (g, i, k) && has_edge (g, k, j)) || continue
399
- add_edge! (g, i, j, get_weight (g, i, k) * get_weight (g, k, j))
420
+ n = length (cp)
421
+ for k in cp
422
+ for i′ in 1 : n, j′ in (i′ + 1 ): n
423
+ i = cp[i′]
424
+ j = cp[j′]
425
+ (has_edge (g, i, k) && has_edge (g, k, j)) || continue
426
+ add_edge! (g, i, j, get_weight (g, i, k) * get_weight (g, k, j))
427
+ end
400
428
end
401
429
end
402
430
return g
@@ -670,11 +698,12 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
670
698
(v, w) -> var_to_diff[v] == w || var_to_diff[w] == v
671
699
end
672
700
diff_aliases = Vector{Pair{Int, Int}}[]
673
- stem = Int[]
701
+ stems = Vector{ Int} []
674
702
stem_set = BitSet ()
675
703
for (v, dv) in enumerate (var_to_diff)
676
704
processed[v] && continue
677
705
(dv === nothing && diff_to_var[v] === nothing ) && continue
706
+ stem = Int[]
678
707
r = find_root! (dls, g, v)
679
708
prev_r = - 1
680
709
for _ in 1 : 10_000 # just to make sure that we don't stuck in an infinite loop
@@ -714,9 +743,9 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
714
743
push! (stem_set, prev_r)
715
744
push! (stem, prev_r)
716
745
push! (diff_aliases, reach₌)
717
- for (_ , v) in reach₌
746
+ for (c , v) in reach₌
718
747
v == prev_r && continue
719
- add_edge! (eqg, v, prev_r)
748
+ add_edge! (eqg, v, prev_r, c )
720
749
end
721
750
end
722
751
@@ -729,9 +758,24 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
729
758
dag[v] = c => a
730
759
end
731
760
end
732
- # Obtain transitive closure after completing the alias edges from diff
733
- # edges.
734
- weighted_transitiveclosure! (eqg)
761
+ push! (stems, stem)
762
+
763
+ # clean up
764
+ for v in dls. visited
765
+ dls. dists[v] = typemax (Int)
766
+ processed[v] = true
767
+ end
768
+ empty! (dls. visited)
769
+ empty! (diff_aliases)
770
+ empty! (stem_set)
771
+ end
772
+
773
+ # Obtain transitive closure after completing the alias edges from diff
774
+ # edges. As a performance optimization, we only compute the transitive
775
+ # closure once at the very end.
776
+ weighted_transitiveclosure! (eqg)
777
+ zero_vars_set = BitSet ()
778
+ for stem in stems
735
779
# Canonicalize by preferring the lower differentiated variable
736
780
# If we have the system
737
781
# ```
@@ -780,7 +824,6 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
780
824
# x := 0
781
825
# y := 0
782
826
# ```
783
- zero_vars_set = BitSet ()
784
827
for v in zero_vars
785
828
for a in Iterators. flatten ((v, outneighbors (eqg, v)))
786
829
while true
@@ -803,17 +846,9 @@ function alias_eliminate_graph!(graph, var_to_diff, mm_orig::SparseMatrixCLIL)
803
846
dag[v] = 0
804
847
end
805
848
end
806
-
807
- # clean up
808
- for v in dls. visited
809
- dls. dists[v] = typemax (Int)
810
- processed[v] = true
811
- end
812
- empty! (dls. visited)
813
- empty! (diff_aliases)
814
- empty! (stem)
815
- empty! (stem_set)
849
+ empty! (zero_vars_set)
816
850
end
851
+
817
852
# update `dag`
818
853
for k in keys (dag)
819
854
dag[k]
0 commit comments