Skip to content

Commit 4a1b742

Browse files
authored
[TRTLLM-9467][fix] Fix PP+CP combination with helix parallelism (#10312)
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent 5845951 commit 4a1b742

File tree

10 files changed

+532
-26
lines changed

10 files changed

+532
-26
lines changed

tensorrt_llm/_torch/autotuner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1506,7 +1506,7 @@ def _broadcast_cache_data(
15061506
"""Broadcast tactics from root rank to all other ranks."""
15071507
cache_data = self.profiling_cache.get_specific_custom_op(custom_op)
15081508
root = 0
1509-
cache_data = self._dist.tp_broadcast(obj=cache_data, root=root)
1509+
cache_data = self._dist.tp_cp_broadcast(obj=cache_data, root=root)
15101510

15111511
self.profiling_cache.merge_cache_data(cache_data)
15121512

tensorrt_llm/_torch/distributed/communicator.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,26 @@ def broadcast(self, obj, root=0):
116116
def allgather(self, obj, root=0):
117117
pass
118118

119+
@abstractmethod
120+
def tp_broadcast(self, obj, root=0, **kwargs):
121+
pass
122+
123+
@abstractmethod
124+
def cp_broadcast(self, obj, root=0, **kwargs):
125+
pass
126+
127+
def tp_cp_broadcast(self, obj, root=0, **kwargs):
128+
"""Broadcast object across both TP and CP groups.
129+
130+
This is used when both TP and CP parallelism are enabled (e.g., helix parallelism).
131+
First broadcasts within the TP group, then within the CP group.
132+
"""
133+
if self.tp_size > 1:
134+
obj = self.tp_broadcast(obj, root=root, **kwargs)
135+
if self.cp_size > 1:
136+
obj = self.cp_broadcast(obj, root=root, **kwargs)
137+
return obj
138+
119139

120140
def safe_broadcast(comm, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
121141
"""
@@ -407,14 +427,26 @@ def create_cp_comm(self):
407427
def cp_allgather(self, obj):
408428
return self.cp_comm.allgather(obj)
409429

430+
def cp_broadcast(self,
431+
obj,
432+
root=0,
433+
chunk_size: int = 4 * 1024 * 1024,
434+
**kwargs):
435+
comm = self.cp_comm
436+
return safe_broadcast(comm, obj, root=root, chunk_size=chunk_size)
437+
410438
def tp_allgather(self, obj):
411439
return self.tp_comm.allgather(obj)
412440

413441
def tp_gather(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
414442
comm = self.tp_comm
415443
return safe_gather(comm, obj, root=root, chunk_size=chunk_size)
416444

417-
def tp_broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
445+
def tp_broadcast(self,
446+
obj,
447+
root=0,
448+
chunk_size: int = 4 * 1024 * 1024,
449+
**kwargs):
418450
comm = self.tp_comm
419451
return safe_broadcast(comm, obj, root=root, chunk_size=chunk_size)
420452

@@ -699,7 +731,7 @@ def tp_gather(self, obj, dst=0):
699731
return output_list
700732

701733
@log_op
702-
def tp_broadcast(self, obj, root=0):
734+
def tp_broadcast(self, obj, root=0, **kwargs):
703735
if isinstance(obj, torch.Tensor):
704736
dist.broadcast(obj, src=root, group=self.mapping.tp_group_pg)
705737
return obj
@@ -712,6 +744,20 @@ def tp_broadcast(self, obj, root=0):
712744
device=torch.device("cpu"))
713745
return ret[0]
714746

747+
@log_op
748+
def cp_broadcast(self, obj, root=0, **kwargs):
749+
if isinstance(obj, torch.Tensor):
750+
dist.broadcast(obj, src=root, group=self.mapping.cp_group_pg)
751+
return obj
752+
else:
753+
ret = [obj]
754+
torch.distributed.broadcast_object_list(
755+
ret,
756+
src=root,
757+
group=self.mapping.cp_group_pg,
758+
device=torch.device("cpu"))
759+
return ret[0]
760+
715761
@log_op
716762
def pp_allgather(self, obj):
717763
if isinstance(obj, torch.Tensor):

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -587,9 +587,10 @@ def _broadcast_new_requests(
587587
if not self.dist.has_pp:
588588
return self.dist.broadcast(payloads, root=0)
589589

590-
# Broadcast within first tp group before send/recv chain to other tp groups
591-
if self.dist.tp_size > 1 and self.dist.is_first_pp_rank:
592-
payloads = self.dist.tp_broadcast(payloads, root=0)
590+
# Broadcast within first PP stage before send/recv chain to other PP stages.
591+
# This needs to cover both TP and CP ranks within the first PP stage.
592+
if self.dist.is_first_pp_rank:
593+
payloads = self.dist.tp_cp_broadcast(payloads, root=0)
593594

594595
# Tag for communication
595596
tag = self.dist.pp_size # Use pp_size as tag to avoid conflicts

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -871,7 +871,9 @@ def test_auto_dtype(self, overlap_scheduler, mtp_nextn):
871871
task = GSM8K(self.MODEL_NAME)
872872
task.evaluate(llm)
873873

874-
@pytest.mark.skip_less_device(4)
874+
@pytest.mark.skip_less_device(8)
875+
@pytest.mark.parametrize("gen_pp,gen_tp,gen_cp", [(1, 2, 2), (2, 1, 2)],
876+
ids=["pp1tp2cp2", "pp2tp1cp2"])
875877
@pytest.mark.parametrize("cuda_graph_config", [
876878
None,
877879
{
@@ -888,8 +890,10 @@ def test_auto_dtype(self, overlap_scheduler, mtp_nextn):
888890
"cudagraph:with_padding"
889891
])
890892
@pytest.mark.parametrize("comms_medium", ["fifo", "nccl"])
891-
def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config):
893+
def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config,
894+
gen_pp, gen_tp, gen_cp):
892895
use_nccl_for_alltoall = comms_medium == "nccl"
896+
gen_ep = gen_tp * gen_cp
893897
kv_cache_config = {
894898
"free_gpu_memory_fraction": 0.5,
895899
"enable_block_reuse": False,
@@ -898,20 +902,22 @@ def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config):
898902
}
899903
ctx_server_config = {
900904
"pipeline_parallel_size": 1,
901-
"tensor_parallel_size": 2,
905+
"tensor_parallel_size": 4,
902906
"context_parallel_size": 1,
903907
"disable_overlap_scheduler": True,
904908
"kv_cache_config": kv_cache_config,
905909
"enable_chunked_prefill": False,
906910
"cuda_graph_config": None,
907911
"cache_transceiver_config": {
908-
"backend": "UCX"
912+
"backend": "UCX",
913+
"max_tokens_in_buffer": 8192,
909914
},
910915
}
911916
gen_server_config = {
912-
"tensor_parallel_size": 1,
913-
"pipeline_parallel_size": 1,
914-
"context_parallel_size": 2,
917+
"tensor_parallel_size": gen_tp,
918+
"pipeline_parallel_size": gen_pp,
919+
"context_parallel_size": gen_cp,
920+
"moe_expert_parallel_size": gen_ep,
915921
"cp_config": {
916922
"cp_type": "HELIX",
917923
"tokens_per_block": 32,
@@ -922,7 +928,8 @@ def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config):
922928
"enable_chunked_prefill": False,
923929
"cuda_graph_config": cuda_graph_config,
924930
"cache_transceiver_config": {
925-
"backend": "UCX"
931+
"backend": "UCX",
932+
"max_tokens_in_buffer": 8192,
926933
},
927934
}
928935
disaggregated_server_config = {

tests/integration/test_lists/qa/llm_function_core.txt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -540,14 +540,14 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding
540540
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2]
541541
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=0]
542542
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=2]
543-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none]
544-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding]
545-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding]
546-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none]
547-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding]
548-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding]
549543
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]
550544
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]
545+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none-pp2tp1cp2]
546+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding-pp2tp1cp2]
547+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2]
548+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none-pp2tp1cp2]
549+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding-pp2tp1cp2]
550+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2]
551551
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True]
552552
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[False]
553553
accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_w4a8_mxfp4[fp8-latency]

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ l0_dgx_b200:
6666
backend: pytorch
6767
orchestrator: mpi
6868
tests:
69+
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
6970
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] TIMEOUT (60)
7071
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_mtp] TIMEOUT (60)
7172
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_bs8_mtp] TIMEOUT (60)
@@ -92,6 +93,7 @@ l0_dgx_b200:
9293
backend: pytorch
9394
orchestrator: mpi
9495
tests:
96+
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
9597
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (60)
9698
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline_fp8kv] TIMEOUT (60)
9799
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[latency] TIMEOUT (60)

tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,6 @@ l0_gb200_multi_gpus:
7272
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-cutlass]
7373
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[no_cuda_graph_overlap-cutlass]
7474
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-trtllm]
75-
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none]
76-
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none]
7775
- accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm_eagle] TIMEOUT (90)
7876
- condition:
7977
ranges:
@@ -89,10 +87,6 @@ l0_gb200_multi_gpus:
8987
stage: post_merge
9088
backend: pytorch
9189
tests:
92-
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding]
93-
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding]
94-
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding]
95-
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding]
9690
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=FLASHINFER-torch_compile=True]
9791
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False]
9892
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]

tests/integration/test_lists/waives.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,8 @@ accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_sof
523523
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.9] SKIP (https://nvbugs/5774869)
524524
triton_server/test_triton.py::test_llava_onevision[llava_onevision] SKIP (https://nvbugs/5775205)
525525
triton_server/test_triton.py::test_gpt_ib_lad[gpt-ib-lad] SKIP (https://nvbugs/5775223)
526+
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4[torch_compile=False] SKIP (https://nvbugs/5777044)
527+
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4[torch_compile=True] SKIP (https://nvbugs/5777044)
526528
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_trtllm-torch_compile=True] SKIP (https://nvbugs/5740377)
527529
unittest/_torch/modules/test_fused_moe.py::test_fused_moe_fp8_blockwise_cute_dsl_multi_gpu[MoEWeightLoadingMode.FUSED_GATE_UP_PROJ-DefaultMoeRoutingMethod-1] SKIP (https://nvbugs/5775256)
528530
unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py::test_ep_shard[3-2] SKIP (https://nvbugs/5777041)

0 commit comments

Comments
 (0)