Skip to content

Commit 1a0a198

Browse files
eellisonpytorchmergebot
authored andcommitted
Add multiple hiding nodes (pytorch#167847)
With smaller, aten nodes, we might want to overlap a single collective with multiple nodes. Updates the overlapping, and bucketing code so that a collective can be hidden by multiple nodes. Pull Request resolved: pytorch#167847 Approved by: https://github.com/fmassa
1 parent 39f5e0e commit 1a0a198

File tree

4 files changed

+210
-47
lines changed

4 files changed

+210
-47
lines changed

test/distributed/test_aten_comm_compute_reordering.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1061,6 +1061,63 @@ def func(a, b, c):
10611061
correct = func(a, b, c)
10621062
self.assertTrue(same(out, correct))
10631063

1064+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
1065+
@torch._inductor.config.patch(get_bucket_patches())
1066+
def test_multiple_hiding_nodes_bucketing(self):
1067+
"""Test that collectives hidden by multiple compute ops can bucket together."""
1068+
1069+
# Use 0.5 compute multiplier so each collective needs 2 matmuls to be fully hidden
1070+
def estimate_with_half_compute(fx_node, override_size=None):
1071+
return estimate_aten_runtime(fx_node, compute_multiplier=0.5)
1072+
1073+
def func(a, b, *, ranks):
1074+
# Two all_gathers that will be hidden by multiple compute operations
1075+
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
1076+
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
1077+
1078+
# Multiple compute operations that can hide the collectives
1079+
# With 0.5 multiplier: mm1 and mm2 together hide ag1, mm2 and mm3 together hide ag2
1080+
mm1 = torch.matmul(a, a.T)
1081+
mm2 = torch.matmul(b, b.T)
1082+
mm3 = torch.matmul(a + b, (a + b).T)
1083+
1084+
return ag1.sum() + ag2.sum() + mm1.sum() + mm2.sum() + mm3.sum()
1085+
1086+
with _dynamo_dist_per_rank_init(
1087+
self.rank,
1088+
self.world_size,
1089+
self.backend(device_type),
1090+
fake_pg=not at_least_x_gpu(2),
1091+
):
1092+
a = torch.ones(8, 8, dtype=torch.float, device=device_type)
1093+
b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2
1094+
ranks = list(range(self.world_size))
1095+
1096+
func_c = functools.partial(func, ranks=ranks)
1097+
1098+
# Patch with custom estimation that uses 0.5 multiplier
1099+
with torch._inductor.config.patch(
1100+
{
1101+
"aten_distributed_optimizations.custom_runtime_estimation": estimate_with_half_compute
1102+
}
1103+
):
1104+
compiled = torch.compile(func_c)
1105+
out, aten_graph_str = run_and_get_aten_graph(compiled, a, b)
1106+
1107+
# Should have 1 bucketed all_gather (both ag1 and ag2 bucketed together)
1108+
FileCheck().check_count(
1109+
"torch.ops._c10d_functional.wait_tensor.default", 1, exactly=True
1110+
).run(aten_graph_str)
1111+
1112+
# Verify bucketed collective is scheduled before all matmuls
1113+
FileCheck().check("functional.all_gather_into_tensor").check(
1114+
"aten.mm"
1115+
).check("aten.mm").check("aten.mm").check("wait_tensor").run(aten_graph_str)
1116+
1117+
# Verify correctness
1118+
correct = func(a, b, ranks=ranks)
1119+
self.assertTrue(same(out, correct))
1120+
10641121

10651122
def get_toy_model(device_type: str):
10661123
"""

test/distributed/test_overlap_bucketing_unit.py

Lines changed: 104 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ def build_collective_info(graph, hiding_annotations):
4949
"""
5050
Build CollectiveInfo dict from manual hiding annotations.
5151
52-
hiding_annotations: dict mapping collective_start -> hiding_compute_node
52+
hiding_annotations: dict mapping collective_start -> hiding_compute_node(s)
53+
Can be a single node or a list/OrderedSet of nodes
5354
"""
5455
from torch._inductor.fx_passes.overlap_scheduling import CollectiveInfo
5556

@@ -65,20 +66,28 @@ def build_collective_info(graph, hiding_annotations):
6566

6667
# Build CollectiveInfo for each collective
6768
for start_node, wait_node in start_to_wait.items():
68-
hiding_node = hiding_annotations.get(start_node)
69+
hiding_annotation = hiding_annotations.get(start_node)
70+
71+
# Convert to OrderedSet
72+
hiding_nodes = OrderedSet()
73+
if hiding_annotation is not None:
74+
if isinstance(hiding_annotation, list | OrderedSet):
75+
hiding_nodes = OrderedSet(hiding_annotation)
76+
else:
77+
hiding_nodes = OrderedSet([hiding_annotation])
6978

7079
# Estimate size and time
7180
size_bytes = 16 * 4 # 4x4 tensor of floats
7281
estimated_time_ms = 1.0 # Dummy time
73-
exposed_time_ms = 0.0 if hiding_node else 1.0 # Hidden if has hiding_node
82+
exposed_time_ms = 0.0 if hiding_nodes else 1.0 # Hidden if has hiding_nodes
7483

7584
collective_info[start_node] = CollectiveInfo(
7685
start_node=start_node,
7786
wait_node=wait_node,
7887
size_bytes=size_bytes,
7988
estimated_time_ms=estimated_time_ms,
8089
exposed_time_ms=exposed_time_ms,
81-
hiding_node=hiding_node,
90+
hiding_nodes=hiding_nodes,
8291
)
8392

8493
return collective_info
@@ -567,6 +576,97 @@ def func(a, b):
567576
graph_str
568577
)
569578

579+
def test_can_bucket_with_multiple_hiding_nodes(self):
580+
"""
581+
Test that collectives with multiple hiding nodes CAN bucket.
582+
583+
Graph structure:
584+
ag1_start -> ag2_start -> mm1 -> mm2 -> mm3 -> ag1_wait -> ag2_wait
585+
586+
Where:
587+
- ag1 is hidden by mm1 and mm2
588+
- ag2 is hidden by mm2 and mm3
589+
- Both collectives share mm2 as a hiding node
590+
"""
591+
592+
def func(a, b):
593+
group_name = "0"
594+
group_size = 1
595+
596+
# Start both collectives
597+
ag1 = torch.ops._c10d_functional.all_gather_into_tensor(
598+
a, group_size, group_name
599+
)
600+
ag2 = torch.ops._c10d_functional.all_gather_into_tensor(
601+
b, group_size, group_name
602+
)
603+
604+
# Three compute operations that hide the collectives
605+
mm1 = torch.mm(a, a)
606+
mm2 = torch.mm(b, b)
607+
mm3 = torch.mm(a + b, a + b)
608+
609+
# Wait for both
610+
ag1_out = torch.ops._c10d_functional.wait_tensor(ag1)
611+
ag2_out = torch.ops._c10d_functional.wait_tensor(ag2)
612+
613+
return ag1_out.sum() + ag2_out.sum() + mm1.sum() + mm2.sum() + mm3.sum()
614+
615+
# Use fake mode to trace without executing
616+
with FakeTensorMode():
617+
a = torch.ones(4, 4, device=self.device)
618+
b = torch.ones(4, 4, device=self.device) * 2
619+
620+
# Trace with make_fx
621+
traced = make_fx(func)(a, b)
622+
623+
# Find nodes using find_nodes
624+
ag1, ag2 = traced.graph.find_nodes(
625+
op="call_function",
626+
target=torch.ops._c10d_functional.all_gather_into_tensor.default,
627+
)
628+
mm1, mm2, mm3 = traced.graph.find_nodes(
629+
op="call_function", target=torch.ops.aten.mm.default
630+
)
631+
632+
# Manually annotate hiding relationships with multiple hiding nodes
633+
hiding_annotations = {
634+
ag1: [mm1, mm2], # ag1 is hidden by mm1 and mm2
635+
ag2: [mm2, mm3], # ag2 is hidden by mm2 and mm3
636+
}
637+
638+
# Build collective info and ancestors
639+
collective_info = build_collective_info(traced.graph, hiding_annotations)
640+
node_ancestors = compute_ancestors(traced.graph)
641+
scheduled = OrderedSet(traced.graph.nodes)
642+
643+
# Verify hiding_nodes are correctly set
644+
self.assertEqual(len(collective_info[ag1].hiding_nodes), 2)
645+
self.assertIn(mm1, collective_info[ag1].hiding_nodes)
646+
self.assertIn(mm2, collective_info[ag1].hiding_nodes)
647+
self.assertEqual(len(collective_info[ag2].hiding_nodes), 2)
648+
self.assertIn(mm2, collective_info[ag2].hiding_nodes)
649+
self.assertIn(mm3, collective_info[ag2].hiding_nodes)
650+
651+
# Run bucketing
652+
from torch._inductor.fx_passes.overlap_preserving_bucketer import (
653+
OverlapPreservingBucketer,
654+
)
655+
656+
bucketer = OverlapPreservingBucketer(
657+
traced.graph,
658+
collective_info,
659+
node_ancestors,
660+
scheduled,
661+
)
662+
bucketer.bucket_collectives()
663+
664+
FileCheck().check_count(
665+
"all_gather_into_tensor_out", 1, exactly=False
666+
).check_count("torch.ops.aten.mm.default", 3, exactly=True).run(
667+
str(traced.graph)
668+
)
669+
570670

571671
if __name__ == "__main__":
572672
run_tests()

torch/_inductor/fx_passes/overlap_preserving_bucketer.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -176,18 +176,20 @@ def build_timeline(self, pg: str) -> Optional[PGEvent]:
176176
head = None
177177
prev_event = None
178178
position = 0
179+
hiding_nodes = OrderedSet()
179180

180181
for node in self.scheduled:
181182
node_type = None
182183

183184
# Determine if this node is relevant for this PG
184185
if node in self.collective_info and get_group_name(node) == pg:
185186
node_type = "starts"
187+
hiding_nodes |= self.collective_info[node].hiding_nodes
186188
elif is_wait_tensor(node):
187189
wait_input = node.args[0]
188190
if isinstance(wait_input, fx.Node) and get_group_name(wait_input) == pg:
189191
node_type = "waits"
190-
elif is_compute_node(node):
192+
elif is_compute_node(node) or node in hiding_nodes:
191193
node_type = "compute"
192194

193195
if node_type is None:
@@ -205,7 +207,6 @@ def build_timeline(self, pg: str) -> Optional[PGEvent]:
205207

206208
prev_event = event
207209
position += 1
208-
209210
return head
210211

211212
def _populate_node_to_event(self, pg: str) -> None:
@@ -222,10 +223,12 @@ def _add_hiding_interval_constraints(self) -> None:
222223
Add hiding interval constraints: start -> compute -> wait.
223224
"""
224225
for start, info in self.collective_info.items():
225-
if info.hiding_node and not info.is_exposed:
226+
if info.is_exposed:
227+
continue
228+
for hn in info.hiding_nodes:
226229
# Enforce: start -> compute -> wait
227-
self.aug_graph.add_extra_dep(n=info.hiding_node, dep=start)
228-
self.aug_graph.add_extra_dep(n=info.wait_node, dep=info.hiding_node)
230+
self.aug_graph.add_extra_dep(n=hn, dep=start)
231+
self.aug_graph.add_extra_dep(n=info.wait_node, dep=hn)
229232

230233
def bucket_collectives(self) -> None:
231234
"""Main entry point for bucketing collectives."""
@@ -358,13 +361,13 @@ def _ancestor_dep(self, n1: fx.Node, n2: fx.Node) -> bool:
358361

359362
def _get_intervals(
360363
self, event: PGEvent
361-
) -> tuple[Optional[tuple[int, int]], Optional[tuple[int, int]]]:
362-
"""Get (execution_interval, hiding_interval) for a collective event.
364+
) -> tuple[Optional[tuple[int, int]], list[tuple[int, int]]]:
365+
"""Get (execution_interval, hiding_intervals) for a collective event.
363366
364367
Returns:
365-
(execution_interval, hiding_interval) where:
368+
(execution_interval, hiding_intervals) where:
366369
- execution_interval is (start_pos, wait_pos) or None
367-
- hiding_interval is (start_pos, compute_pos) or None if no hiding node
370+
- hiding_intervals is a list of (start_pos, compute_pos) tuples, one for each hiding node
368371
369372
Works for both start and wait events by looking up the collective info.
370373
"""
@@ -375,28 +378,31 @@ def _get_intervals(
375378
elif event.is_wait:
376379
wait_input = event.node.args[0]
377380
if not isinstance(wait_input, fx.Node):
378-
return None, None
381+
return None, []
379382
coll = wait_input
380383
else:
381-
return None, None
384+
return None, []
382385

383386
if coll not in self.collective_info:
384-
return None, None
387+
return None, []
385388

386389
info = self.collective_info[coll]
387390
start_event = self.node_to_event[coll]
388391
wait_event = self.node_to_event[info.wait_node]
389392

390393
execution_interval = (start_event.position, wait_event.position)
391394

392-
hiding_interval = None
393-
if info.hiding_node:
394-
hiding_interval = (
395-
start_event.position,
396-
self.node_to_event[info.hiding_node].position,
397-
)
395+
hiding_intervals = []
396+
if info.hiding_nodes:
397+
for hiding_node in info.hiding_nodes:
398+
hiding_intervals.append(
399+
(
400+
start_event.position,
401+
self.node_to_event[hiding_node].position,
402+
)
403+
)
398404

399-
return execution_interval, hiding_interval
405+
return execution_interval, hiding_intervals
400406

401407
def _preserves_hiding_intervals(
402408
self,
@@ -424,9 +430,9 @@ def _preserves_hiding_intervals(
424430
# Collect hiding compute positions for the bucket
425431
bucket_hiding_compute_positions = []
426432
for coll in all_bucketed_colls:
427-
if hiding_node := self.collective_info[coll].hiding_node:
433+
for coll_hiding_node in self.collective_info[coll].hiding_nodes:
428434
bucket_hiding_compute_positions.append(
429-
self.node_to_event[hiding_node].position
435+
self.node_to_event[coll_hiding_node].position
430436
)
431437

432438
# Get new positions
@@ -478,11 +484,10 @@ def get_pos(n: fx.Node) -> int:
478484
curr_event.node not in all_bucketed_colls
479485
and curr_event.node not in all_bucketed_waits
480486
):
481-
exec_interval, hiding_interval = self._get_intervals(curr_event)
487+
exec_interval, hiding_interval_list = self._get_intervals(curr_event)
482488
if exec_interval:
483489
execution_intervals.append(exec_interval)
484-
if hiding_interval:
485-
hiding_intervals.append(hiding_interval)
490+
hiding_intervals.extend(hiding_interval_list)
486491
curr_event = curr_event.next
487492

488493
curr_event = new_wait_event.prev
@@ -491,11 +496,10 @@ def get_pos(n: fx.Node) -> int:
491496
curr_event.node not in all_bucketed_colls
492497
and curr_event.node not in all_bucketed_waits
493498
):
494-
exec_interval, hiding_interval = self._get_intervals(curr_event)
499+
exec_interval, hiding_interval_list = self._get_intervals(curr_event)
495500
if exec_interval:
496501
execution_intervals.append(exec_interval)
497-
if hiding_interval:
498-
hiding_intervals.append(hiding_interval)
502+
hiding_intervals.extend(hiding_interval_list)
499503
curr_event = curr_event.prev
500504

501505
# Check: no hiding interval should be enclosed by any execution interval
@@ -659,12 +663,12 @@ def _has_ancestor_conflicts(
659663
return True
660664

661665
# Check if existing hiding node conflicts with candidate wait
662-
if hiding_node := self.collective_info[coll].hiding_node:
663-
if self._ancestor_dep(hiding_node, candidate_wait):
666+
for old_hiding_node in self.collective_info[coll].hiding_nodes:
667+
if self._ancestor_dep(old_hiding_node, candidate_wait):
664668
return True
665669

666670
# Check if candidate hiding node conflicts with existing wait
667-
if new_hiding_node := candidate_info.hiding_node:
671+
for new_hiding_node in candidate_info.hiding_nodes:
668672
if self._ancestor_dep(new_hiding_node, coll_wait):
669673
return True
670674

0 commit comments

Comments
 (0)