Skip to content

Commit 93f9992

Browse files
committed
test_mla_helix: Initialize torch.distributed with NCCL for helix alltoall
The alltoall_helix operation requires NCCL to be initialized. Previously, the test was using MPI for process spawning but not initializing NCCL. Changes: - Add _init_torch_distributed() to initialize torch.distributed with NCCL backend, coordinating master port via MPI broadcast - Modify _run_single_rank() to take world_size and initialize NCCL before running the test function - Clean up torch.distributed in finally block - Update executor.map() call to pass world_size twice (for NCCL init and for the test function)
1 parent b06ba6d commit 93f9992

File tree

1 file changed

+46
-2
lines changed

1 file changed

+46
-2
lines changed

tests/unittest/_torch/modules/test_mla_helix.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
- Latent cache communication happens within CP groups
2727
"""
2828

29+
import os
2930
import pickle
3031
import sys
3132
import time
@@ -38,10 +39,12 @@
3839
import cloudpickle
3940
import pytest
4041
import torch
42+
import torch.distributed as dist
4143
from mpi4py import MPI
4244
from mpi4py.futures import MPIPoolExecutor
4345

4446
import tensorrt_llm
47+
from tensorrt_llm._utils import get_free_port
4548
from tensorrt_llm._torch.attention_backend.interface import (
4649
AttentionMetadata,
4750
KVCacheParams,
@@ -1048,19 +1051,58 @@ def _full_test_multi_gpu(
10481051

10491052

10501053

1051-
def _run_single_rank(func, *args, **kwargs):
1054+
def _init_torch_distributed(world_size: int):
1055+
"""Initialize torch.distributed with NCCL backend for helix all-to-all communication."""
1056+
if dist.is_initialized():
1057+
return
1058+
1059+
rank = tensorrt_llm.mpi_rank()
1060+
comm = MPI.COMM_WORLD
1061+
1062+
# Rank 0 generates a free port and broadcasts to all ranks
1063+
if rank == 0:
1064+
master_port = get_free_port()
1065+
else:
1066+
master_port = None
1067+
master_port = comm.bcast(master_port, root=0)
1068+
1069+
# Set up environment variables for torch.distributed
1070+
os.environ["MASTER_ADDR"] = "127.0.0.1"
1071+
os.environ["MASTER_PORT"] = str(master_port)
1072+
os.environ["RANK"] = str(rank)
1073+
os.environ["WORLD_SIZE"] = str(world_size)
1074+
os.environ["LOCAL_RANK"] = str(rank)
1075+
1076+
# Initialize torch.distributed with NCCL backend
1077+
dist.init_process_group(
1078+
backend="nccl",
1079+
init_method=f"tcp://127.0.0.1:{master_port}",
1080+
rank=rank,
1081+
world_size=world_size
1082+
)
1083+
print(f"[Rank {rank}] torch.distributed initialized with NCCL")
1084+
1085+
1086+
def _run_single_rank(func, world_size: int, *args, **kwargs):
10521087
"""Wrapper to run a function on a single MPI rank."""
10531088
rank = tensorrt_llm.mpi_rank()
10541089
torch.cuda.set_device(rank)
10551090
print(f"rank {rank} starting")
10561091
try:
1092+
# Initialize torch.distributed with NCCL for helix communication
1093+
_init_torch_distributed(world_size)
1094+
10571095
ret = func(rank, *args, **kwargs)
10581096
print(f"rank {rank} done")
10591097
return ret
10601098
except Exception:
10611099
traceback.print_exc()
10621100
tb = traceback.format_exc()
10631101
raise Exception(f"\n\nError occurred. Original traceback is\n{tb}\n")
1102+
finally:
1103+
# Cleanup torch.distributed
1104+
if dist.is_initialized():
1105+
dist.destroy_process_group()
10641106

10651107

10661108
# ============================================================================
@@ -1107,7 +1149,9 @@ def test_mla_helix_distributed_mixed_tp_cp(
11071149
with MPIPoolExecutor(max_workers=world_size) as executor:
11081150
results = executor.map(
11091151
_run_single_rank,
1110-
*zip(*[(_full_test_multi_gpu, world_size, tp_size, cp_size, scenario, gen_steps)] * world_size),
1152+
# First arg is the function, second is world_size for NCCL init,
1153+
# then the remaining args are passed to the function
1154+
*zip(*[(_full_test_multi_gpu, world_size, world_size, tp_size, cp_size, scenario, gen_steps)] * world_size),
11111155
)
11121156
if mismatch_ratios is None:
11131157
for ratio_mismatch in results:

0 commit comments

Comments
 (0)