Skip to content

Commit 654c5fb

Browse files
Revert "bucketing compile time improve (pytorch#168122)"
This reverts commit 1328a02. Reverted pytorch#168122 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](pytorch#168122 (comment)))
1 parent 5ff187d commit 654c5fb

File tree

4 files changed

+56
-59
lines changed

4 files changed

+56
-59
lines changed

test/distributed/test_overlap_bucketing_unit.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,28 @@ def build_collective_info(graph, hiding_annotations):
9393
return collective_info
9494

9595

96+
def compute_ancestors(graph):
97+
"""Compute ancestor sets for all nodes in the graph."""
98+
node_ancestors = {}
99+
100+
for node in graph.nodes:
101+
ancestors = OrderedSet()
102+
stack = list(node.all_input_nodes)
103+
visited = set()
104+
105+
while stack:
106+
current = stack.pop()
107+
if current in visited:
108+
continue
109+
visited.add(current)
110+
ancestors.add(current)
111+
stack.extend(current.all_input_nodes)
112+
113+
node_ancestors[node] = ancestors
114+
115+
return node_ancestors
116+
117+
96118
@requires_accelerator_dist_backend()
97119
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
98120
@instantiate_parametrized_tests
@@ -168,8 +190,9 @@ def func(a, b):
168190
ag2: mm2, # mm2 hides ag2
169191
}
170192

171-
# Build collective info and scheduled
193+
# Build collective info and ancestors
172194
collective_info = build_collective_info(traced.graph, hiding_annotations)
195+
node_ancestors = compute_ancestors(traced.graph)
173196
scheduled = OrderedSet(traced.graph.nodes)
174197

175198
# Run bucketing
@@ -180,6 +203,7 @@ def func(a, b):
180203
bucketer = OverlapPreservingBucketer(
181204
traced.graph,
182205
collective_info,
206+
node_ancestors,
183207
scheduled,
184208
)
185209
bucketer.bucket_collectives()
@@ -254,8 +278,9 @@ def func(a, b):
254278
ag2: mm2, # mm2 hides ag2
255279
}
256280

257-
# Build collective info and scheduled
281+
# Build collective info and ancestors
258282
collective_info = build_collective_info(traced.graph, hiding_annotations)
283+
node_ancestors = compute_ancestors(traced.graph)
259284
scheduled = OrderedSet(traced.graph.nodes)
260285

261286
# Run bucketing
@@ -266,6 +291,7 @@ def func(a, b):
266291
bucketer = OverlapPreservingBucketer(
267292
traced.graph,
268293
collective_info,
294+
node_ancestors,
269295
scheduled,
270296
)
271297
bucketer.bucket_collectives()
@@ -355,8 +381,9 @@ def func(a, b, c):
355381
if final_mm_hidden:
356382
hiding_annotations[rs] = mm2
357383

358-
# Build collective info and scheduled
384+
# Build collective info and ancestors
359385
collective_info = build_collective_info(traced.graph, hiding_annotations)
386+
node_ancestors = compute_ancestors(traced.graph)
360387
scheduled = OrderedSet(traced.graph.nodes)
361388

362389
# Run bucketing logic to find buckets (without applying them, which would require process groups)
@@ -367,6 +394,7 @@ def func(a, b, c):
367394
bucketer = OverlapPreservingBucketer(
368395
traced.graph,
369396
collective_info,
397+
node_ancestors,
370398
scheduled,
371399
)
372400

@@ -439,6 +467,7 @@ def func(a, b):
439467

440468
# Build collective info
441469
collective_info = build_collective_info(traced.graph, hiding_annotations)
470+
node_ancestors = compute_ancestors(traced.graph)
442471
scheduled = OrderedSet(traced.graph.nodes)
443472

444473
# Run bucketing
@@ -449,6 +478,7 @@ def func(a, b):
449478
bucketer = OverlapPreservingBucketer(
450479
traced.graph,
451480
collective_info,
481+
node_ancestors,
452482
scheduled,
453483
)
454484
bucketer.bucket_collectives()
@@ -520,8 +550,9 @@ def func(a, b):
520550
ag2: mm2, # mm2 hides ag2
521551
}
522552

523-
# Build collective info and scheduled
553+
# Build collective info and ancestors
524554
collective_info = build_collective_info(traced.graph, hiding_annotations)
555+
node_ancestors = compute_ancestors(traced.graph)
525556
scheduled = OrderedSet(traced.graph.nodes)
526557

527558
# Run bucketing with multidtype mode
@@ -532,6 +563,7 @@ def func(a, b):
532563
bucketer = OverlapPreservingBucketer(
533564
traced.graph,
534565
collective_info,
566+
node_ancestors,
535567
scheduled,
536568
bucket_mode="custom_ops_multidtype",
537569
)
@@ -603,8 +635,9 @@ def func(a, b):
603635
ag2: [mm2, mm3], # ag2 is hidden by mm2 and mm3
604636
}
605637

606-
# Build collective info and scheduled
638+
# Build collective info and ancestors
607639
collective_info = build_collective_info(traced.graph, hiding_annotations)
640+
node_ancestors = compute_ancestors(traced.graph)
608641
scheduled = OrderedSet(traced.graph.nodes)
609642

610643
# Verify hiding_nodes are correctly set
@@ -623,6 +656,7 @@ def func(a, b):
623656
bucketer = OverlapPreservingBucketer(
624657
traced.graph,
625658
collective_info,
659+
node_ancestors,
626660
scheduled,
627661
)
628662
bucketer.bucket_collectives()
@@ -695,8 +729,9 @@ def func(a, b, c):
695729
ag3: mm,
696730
}
697731

698-
# Build collective info and scheduled
732+
# Build collective info and ancestors
699733
collective_info = build_collective_info(traced.graph, hiding_annotations)
734+
node_ancestors = compute_ancestors(traced.graph)
700735
scheduled = OrderedSet(traced.graph.nodes)
701736

702737
# Run bucketing
@@ -707,6 +742,7 @@ def func(a, b, c):
707742
bucketer = OverlapPreservingBucketer(
708743
traced.graph,
709744
collective_info,
745+
node_ancestors,
710746
scheduled,
711747
)
712748
bucketer.bucket_collectives()

torch/_inductor/fx_passes/overlap_manual_scheduling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def __init__(
182182
self.bucketer = ManualOverlapPreservingBucketer(
183183
graph=self.graph,
184184
collective_info=self.collective_info,
185+
node_ancestors=self.node_ancestors,
185186
node_users=self.node_users,
186187
scheduled=OrderedSet(self.graph.nodes),
187188
)

torch/_inductor/fx_passes/overlap_preserving_bucketer.py

Lines changed: 12 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import itertools
21
import logging
32
from collections import defaultdict
43
from dataclasses import dataclass
@@ -131,6 +130,7 @@ def __init__(
131130
self,
132131
graph: fx.Graph,
133132
collective_info: dict[fx.Node, CollectiveInfo],
133+
node_ancestors: dict[fx.Node, OrderedSet[fx.Node]],
134134
scheduled: OrderedSet[fx.Node],
135135
max_bucket_memory_gb: float = 1.0,
136136
max_coll_distance: int = 1000,
@@ -139,45 +139,18 @@ def __init__(
139139
):
140140
self.graph = graph
141141
self.collective_info = collective_info
142+
self.node_ancestors = node_ancestors
142143
self.scheduled = scheduled
143144
self.max_bucket_memory_gb = max_bucket_memory_gb
144145
self.node_idx = {n: i for i, n in enumerate(scheduled)}
146+
self.aug_graph = AugmentedGraphHelper(self.graph, self.node_ancestors)
145147
self.max_coll_distance = max_coll_distance
146148
self.insert_overlap_deps = insert_overlap_deps
147149
self.bucket_mode = bucket_mode
148150
self.node_to_event: dict[fx.Node, PGEvent] = {}
149-
150-
# Compute ancestors including original graph edges and hiding interval dependencies
151-
self.node_ancestors = self._compute_node_ancestors()
152-
self.aug_graph = AugmentedGraphHelper(self.graph, self.node_ancestors)
153-
154-
# Build timelines and add constraints to aug_graph
155151
self.pg_to_timeline_head: dict[str, Optional[PGEvent]] = self.build_timelines()
156-
self._add_hiding_interval_constraints()
157-
158-
def _compute_node_ancestors(self) -> dict[fx.Node, OrderedSet[fx.Node]]:
159-
"""
160-
Compute ancestor sets for all nodes including:
161-
1. Original graph edges
162-
2. Hiding interval deps: collective_start -> hiding_node -> wait
163-
"""
164-
augmented_inputs: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
165-
for start, info in self.collective_info.items():
166-
if info.is_exposed:
167-
continue
168-
for hiding_node in info.hiding_nodes:
169-
augmented_inputs[hiding_node].add(start)
170-
augmented_inputs[info.wait_node].add(hiding_node)
171152

172-
node_ancestors: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
173-
for node in self.scheduled:
174-
for input_node in itertools.chain(
175-
augmented_inputs[node], node.all_input_nodes
176-
):
177-
node_ancestors[node].add(input_node)
178-
node_ancestors[node] |= node_ancestors[input_node]
179-
180-
return node_ancestors
153+
self._add_hiding_interval_constraints()
181154

182155
def build_timelines(self) -> dict[str, Optional[PGEvent]]:
183156
"Construct each process groups ordered series of event"
@@ -364,30 +337,21 @@ def _find_buckets(
364337
)
365338
processed.add(start_node)
366339

367-
# Greedy optimization: stop after consecutive failures
368-
consecutive_failures = 0
369-
max_consecutive_failures = 20
370-
371340
# Check candidates in sorted order, break when beyond max distance
372341
for candidate in sorted_collectives[i + 1 : i + 1 + self.max_coll_distance]:
342+
if candidate in processed:
343+
continue
344+
373345
candidate_bytes = self.collective_info[candidate].size_bytes
374346
# proxy on memory use, if we see a too large bucket,
375347
# dont look for another, later bucket
376348
if bucket_info.total_bytes + candidate_bytes > max_bucket_bytes:
377349
break
378350

379-
if candidate in processed:
380-
continue
381-
382351
if self._can_add_to_bucket(bucket_info, candidate):
383352
bucket_info.collectives.append(candidate)
384353
bucket_info.total_bytes += candidate_bytes
385354
processed.add(candidate)
386-
consecutive_failures = 0 # Reset on success
387-
else:
388-
consecutive_failures += 1
389-
if consecutive_failures >= max_consecutive_failures:
390-
break
391355

392356
if len(bucket_info.collectives) > 1:
393357
buckets.append(bucket_info)
@@ -692,28 +656,23 @@ def _has_ancestor_conflicts(
692656
candidate_wait = candidate_info.wait_node
693657

694658
for coll in bucket_info.collectives:
695-
if (
696-
coll in self.node_ancestors[candidate]
697-
or candidate in self.node_ancestors[coll]
698-
):
659+
# Check if collectives are ancestors of each other
660+
if self._ancestor_dep(coll, candidate):
699661
return True
700662

701663
# Check if waits are ancestors of each other
702664
coll_wait = self.collective_info[coll].wait_node
703-
if (
704-
coll_wait in self.node_ancestors[candidate_wait]
705-
or candidate_wait in self.node_ancestors[coll_wait]
706-
):
665+
if self._ancestor_dep(candidate_wait, coll_wait):
707666
return True
708667

709668
# Check if existing hiding node conflicts with candidate wait
710669
for old_hiding_node in self.collective_info[coll].hiding_nodes:
711-
if candidate_wait in self.node_ancestors[old_hiding_node]:
670+
if self._ancestor_dep(old_hiding_node, candidate_wait):
712671
return True
713672

714673
# Check if candidate hiding node conflicts with existing wait
715674
for new_hiding_node in candidate_info.hiding_nodes:
716-
if coll_wait in self.node_ancestors[new_hiding_node]:
675+
if self._ancestor_dep(new_hiding_node, coll_wait):
717676
return True
718677

719678
return False

torch/_inductor/fx_passes/overlap_scheduling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,6 +1125,7 @@ def _bucket_collectives(self) -> None:
11251125
bucketer = OverlapPreservingBucketer(
11261126
graph=self.graph,
11271127
collective_info=self.collective_info,
1128+
node_ancestors=self.node_ancestors,
11281129
scheduled=self.scheduled,
11291130
max_bucket_memory_gb=2.0, # Could make this configurable
11301131
max_coll_distance=self.max_node_distance,

0 commit comments

Comments
 (0)