Skip to content

Commit c36531f

Browse files
committed
[TRTLLM-9467][fix] Fix TP+CP combination with helix parallelism
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent 270be80 commit c36531f

File tree

8 files changed

+194
-22
lines changed

8 files changed

+194
-22
lines changed

tensorrt_llm/_torch/modules/attention.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -814,10 +814,13 @@ def __init__(
814814

815815
# tensor parallel
816816
config = config or ModelConfig()
817+
override_tp_rank_for_o_proj = None
817818
if mapping_with_cp is not None:
818819
logger.warning(
819820
"[MLA::__init__] Overriding mapping with CP detected.")
820821
self.mapping = mapping_with_cp
822+
override_tp_rank_for_o_proj = mapping_with_cp.get_helix_overridden_tp_rank(
823+
)
821824
else:
822825
self.mapping = config.mapping
823826
tp_size = self.mapping.tp_size
@@ -952,7 +955,10 @@ def __init__(
952955
skip_create_weights_in_init=config.skip_create_weights_in_init,
953956
reduce_output=reduce_output,
954957
allreduce_strategy=config.allreduce_strategy,
955-
force_dynamic_quantization=config.force_dynamic_quantization)
958+
force_dynamic_quantization=config.force_dynamic_quantization,
959+
# override_tp_rank is only used for helix parallelism.
960+
override_tp_rank=override_tp_rank_for_o_proj,
961+
)
956962

957963
def yarn_get_mscale(scale=1, mscale=1):
958964
if scale <= 1:

tensorrt_llm/_torch/modules/linear.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,11 @@ def load_weights_vanilla_helper(module: Linear,
148148
assert "bias" in weights[0]
149149
device = torch.device('cuda')
150150

151+
# Use override_tp_rank if set, otherwise fall back to tp_rank. Currently, this is only used
152+
# for o_proj in MLA when using helix parallelism.
153+
effective_tp_rank = module.override_tp_rank if module.override_tp_rank is not None else module.tp_rank
151154
weight = load_weight_shard(weights[0]['weight'], module.tp_size,
152-
module.tp_rank, module.tp_mode,
155+
effective_tp_rank, module.tp_mode,
153156
device) if "weight" in weights[0] else None
154157

155158
if weight is not None:
@@ -166,7 +169,7 @@ def load_weights_vanilla_helper(module: Linear,
166169

167170
if module.bias is not None:
168171
bias = load_weight_shard(weights[0]['bias'], module.tp_size,
169-
module.tp_rank, module.tp_mode,
172+
effective_tp_rank, module.tp_mode,
170173
device) if "bias" in weights[0] else None
171174
if bias is not None:
172175
copy_weight(module.bias, bias_transform(bias))
@@ -2065,6 +2068,7 @@ def __init__(
20652068
disable_deep_gemm: bool = False,
20662069
fused_weight_shard_indices_mapping: Optional[dict] = None,
20672070
nvfp4_allowed_backends: Optional[List[str]] = None,
2071+
override_tp_rank: Optional[int] = None,
20682072
):
20692073
"""
20702074
Args:
@@ -2105,6 +2109,7 @@ def __init__(
21052109
'cutlass', 'cublaslt', 'cuda_core'
21062110
]
21072111

2112+
self.override_tp_rank = override_tp_rank
21082113
local_in_features = in_features
21092114
local_out_features = out_features
21102115

tensorrt_llm/mapping.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,26 @@ def has_cp_helix(self):
246246
return self.cp_size > 1 and self.cp_config.get(
247247
"cp_type") == CpType.HELIX
248248

249+
def get_helix_overridden_tp_rank(self) -> int:
250+
"""Get the overridden TP rank when repurposing helix CP to TP.
251+
252+
In helix parallelism, CP groups are structured differently than TP groups.
253+
For example, with tp_size=2, cp_size=2:
254+
- CP groups: [0, 2], [1, 3] (accumulated order: [0, 2, 1, 3])
255+
- When repurposed to TP: [0, 1, 2, 3]
256+
257+
The helix accumulated order iterates through TP ranks, and for each TP rank
258+
iterates through CP ranks. So the position in helix order is:
259+
helix_position = tp_rank * cp_size + cp_rank
260+
261+
This function computes the TP rank in the repurposed mapping, accounting
262+
for the reordering from helix accumulated order to standard TP order.
263+
264+
Returns:
265+
The TP rank in the repurposed (tp_size * cp_size, cp_size=1) mapping.
266+
"""
267+
return self.tp_rank * self.cp_size + self.cp_rank
268+
249269
def get_node_rank(self, rank: int):
250270
return rank // self.gpus_per_node
251271

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -871,7 +871,8 @@ def test_auto_dtype(self, overlap_scheduler, mtp_nextn):
871871
task = GSM8K(self.MODEL_NAME)
872872
task.evaluate(llm)
873873

874-
@pytest.mark.skip_less_device(4)
874+
@pytest.mark.parametrize("gen_pp,gen_tp,gen_cp", [(1, 2, 2), (2, 1, 2)],
875+
ids=["pp1tp2cp2", "pp2tp1cp2"])
875876
@pytest.mark.parametrize("cuda_graph_config", [
876877
None,
877878
{
@@ -888,8 +889,10 @@ def test_auto_dtype(self, overlap_scheduler, mtp_nextn):
888889
"cudagraph:with_padding"
889890
])
890891
@pytest.mark.parametrize("comms_medium", ["fifo", "nccl"])
891-
def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config):
892+
def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config,
893+
gen_pp, gen_tp, gen_cp):
892894
use_nccl_for_alltoall = comms_medium == "nccl"
895+
gen_ep = gen_tp * gen_cp
893896
kv_cache_config = {
894897
"free_gpu_memory_fraction": 0.5,
895898
"enable_block_reuse": False,
@@ -898,7 +901,7 @@ def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config):
898901
}
899902
ctx_server_config = {
900903
"pipeline_parallel_size": 1,
901-
"tensor_parallel_size": 2,
904+
"tensor_parallel_size": 4,
902905
"context_parallel_size": 1,
903906
"disable_overlap_scheduler": True,
904907
"kv_cache_config": kv_cache_config,
@@ -909,9 +912,10 @@ def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config):
909912
},
910913
}
911914
gen_server_config = {
912-
"tensor_parallel_size": 1,
913-
"pipeline_parallel_size": 1,
914-
"context_parallel_size": 2,
915+
"tensor_parallel_size": gen_tp,
916+
"pipeline_parallel_size": gen_pp,
917+
"context_parallel_size": gen_cp,
918+
"moe_expert_parallel_size": gen_ep,
915919
"cp_config": {
916920
"cp_type": "HELIX",
917921
"tokens_per_block": 32,

tests/integration/test_lists/qa/llm_function_core.txt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -540,12 +540,12 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding
540540
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2]
541541
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=0]
542542
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=2]
543-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none]
544-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding]
545-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding]
546-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none]
547-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding]
548-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding]
543+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none-pp1tp2cp2]
544+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding-pp1tp2cp2]
545+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp2cp2]
546+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none-pp1tp2cp2]
547+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding-pp1tp2cp2]
548+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp2cp2]
549549
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]
550550
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]
551551
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True]

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ l0_dgx_b200:
6666
backend: pytorch
6767
orchestrator: mpi
6868
tests:
69+
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60)
6970
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] TIMEOUT (60)
7071
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_mtp] TIMEOUT (60)
7172
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_bs8_mtp] TIMEOUT (60)
@@ -92,6 +93,7 @@ l0_dgx_b200:
9293
backend: pytorch
9394
orchestrator: mpi
9495
tests:
96+
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60)
9597
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (60)
9698
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline_fp8kv] TIMEOUT (60)
9799
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[latency] TIMEOUT (60)

tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,6 @@ l0_gb200_multi_gpus:
7272
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-cutlass]
7373
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[no_cuda_graph_overlap-cutlass]
7474
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-trtllm]
75-
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none]
76-
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none]
7775
- accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm_eagle] TIMEOUT (90)
7876
- condition:
7977
ranges:
@@ -89,10 +87,6 @@ l0_gb200_multi_gpus:
8987
stage: post_merge
9088
backend: pytorch
9189
tests:
92-
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding]
93-
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding]
94-
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding]
95-
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding]
9690
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=FLASHINFER-torch_compile=True]
9791
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False]
9892
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]

tests/unittest/others/test_mapping.py

Lines changed: 142 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import unittest
16+
from collections import namedtuple
1617

17-
from tensorrt_llm.mapping import Mapping
18+
from tensorrt_llm.mapping import CpType, Mapping
1819

1920

2021
class TestMapping(unittest.TestCase):
@@ -81,3 +82,143 @@ def test_mapping(self):
8182
self.assertEqual(m.next_pp_rank(), 1)
8283
self.assertEqual(m.prev_cp_rank(), 15)
8384
self.assertEqual(m.next_cp_rank(), 11)
85+
86+
def test_helix_overridden_tp_rank(self):
87+
# Test case for helix overridden TP rank: (pp_size, tp_size, cp_size, expected_mapping)
88+
# where expected_mapping is a list of (rank, expected_helix_tp_rank) tuples.
89+
HelixTestCase = namedtuple(
90+
'HelixTestCase',
91+
['pp_size', 'tp_size', 'cp_size', 'expected_mapping'])
92+
test_cases = [
93+
# Case: pp_size=1, tp_size=2, cp_size=2.
94+
# CP groups: [0, 2], [1, 3] -> helix order: [0, 2, 1, 3].
95+
HelixTestCase(pp_size=1,
96+
tp_size=2,
97+
cp_size=2,
98+
expected_mapping=[
99+
(0, 0),
100+
(2, 1),
101+
(1, 2),
102+
(3, 3),
103+
]),
104+
# Case: pp_size=1, tp_size=4, cp_size=2.
105+
# CP groups: [0, 4], [1, 5], [2, 6], [3, 7] -> helix order: [0, 4, 1, 5, 2, 6, 3, 7].
106+
HelixTestCase(pp_size=1,
107+
tp_size=4,
108+
cp_size=2,
109+
expected_mapping=[
110+
(0, 0),
111+
(4, 1),
112+
(1, 2),
113+
(5, 3),
114+
(2, 4),
115+
(6, 5),
116+
(3, 6),
117+
(7, 7),
118+
]),
119+
# Case: pp_size=1, tp_size=2, cp_size=4.
120+
# CP groups: [0, 2, 4, 6], [1, 3, 5, 7] -> helix order: [0, 2, 4, 6, 1, 3, 5, 7].
121+
HelixTestCase(pp_size=1,
122+
tp_size=2,
123+
cp_size=4,
124+
expected_mapping=[
125+
(0, 0),
126+
(2, 1),
127+
(4, 2),
128+
(6, 3),
129+
(1, 4),
130+
(3, 5),
131+
(5, 6),
132+
(7, 7),
133+
]),
134+
# Case: pp_size=1, tp_size=4, cp_size=4.
135+
# CP groups: [0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15] -> helix order: [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15].
136+
HelixTestCase(pp_size=1,
137+
tp_size=4,
138+
cp_size=4,
139+
expected_mapping=[
140+
(0, 0),
141+
(4, 1),
142+
(8, 2),
143+
(12, 3),
144+
(1, 4),
145+
(5, 5),
146+
(9, 6),
147+
(13, 7),
148+
(2, 8),
149+
(6, 9),
150+
(10, 10),
151+
(14, 11),
152+
(3, 12),
153+
(7, 13),
154+
(11, 14),
155+
(15, 15),
156+
]),
157+
# Case: pp_size=2, tp_size=4, cp_size=2.
158+
# PP stage 0 CP groups: [0,4], [1,5], [2,6], [3,7] -> helix order: [0, 4, 1, 5, 2, 6, 3, 7].
159+
# PP stage 1 CP groups: [8,12], [9,13], [10,14], [11,15] -> helix order: [8, 12, 9, 13, 10, 14, 11, 15].
160+
HelixTestCase(
161+
pp_size=2,
162+
tp_size=4,
163+
cp_size=2,
164+
expected_mapping=[
165+
(0, 0),
166+
(4, 1),
167+
(1, 2),
168+
(5, 3),
169+
(2, 4),
170+
(6, 5),
171+
(3, 6),
172+
(7, 7), # PP stage 0
173+
(8, 0),
174+
(12, 1),
175+
(9, 2),
176+
(13, 3),
177+
(10, 4),
178+
(14, 5),
179+
(11, 6),
180+
(15, 7), # PP stage 1
181+
]),
182+
# Case: pp_size=2, tp_size=2, cp_size=4.
183+
# PP stage 0 CP groups: [0, 2, 4, 6], [1, 3, 5, 7] -> helix order: [0, 2, 4, 6, 1, 3, 5, 7].
184+
# PP stage 1 CP groups: [8, 10, 12, 14], [9, 11, 13, 15] -> helix order: [8, 10, 12, 14, 9, 11, 13, 15].
185+
HelixTestCase(
186+
pp_size=2,
187+
tp_size=2,
188+
cp_size=4,
189+
expected_mapping=[
190+
(0, 0),
191+
(2, 1),
192+
(4, 2),
193+
(6, 3),
194+
(1, 4),
195+
(3, 5),
196+
(5, 6),
197+
(7, 7), # PP stage 0
198+
(8, 0),
199+
(10, 1),
200+
(12, 2),
201+
(14, 3),
202+
(9, 4),
203+
(11, 5),
204+
(13, 6),
205+
(15, 7), # PP stage 1
206+
]),
207+
]
208+
209+
for case in test_cases:
210+
world_size = case.pp_size * case.tp_size * case.cp_size
211+
with self.subTest(pp_size=case.pp_size,
212+
tp_size=case.tp_size,
213+
cp_size=case.cp_size):
214+
for rank, expected in case.expected_mapping:
215+
m = Mapping(world_size=world_size,
216+
rank=rank,
217+
tp_size=case.tp_size,
218+
pp_size=case.pp_size,
219+
cp_size=case.cp_size,
220+
cp_config={"cp_type": CpType.HELIX})
221+
self.assertEqual(
222+
m.get_helix_overridden_tp_rank(), expected,
223+
f"Failed for rank={rank}: expected {expected}, got {m.get_helix_overridden_tp_rank()}"
224+
)

0 commit comments

Comments
 (0)