Skip to content

Commit 6f826eb

Browse files
committed
[https://nvbugs/5791900][fix] Fix HelixCpMnnvlMemory init with PP
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
1 parent c5d5af9 commit 6f826eb

File tree

4 files changed

+25
-5
lines changed

4 files changed

+25
-5
lines changed

tensorrt_llm/_mnnvl_utils.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -370,15 +370,32 @@ def get_comm(cls, mapping: Mapping):
370370
if cls.comm is not None:
371371
return cls.comm
372372
comm = mpi_comm().Split(
373-
mapping.pp_rank * mapping.tp_size * mapping.moe_tp_size
374-
+ mapping.tp_rank * mapping.moe_tp_size
375-
+ mapping.moe_tp_rank,
373+
mapping.pp_rank * mapping.tp_size + mapping.tp_rank,
376374
mapping.cp_rank,
377375
)
378376
cls.comm = comm
379377
return comm
380378

381379

380+
def init_helix_cp_comm(mapping: Mapping) -> None:
381+
"""Pre-initialize the Helix CP communicator.
382+
383+
This function MUST be called during model initialization when all ranks
384+
are synchronized (before any PP pipeline divergence). The MPI Split operation
385+
is collective and requires all ranks in the communicator to participate.
386+
387+
In PP (pipeline parallel) mode, different PP stages execute different parts
388+
of the model at different times. If the communicator is initialized lazily
389+
during the first forward pass, ranks in different PP stages may not reach
390+
the Split operation at the same time, causing a deadlock.
391+
392+
Args:
393+
mapping: The mapping object containing parallelism configuration.
394+
"""
395+
if mapping.has_cp_helix() and not mapping.cp_config.get("use_nccl_for_alltoall", True):
396+
HelixCpMnnvlMemory.get_comm(mapping)
397+
398+
382399
@dataclass
383400
class MoEAlltoallInfo:
384401
local_gather_indices: torch.Tensor

tensorrt_llm/_torch/distributed/communicator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
except Exception:
1717
MPI = None # deferred; functions will error if used when ENABLE_MULTI_DEVICE is True
1818

19+
from tensorrt_llm._mnnvl_utils import init_helix_cp_comm
1920
from tensorrt_llm._utils import (mpi_allgather, mpi_barrier, mpi_comm,
2021
mpi_disabled, mpi_isend, mpi_isend_object,
2122
mpi_recv, mpi_recv_object, mpi_send,
@@ -888,6 +889,7 @@ def init_pp_comm(mapping):
888889
_pp_comm = PPCommTorch(mapping)
889890
else:
890891
_pp_comm = PPCommNCCL(mapping)
892+
init_helix_cp_comm(mapping)
891893

892894

893895
@TorchDist.log_op

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,7 @@ def test_auto_dtype(self, overlap_scheduler, mtp_nextn):
871871
task = GSM8K(self.MODEL_NAME)
872872
task.evaluate(llm)
873873

874+
@skip_pre_blackwell
874875
@pytest.mark.skip_less_device(8)
875876
@pytest.mark.parametrize("gen_pp,gen_tp,gen_cp", [(1, 1, 4), (1, 2, 2),
876877
(2, 1, 2)],

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ l0_dgx_b200:
7171
backend: pytorch
7272
orchestrator: mpi
7373
tests:
74-
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
74+
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
7575
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60)
7676
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60)
7777
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] TIMEOUT (60)
@@ -101,7 +101,7 @@ l0_dgx_b200:
101101
backend: pytorch
102102
orchestrator: mpi
103103
tests:
104-
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
104+
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
105105
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60)
106106
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60)
107107
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (60)

0 commit comments

Comments
 (0)