Skip to content

Commit e9af2f3

Browse files
committed
[TRTLLM-9467][fix] Fix PP+CP combination with helix parallelism
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent 0d2e271 commit e9af2f3

File tree

8 files changed

+424
-23
lines changed

8 files changed

+424
-23
lines changed

tensorrt_llm/_torch/distributed/communicator.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,10 @@ def create_cp_comm(self):
407407
def cp_allgather(self, obj):
408408
return self.cp_comm.allgather(obj)
409409

410+
def cp_broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
411+
comm = self.cp_comm
412+
return safe_broadcast(comm, obj, root=root, chunk_size=chunk_size)
413+
410414
def tp_allgather(self, obj):
411415
return self.tp_comm.allgather(obj)
412416

@@ -712,6 +716,20 @@ def tp_broadcast(self, obj, root=0):
712716
device=torch.device("cpu"))
713717
return ret[0]
714718

719+
@log_op
720+
def cp_broadcast(self, obj, root=0):
721+
if isinstance(obj, torch.Tensor):
722+
dist.broadcast(obj, src=root, group=self.mapping.cp_group_pg)
723+
return obj
724+
else:
725+
ret = [obj]
726+
torch.distributed.broadcast_object_list(
727+
ret,
728+
src=root,
729+
group=self.mapping.cp_group_pg,
730+
device=torch.device("cpu"))
731+
return ret[0]
732+
715733
@log_op
716734
def pp_allgather(self, obj):
717735
if isinstance(obj, torch.Tensor):

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -587,9 +587,15 @@ def _broadcast_new_requests(
587587
if not self.dist.has_pp:
588588
return self.dist.broadcast(payloads, root=0)
589589

590-
# Broadcast within first tp group before send/recv chain to other tp groups
591-
if self.dist.tp_size > 1 and self.dist.is_first_pp_rank:
592-
payloads = self.dist.tp_broadcast(payloads, root=0)
590+
# Broadcast within first PP stage before send/recv chain to other PP stages.
591+
# This needs to cover both TP and CP ranks within the first PP stage.
592+
if self.dist.is_first_pp_rank:
593+
if self.dist.tp_size > 1:
594+
payloads = self.dist.tp_broadcast(payloads, root=0)
595+
# Also broadcast within CP group when CP is enabled (helix parallelism).
596+
# This ensures all CP ranks within the first PP stage receive the requests.
597+
if self.dist.cp_size > 1:
598+
payloads = self.dist.cp_broadcast(payloads, root=0)
593599

594600
# Tag for communication
595601
tag = self.dist.pp_size # Use pp_size as tag to avoid conflicts

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -871,7 +871,9 @@ def test_auto_dtype(self, overlap_scheduler, mtp_nextn):
871871
task = GSM8K(self.MODEL_NAME)
872872
task.evaluate(llm)
873873

874-
@pytest.mark.skip_less_device(4)
874+
@pytest.mark.skip_less_device(8)
875+
@pytest.mark.parametrize("gen_pp,gen_tp,gen_cp", [(1, 2, 2), (2, 1, 2)],
876+
ids=["pp1tp2cp2", "pp2tp1cp2"])
875877
@pytest.mark.parametrize("cuda_graph_config", [
876878
None,
877879
{
@@ -888,8 +890,10 @@ def test_auto_dtype(self, overlap_scheduler, mtp_nextn):
888890
"cudagraph:with_padding"
889891
])
890892
@pytest.mark.parametrize("comms_medium", ["fifo", "nccl"])
891-
def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config):
893+
def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config,
894+
gen_pp, gen_tp, gen_cp):
892895
use_nccl_for_alltoall = comms_medium == "nccl"
896+
gen_ep = gen_tp * gen_cp
893897
kv_cache_config = {
894898
"free_gpu_memory_fraction": 0.5,
895899
"enable_block_reuse": False,
@@ -898,20 +902,22 @@ def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config):
898902
}
899903
ctx_server_config = {
900904
"pipeline_parallel_size": 1,
901-
"tensor_parallel_size": 2,
905+
"tensor_parallel_size": 4,
902906
"context_parallel_size": 1,
903907
"disable_overlap_scheduler": True,
904908
"kv_cache_config": kv_cache_config,
905909
"enable_chunked_prefill": False,
906910
"cuda_graph_config": None,
907911
"cache_transceiver_config": {
908-
"backend": "UCX"
912+
"backend": "UCX",
913+
"max_tokens_in_buffer": 8192,
909914
},
910915
}
911916
gen_server_config = {
912-
"tensor_parallel_size": 1,
913-
"pipeline_parallel_size": 1,
914-
"context_parallel_size": 2,
917+
"tensor_parallel_size": gen_tp,
918+
"pipeline_parallel_size": gen_pp,
919+
"context_parallel_size": gen_cp,
920+
"moe_expert_parallel_size": gen_ep,
915921
"cp_config": {
916922
"cp_type": "HELIX",
917923
"tokens_per_block": 32,
@@ -922,7 +928,8 @@ def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config):
922928
"enable_chunked_prefill": False,
923929
"cuda_graph_config": cuda_graph_config,
924930
"cache_transceiver_config": {
925-
"backend": "UCX"
931+
"backend": "UCX",
932+
"max_tokens_in_buffer": 8192,
926933
},
927934
}
928935
disaggregated_server_config = {

tests/integration/test_lists/qa/llm_function_core.txt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -540,14 +540,14 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding
540540
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2]
541541
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=0]
542542
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=2]
543-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none]
544-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding]
545-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding]
546-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none]
547-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding]
548-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding]
549543
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]
550544
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]
545+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none-pp2tp1cp2]
546+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding-pp2tp1cp2]
547+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2]
548+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none-pp2tp1cp2]
549+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding-pp2tp1cp2]
550+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2]
551551
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True]
552552
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[False]
553553
accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_w4a8_mxfp4[fp8-latency]

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ l0_dgx_b200:
6666
backend: pytorch
6767
orchestrator: mpi
6868
tests:
69+
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
6970
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] TIMEOUT (60)
7071
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_mtp] TIMEOUT (60)
7172
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_bs8_mtp] TIMEOUT (60)
@@ -92,6 +93,7 @@ l0_dgx_b200:
9293
backend: pytorch
9394
orchestrator: mpi
9495
tests:
96+
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
9597
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (60)
9698
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline_fp8kv] TIMEOUT (60)
9799
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[latency] TIMEOUT (60)

tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,6 @@ l0_gb200_multi_gpus:
7272
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-cutlass]
7373
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[no_cuda_graph_overlap-cutlass]
7474
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-trtllm]
75-
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none]
76-
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none]
7775
- accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm_eagle] TIMEOUT (90)
7876
- condition:
7977
ranges:
@@ -89,10 +87,6 @@ l0_gb200_multi_gpus:
8987
stage: post_merge
9088
backend: pytorch
9189
tests:
92-
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding]
93-
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding]
94-
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding]
95-
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding]
9690
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=FLASHINFER-torch_compile=True]
9791
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False]
9892
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
"""
2+
Tests for cp_broadcast functionality in both MPIDist and TorchDist.
3+
4+
This module tests the context parallelism broadcast operation which is used
5+
when CP (context parallelism) is enabled (e.g., in Helix parallelism).
6+
7+
For MPIDist tests, run with mpirun:
8+
mpirun -n 2 python -m pytest tests/unittest/_torch/distributed/test_cp_broadcast.py -v
9+
10+
For TorchDist tests, see test_ops.py which uses Ray for distributed testing.
11+
"""
12+
13+
import numpy as np
14+
import pytest
15+
16+
from tensorrt_llm._torch.distributed import MPIDist
17+
from tensorrt_llm.mapping import Mapping
18+
19+
20+
def get_mpi_info():
21+
"""Get MPI rank and world size, returns (0, 1) if MPI is not available."""
22+
try:
23+
from mpi4py import MPI
24+
25+
comm = MPI.COMM_WORLD
26+
return comm.Get_rank(), comm.Get_size()
27+
except ImportError:
28+
return 0, 1
29+
30+
31+
def skip_if_not_mpi():
32+
"""Skip test if not running under MPI with sufficient ranks."""
33+
rank, world_size = get_mpi_info()
34+
if world_size < 2:
35+
pytest.skip("Test requires at least 2 MPI ranks (run with mpirun -n 2)")
36+
37+
38+
class TestMPIDistCpBroadcast:
39+
"""Tests for MPIDist.cp_broadcast functionality."""
40+
41+
@pytest.fixture(autouse=True)
42+
def setup(self):
43+
"""Set up MPI environment and mapping for each test."""
44+
skip_if_not_mpi()
45+
self.rank, self.world_size = get_mpi_info()
46+
47+
# Set up mapping with CP enabled (cp_size = world_size, tp_size = 1)
48+
self.mapping = Mapping(
49+
world_size=self.world_size,
50+
rank=self.rank,
51+
tp_size=1,
52+
cp_size=self.world_size,
53+
pp_size=1,
54+
)
55+
self.dist = MPIDist(mapping=self.mapping)
56+
57+
def test_broadcast_numpy_array(self):
58+
"""Test broadcasting a numpy array via cp_broadcast."""
59+
root = 0
60+
shape = (64, 128)
61+
62+
if self.mapping.cp_rank == root:
63+
# Root rank creates the data to broadcast
64+
data = np.random.randn(*shape).astype(np.float32)
65+
else:
66+
# Non-root ranks have empty/zero data
67+
data = np.zeros(shape, dtype=np.float32)
68+
69+
# Store original data from root for verification
70+
from mpi4py import MPI
71+
72+
expected = np.zeros(shape, dtype=np.float32)
73+
MPI.COMM_WORLD.Bcast(data if self.mapping.cp_rank == root else expected, root=root)
74+
if self.mapping.cp_rank == root:
75+
expected = data.copy()
76+
77+
# Perform cp_broadcast
78+
result = self.dist.cp_broadcast(data, root=root)
79+
80+
# Verify all ranks have the same data
81+
np.testing.assert_array_almost_equal(result, expected)
82+
83+
def test_broadcast_python_dict(self):
84+
"""Test broadcasting a Python dictionary via cp_broadcast."""
85+
root = 0
86+
87+
if self.mapping.cp_rank == root:
88+
obj = {
89+
"model_name": "llama",
90+
"batch_size": 32,
91+
"tokens": [1, 2, 3, 4, 5],
92+
"config": {"hidden_size": 4096, "num_layers": 32},
93+
}
94+
else:
95+
obj = None
96+
97+
result = self.dist.cp_broadcast(obj, root=root)
98+
99+
# Verify all ranks received the correct object
100+
assert result["model_name"] == "llama"
101+
assert result["batch_size"] == 32
102+
assert result["tokens"] == [1, 2, 3, 4, 5]
103+
assert result["config"]["hidden_size"] == 4096
104+
assert result["config"]["num_layers"] == 32
105+
106+
def test_broadcast_python_list(self):
107+
"""Test broadcasting a Python list via cp_broadcast."""
108+
root = 0
109+
110+
if self.mapping.cp_rank == root:
111+
obj = ["request1", "request2", {"id": 123, "data": [1, 2, 3]}]
112+
else:
113+
obj = None
114+
115+
result = self.dist.cp_broadcast(obj, root=root)
116+
117+
assert result == ["request1", "request2", {"id": 123, "data": [1, 2, 3]}]
118+
119+
def test_broadcast_from_non_zero_root(self):
120+
"""Test broadcasting from a non-zero root rank."""
121+
if self.world_size < 2:
122+
pytest.skip("Need at least 2 ranks to test non-zero root")
123+
124+
root = 1 # Broadcast from rank 1
125+
126+
if self.mapping.cp_rank == root:
127+
obj = {"source": "rank1", "value": 42}
128+
else:
129+
obj = None
130+
131+
result = self.dist.cp_broadcast(obj, root=root)
132+
133+
assert result["source"] == "rank1"
134+
assert result["value"] == 42
135+
136+
def test_broadcast_large_object(self):
137+
"""Test broadcasting a large object that may require chunking."""
138+
root = 0
139+
# Create a large list to test chunking behavior
140+
large_size = 100000
141+
142+
if self.mapping.cp_rank == root:
143+
obj = list(range(large_size))
144+
else:
145+
obj = None
146+
147+
result = self.dist.cp_broadcast(obj, root=root)
148+
149+
assert len(result) == large_size
150+
assert result[0] == 0
151+
assert result[-1] == large_size - 1
152+
153+
def test_broadcast_string(self):
154+
"""Test broadcasting a simple string via cp_broadcast."""
155+
root = 0
156+
157+
if self.mapping.cp_rank == root:
158+
obj = "Hello from root rank!"
159+
else:
160+
obj = None
161+
162+
result = self.dist.cp_broadcast(obj, root=root)
163+
164+
assert result == "Hello from root rank!"
165+
166+
167+
# Additional integration-style test that can be run standalone
168+
def test_mpi_cp_broadcast_integration():
169+
"""
170+
Integration test for MPIDist cp_broadcast.
171+
"""
172+
rank, world_size = get_mpi_info()
173+
if world_size < 2:
174+
pytest.skip("Test requires at least 2 MPI ranks")
175+
176+
# Create mapping with CP enabled
177+
mapping = Mapping(
178+
world_size=world_size,
179+
rank=rank,
180+
tp_size=1,
181+
cp_size=world_size,
182+
pp_size=1,
183+
)
184+
dist = MPIDist(mapping=mapping)
185+
186+
# Test 1: Broadcast dict
187+
if mapping.cp_rank == 0:
188+
payload = {"requests": [{"id": i} for i in range(10)]}
189+
else:
190+
payload = None
191+
192+
result = dist.cp_broadcast(payload, root=0)
193+
assert len(result["requests"]) == 10
194+
assert result["requests"][0]["id"] == 0
195+
196+
# Test 2: Broadcast numpy array
197+
shape = (32, 64)
198+
if mapping.cp_rank == 0:
199+
arr = np.ones(shape, dtype=np.float32) * (rank + 1)
200+
else:
201+
arr = np.zeros(shape, dtype=np.float32)
202+
203+
result = dist.cp_broadcast(arr, root=0)
204+
expected_val = 1.0 # From rank 0
205+
np.testing.assert_array_almost_equal(result, np.ones(shape) * expected_val)
206+
207+
208+
if __name__ == "__main__":
209+
# Allow running directly with mpirun
210+
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)