|
26 | 26 | - Latent cache communication happens within CP groups |
27 | 27 | """ |
28 | 28 |
|
| 29 | +import os |
29 | 30 | import pickle |
30 | 31 | import sys |
31 | 32 | import time |
|
38 | 39 | import cloudpickle |
39 | 40 | import pytest |
40 | 41 | import torch |
| 42 | +import torch.distributed as dist |
41 | 43 | from mpi4py import MPI |
42 | 44 | from mpi4py.futures import MPIPoolExecutor |
43 | 45 |
|
44 | 46 | import tensorrt_llm |
| 47 | +from tensorrt_llm._utils import get_free_port |
45 | 48 | from tensorrt_llm._torch.attention_backend.interface import ( |
46 | 49 | AttentionMetadata, |
47 | 50 | KVCacheParams, |
@@ -1048,19 +1051,58 @@ def _full_test_multi_gpu( |
1048 | 1051 |
|
1049 | 1052 |
|
1050 | 1053 |
|
1051 | | -def _run_single_rank(func, *args, **kwargs): |
| 1054 | +def _init_torch_distributed(world_size: int): |
| 1055 | + """Initialize torch.distributed with NCCL backend for helix all-to-all communication.""" |
| 1056 | + if dist.is_initialized(): |
| 1057 | + return |
| 1058 | + |
| 1059 | + rank = tensorrt_llm.mpi_rank() |
| 1060 | + comm = MPI.COMM_WORLD |
| 1061 | + |
| 1062 | + # Rank 0 generates a free port and broadcasts to all ranks |
| 1063 | + if rank == 0: |
| 1064 | + master_port = get_free_port() |
| 1065 | + else: |
| 1066 | + master_port = None |
| 1067 | + master_port = comm.bcast(master_port, root=0) |
| 1068 | + |
| 1069 | + # Set up environment variables for torch.distributed |
| 1070 | + os.environ["MASTER_ADDR"] = "127.0.0.1" |
| 1071 | + os.environ["MASTER_PORT"] = str(master_port) |
| 1072 | + os.environ["RANK"] = str(rank) |
| 1073 | + os.environ["WORLD_SIZE"] = str(world_size) |
| 1074 | + os.environ["LOCAL_RANK"] = str(rank) |
| 1075 | + |
| 1076 | + # Initialize torch.distributed with NCCL backend |
| 1077 | + dist.init_process_group( |
| 1078 | + backend="nccl", |
| 1079 | + init_method=f"tcp://127.0.0.1:{master_port}", |
| 1080 | + rank=rank, |
| 1081 | + world_size=world_size |
| 1082 | + ) |
| 1083 | + print(f"[Rank {rank}] torch.distributed initialized with NCCL") |
| 1084 | + |
| 1085 | + |
| 1086 | +def _run_single_rank(func, world_size: int, *args, **kwargs): |
1052 | 1087 | """Wrapper to run a function on a single MPI rank.""" |
1053 | 1088 | rank = tensorrt_llm.mpi_rank() |
1054 | 1089 | torch.cuda.set_device(rank) |
1055 | 1090 | print(f"rank {rank} starting") |
1056 | 1091 | try: |
| 1092 | + # Initialize torch.distributed with NCCL for helix communication |
| 1093 | + _init_torch_distributed(world_size) |
| 1094 | + |
1057 | 1095 | ret = func(rank, *args, **kwargs) |
1058 | 1096 | print(f"rank {rank} done") |
1059 | 1097 | return ret |
1060 | 1098 | except Exception: |
1061 | 1099 | traceback.print_exc() |
1062 | 1100 | tb = traceback.format_exc() |
1063 | 1101 | raise Exception(f"\n\nError occurred. Original traceback is\n{tb}\n") |
| 1102 | + finally: |
| 1103 | + # Cleanup torch.distributed |
| 1104 | + if dist.is_initialized(): |
| 1105 | + dist.destroy_process_group() |
1064 | 1106 |
|
1065 | 1107 |
|
1066 | 1108 | # ============================================================================ |
@@ -1107,7 +1149,9 @@ def test_mla_helix_distributed_mixed_tp_cp( |
1107 | 1149 | with MPIPoolExecutor(max_workers=world_size) as executor: |
1108 | 1150 | results = executor.map( |
1109 | 1151 | _run_single_rank, |
1110 | | - *zip(*[(_full_test_multi_gpu, world_size, tp_size, cp_size, scenario, gen_steps)] * world_size), |
| 1152 | + # First arg is the function, second is world_size for NCCL init, |
| 1153 | + # then the remaining args are passed to the function |
| 1154 | + *zip(*[(_full_test_multi_gpu, world_size, world_size, tp_size, cp_size, scenario, gen_steps)] * world_size), |
1111 | 1155 | ) |
1112 | 1156 | if mismatch_ratios is None: |
1113 | 1157 | for ratio_mismatch in results: |
|
0 commit comments