Skip to content

Commit 1d51e69

Browse files
pytorchbotlw
andauthored
[async-TP] Turn asserts back into silent skips (pytorch#158736)
[async-TP] Turn asserts back into silent skips (pytorch#158572) pytorch#149946 modified some checks that verify whether async-TP is "applicable" to a given collective operation in a graph. Before, the pattern-mathcing+replacement would just be skipped, but now these are asserts that fail and raise. This is causing concrete issues in some graphs where 2-dimensional device meshes are being used (e.g., TP + CP) but only one dimension has symm-mem enabled. See pytorch#158569. This PR is turning these asserts back into harmless early-exits. Note that this only needed to be done for reduce-scatters, as it was already the case for all-gathers. Pull Request resolved: pytorch#158572 Approved by: https://github.com/danielvegamyhre, https://github.com/atalman (cherry picked from commit fac0be7) Co-authored-by: Luca Wehrstedt <[email protected]>
1 parent 06152d9 commit 1d51e69

File tree

3 files changed

+66
-8
lines changed

3 files changed

+66
-8
lines changed

test/distributed/tensor/parallel/test_micro_pipeline_tp.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,5 +494,55 @@ def test_dtensor_seq_par(self, shard_dim: int):
494494
self.assertNotIn("reduce_scatter_tensor", code)
495495

496496

497+
@instantiate_parametrized_tests
498+
class MicroPipelineTP4GPUTest(TestCase):
499+
def setUp(self):
500+
torch._inductor.config._micro_pipeline_tp = True
501+
502+
self.rank = 0
503+
self.world_size = 4
504+
torch.cuda.set_device("cuda:0")
505+
506+
store = FakeStore()
507+
dist.init_process_group(
508+
backend="fake",
509+
world_size=self.world_size,
510+
rank=self.rank,
511+
store=store,
512+
)
513+
514+
def tearDown(self):
515+
dist.destroy_process_group()
516+
517+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
518+
@fresh_cache()
519+
def test_extra_collectives(self):
520+
device_mesh = DeviceMesh(
521+
"cuda",
522+
torch.arange(0, self.world_size).view(2, -1),
523+
mesh_dim_names=("tp", "other"),
524+
)
525+
526+
def func(inp: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor) -> torch.Tensor:
527+
hidden = all_gather_tensor(inp, 0, (device_mesh, 0)) @ w1.t()
528+
full_hidden = all_gather_tensor(hidden, 0, (device_mesh, 1))
529+
full_hidden /= full_hidden.pow(2).sum().sqrt()
530+
hidden = reduce_scatter_tensor(full_hidden, "avg", 0, (device_mesh, 1))
531+
return reduce_scatter_tensor(hidden @ w2.t(), "avg", 0, (device_mesh, 0))
532+
533+
inp = torch.rand(8, 10, device="cuda")
534+
w1 = torch.rand(7, 10, device="cuda")
535+
w2 = torch.rand(10, 7, device="cuda")
536+
537+
with _test_mode(group_names={device_mesh["tp"].get_group().group_name}):
538+
compiled = torch.compile(func)
539+
code = run_and_get_triton_code(compiled, inp, w1, w2)
540+
541+
self.assertIn("fused_all_gather_matmul", code)
542+
self.assertIn("all_gather_into_tensor", code)
543+
self.assertIn("fused_matmul_reduce_scatter", code)
544+
self.assertIn("reduce_scatter_tensor", code)
545+
546+
497547
if __name__ == "__main__":
498548
run_tests()

torch/_inductor/fx_passes/micro_pipeline_tp.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -850,9 +850,11 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
850850
851851
Returns boolean indicating if fusion was successful or not.
852852
"""
853-
assert torch.distributed.is_available() and torch.distributed.is_nccl_available(), (
854-
"torch.distributed and NCCL must be available to use async tensor parallelism"
855-
)
853+
if (
854+
not torch.distributed.is_available()
855+
or not torch.distributed.is_nccl_available()
856+
):
857+
return
856858

857859
from torch.distributed._symmetric_memory import (
858860
is_symm_mem_enabled_for_group,
@@ -875,9 +877,8 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
875877
reduce_scatter.group_name,
876878
)
877879

878-
assert is_symm_mem_enabled_for_group(group_name), (
879-
f"symmetric memory is not enabled for process group {group_name}, this is required for async TP"
880-
)
880+
if not is_symm_mem_enabled_for_group(group_name):
881+
return
881882

882883
# Currently fused_matmul_reduce_scatter doesn't return the matmul result,
883884
# so we can't apply the fusion if the matmul result is used by multiple

torch/distributed/_symmetric_memory/__init__.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,23 +47,28 @@ def enable_symm_mem_for_group(group_name: str) -> None:
4747

4848

4949
_is_test_mode: bool = False
50+
_mocked_group_names: Optional[set[str]] = None
5051

5152

5253
@contextmanager
53-
def _test_mode() -> Generator[None, None, None]:
54+
def _test_mode(group_names: Optional[set[str]] = None) -> Generator[None, None, None]:
5455
"""
5556
Forces ``is_symm_mem_enabled_for_group()`` to return ``True`` and the ops
5657
defined in the ``symm_mem`` namespace to use fallback implementations.
5758
5859
The context manager is not thread safe.
5960
"""
6061
global _is_test_mode
62+
global _mocked_group_names
6163
prev = _is_test_mode
64+
prev_group_names = _mocked_group_names
6265
try:
6366
_is_test_mode = True
67+
_mocked_group_names = group_names
6468
yield
6569
finally:
6670
_is_test_mode = prev
71+
_mocked_group_names = prev_group_names
6772

6873

6974
def is_symm_mem_enabled_for_group(group_name: str) -> bool:
@@ -73,7 +78,9 @@ def is_symm_mem_enabled_for_group(group_name: str) -> bool:
7378
Args:
7479
group_name (str): the name of the process group.
7580
"""
76-
return _is_test_mode or group_name in _group_name_to_store
81+
if _is_test_mode:
82+
return _mocked_group_names is None or group_name in _mocked_group_names
83+
return group_name in _group_name_to_store
7784

7885

7986
_group_name_to_workspace_tensor: dict[str, Optional[torch.Tensor]] = {}

0 commit comments

Comments
 (0)