Skip to content

Commit 96e915c

Browse files
committed
[TRTLLM-9467][fix] Fix PP+CP combination with helix parallelism
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent 13ffe52 commit 96e915c

File tree

6 files changed

+50
-23
lines changed

6 files changed

+50
-23
lines changed

tensorrt_llm/_torch/distributed/communicator.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,10 @@ def create_cp_comm(self):
407407
def cp_allgather(self, obj):
408408
return self.cp_comm.allgather(obj)
409409

410+
def cp_broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
411+
comm = self.cp_comm
412+
return safe_broadcast(comm, obj, root=root, chunk_size=chunk_size)
413+
410414
def tp_allgather(self, obj):
411415
return self.tp_comm.allgather(obj)
412416

@@ -723,6 +727,20 @@ def tp_broadcast(self, obj, root=0):
723727
device=torch.device("cpu"))
724728
return ret[0]
725729

730+
@log_op
731+
def cp_broadcast(self, obj, root=0):
732+
if isinstance(obj, torch.Tensor):
733+
dist.broadcast(obj, src=root, group=self.mapping.cp_group_pg)
734+
return obj
735+
else:
736+
ret = [obj]
737+
torch.distributed.broadcast_object_list(
738+
ret,
739+
src=root,
740+
group=self.mapping.cp_group_pg,
741+
device=torch.device("cpu"))
742+
return ret[0]
743+
726744
@log_op
727745
def pp_allgather(self, obj):
728746
if isinstance(obj, torch.Tensor):

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -587,9 +587,15 @@ def _broadcast_new_requests(
587587
if not self.dist.has_pp:
588588
return self.dist.broadcast(payloads, root=0)
589589

590-
# Broadcast within first tp group before send/recv chain to other tp groups
591-
if self.dist.tp_size > 1 and self.dist.is_first_pp_rank:
592-
payloads = self.dist.tp_broadcast(payloads, root=0)
590+
# Broadcast within first PP stage before send/recv chain to other PP stages.
591+
# This needs to cover both TP and CP ranks within the first PP stage.
592+
if self.dist.is_first_pp_rank:
593+
if self.dist.tp_size > 1:
594+
payloads = self.dist.tp_broadcast(payloads, root=0)
595+
# Also broadcast within CP group when CP is enabled (helix parallelism).
596+
# This ensures all CP ranks within the first PP stage receive the requests.
597+
if self.dist.cp_size > 1:
598+
payloads = self.dist.cp_broadcast(payloads, root=0)
593599

594600
# Tag for communication
595601
tag = self.dist.pp_size # Use pp_size as tag to avoid conflicts

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -871,7 +871,9 @@ def test_auto_dtype(self, overlap_scheduler, mtp_nextn):
871871
task = GSM8K(self.MODEL_NAME)
872872
task.evaluate(llm)
873873

874-
@pytest.mark.skip_less_device(4)
874+
@pytest.mark.skip_less_device(8)
875+
@pytest.mark.parametrize("gen_pp,gen_tp,gen_cp", [(1, 2, 2), (2, 1, 2)],
876+
ids=["pp1tp2cp2", "pp2tp1cp2"])
875877
@pytest.mark.parametrize("cuda_graph_config", [
876878
None,
877879
{
@@ -888,8 +890,10 @@ def test_auto_dtype(self, overlap_scheduler, mtp_nextn):
888890
"cudagraph:with_padding"
889891
])
890892
@pytest.mark.parametrize("comms_medium", ["fifo", "nccl"])
891-
def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config):
893+
def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config,
894+
gen_pp, gen_tp, gen_cp):
892895
use_nccl_for_alltoall = comms_medium == "nccl"
896+
gen_ep = gen_tp * gen_cp
893897
kv_cache_config = {
894898
"free_gpu_memory_fraction": 0.5,
895899
"enable_block_reuse": False,
@@ -898,20 +902,22 @@ def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config):
898902
}
899903
ctx_server_config = {
900904
"pipeline_parallel_size": 1,
901-
"tensor_parallel_size": 2,
905+
"tensor_parallel_size": 4,
902906
"context_parallel_size": 1,
903907
"disable_overlap_scheduler": True,
904908
"kv_cache_config": kv_cache_config,
905909
"enable_chunked_prefill": False,
906910
"cuda_graph_config": None,
907911
"cache_transceiver_config": {
908-
"backend": "UCX"
912+
"backend": "UCX",
913+
"max_tokens_in_buffer": 8192,
909914
},
910915
}
911916
gen_server_config = {
912-
"tensor_parallel_size": 1,
913-
"pipeline_parallel_size": 1,
914-
"context_parallel_size": 2,
917+
"tensor_parallel_size": gen_tp,
918+
"pipeline_parallel_size": gen_pp,
919+
"context_parallel_size": gen_cp,
920+
"moe_expert_parallel_size": gen_ep,
915921
"cp_config": {
916922
"cp_type": "HELIX",
917923
"tokens_per_block": 32,
@@ -922,7 +928,8 @@ def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config):
922928
"enable_chunked_prefill": False,
923929
"cuda_graph_config": cuda_graph_config,
924930
"cache_transceiver_config": {
925-
"backend": "UCX"
931+
"backend": "UCX",
932+
"max_tokens_in_buffer": 8192,
926933
},
927934
}
928935
disaggregated_server_config = {

tests/integration/test_lists/qa/llm_function_core.txt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -535,14 +535,14 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding
535535
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2]
536536
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=0]
537537
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=2]
538-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none]
539-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding]
540-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding]
541-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none]
542-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding]
543-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding]
544538
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]
545539
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]
540+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none-pp2tp1cp2]
541+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding-pp2tp1cp2]
542+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2]
543+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none-pp2tp1cp2]
544+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding-pp2tp1cp2]
545+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2]
546546
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True]
547547
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[False]
548548
accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_w4a8_mxfp4[fp8-latency]

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ l0_dgx_b200:
6666
backend: pytorch
6767
orchestrator: mpi
6868
tests:
69+
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (180)
6970
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] TIMEOUT (180)
7071
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_mtp] TIMEOUT (180)
7172
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_bs8_mtp] TIMEOUT (180)
@@ -92,6 +93,7 @@ l0_dgx_b200:
9293
backend: pytorch
9394
orchestrator: mpi
9495
tests:
96+
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (180)
9597
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (180)
9698
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline_fp8kv] TIMEOUT (180)
9799
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[latency] TIMEOUT (180)

tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,6 @@ l0_gb200_multi_gpus:
6969
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-cutlass]
7070
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[no_cuda_graph_overlap-cutlass]
7171
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-trtllm]
72-
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none]
73-
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none]
7472
- accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm_eagle] TIMEOUT (90)
7573
- condition:
7674
ranges:
@@ -86,10 +84,6 @@ l0_gb200_multi_gpus:
8684
stage: post_merge
8785
backend: pytorch
8886
tests:
89-
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding]
90-
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding]
91-
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding]
92-
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding]
9387
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=FLASHINFER-torch_compile=True]
9488
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False]
9589
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]

0 commit comments

Comments
 (0)