Skip to content

Commit 7cbbd4c

Browse files
committed
[TRTLLM-9466][chore] Fix TP+CP combination with helix parallelism
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent 13ffe52 commit 7cbbd4c

File tree

7 files changed

+112
-16
lines changed

7 files changed

+112
-16
lines changed

tensorrt_llm/_torch/modules/attention.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -792,10 +792,13 @@ def __init__(
792792

793793
# tensor parallel
794794
config = config or ModelConfig()
795+
override_tp_rank_for_o_proj = None
795796
if mapping_with_cp is not None:
796797
logger.warning(
797798
"[MLA::__init__] Overriding mapping with CP detected.")
798799
self.mapping = mapping_with_cp
800+
override_tp_rank_for_o_proj = mapping_with_cp.get_helix_overridden_tp_rank(
801+
)
799802
else:
800803
self.mapping = config.mapping
801804
tp_size = self.mapping.tp_size
@@ -930,7 +933,10 @@ def __init__(
930933
skip_create_weights_in_init=config.skip_create_weights_in_init,
931934
reduce_output=reduce_output,
932935
allreduce_strategy=config.allreduce_strategy,
933-
force_dynamic_quantization=config.force_dynamic_quantization)
936+
force_dynamic_quantization=config.force_dynamic_quantization,
937+
override_tp_rank=
938+
override_tp_rank_for_o_proj, # Only used for helix parallelism.
939+
)
934940

935941
def yarn_get_mscale(scale=1, mscale=1):
936942
if scale <= 1:

tensorrt_llm/_torch/modules/linear.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,11 @@ def load_weights_vanilla_helper(module: Linear,
148148
assert "bias" in weights[0]
149149
device = torch.device('cuda')
150150

151+
# Use override_tp_rank if set, otherwise fall back to tp_rank. Currently, this is only used
152+
# for o_proj in MLA when using helix parallelism.
153+
effective_tp_rank = module.override_tp_rank if module.override_tp_rank is not None else module.tp_rank
151154
weight = load_weight_shard(weights[0]['weight'], module.tp_size,
152-
module.tp_rank, module.tp_mode,
155+
effective_tp_rank, module.tp_mode,
153156
device) if "weight" in weights[0] else None
154157

155158
if weight is not None:
@@ -166,7 +169,7 @@ def load_weights_vanilla_helper(module: Linear,
166169

167170
if module.bias is not None:
168171
bias = load_weight_shard(weights[0]['bias'], module.tp_size,
169-
module.tp_rank, module.tp_mode,
172+
effective_tp_rank, module.tp_mode,
170173
device) if "bias" in weights[0] else None
171174
if bias is not None:
172175
copy_weight(module.bias, bias_transform(bias))
@@ -2065,6 +2068,7 @@ def __init__(
20652068
disable_deep_gemm: bool = False,
20662069
fused_weight_shard_indices_mapping: Optional[dict] = None,
20672070
nvfp4_allowed_backends: Optional[List[str]] = None,
2071+
override_tp_rank: Optional[int] = None,
20682072
):
20692073
"""
20702074
Args:
@@ -2105,6 +2109,7 @@ def __init__(
21052109
'cutlass', 'cublaslt', 'cuda_core'
21062110
]
21072111

2112+
self.override_tp_rank = override_tp_rank
21082113
local_in_features = in_features
21092114
local_out_features = out_features
21102115

tensorrt_llm/mapping.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,26 @@ def has_cp_helix(self):
246246
return self.cp_size > 1 and self.cp_config.get(
247247
"cp_type") == CpType.HELIX
248248

249+
def get_helix_overridden_tp_rank(self) -> int:
250+
"""Get the overridden TP rank when repurposing helix CP to TP.
251+
252+
In helix parallelism, CP groups are structured differently than TP groups.
253+
For example, with tp_size=2, cp_size=2:
254+
- CP groups: [0, 2], [1, 3] (accumulated order: [0, 2, 1, 3])
255+
- When repurposed to TP: [0, 1, 2, 3]
256+
257+
The helix accumulated order iterates through TP ranks, and for each TP rank
258+
iterates through CP ranks. So the position in helix order is:
259+
helix_position = tp_rank * cp_size + cp_rank
260+
261+
This function computes the TP rank in the repurposed mapping, accounting
262+
for the reordering from helix accumulated order to standard TP order.
263+
264+
Returns:
265+
The TP rank in the repurposed (tp_size * cp_size, cp_size=1) mapping.
266+
"""
267+
return self.tp_rank * self.cp_size + self.cp_rank
268+
249269
def get_node_rank(self, rank: int):
250270
return rank // self.gpus_per_node
251271

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -871,7 +871,9 @@ def test_auto_dtype(self, overlap_scheduler, mtp_nextn):
871871
task = GSM8K(self.MODEL_NAME)
872872
task.evaluate(llm)
873873

874-
@pytest.mark.skip_less_device(4)
874+
875+
@pytest.mark.parametrize("gen_pp,gen_tp,gen_cp", [(1, 1, 2), (2, 1, 2)],
876+
ids=["pp1tp1cp2", "pp2tp1cp2"])
875877
@pytest.mark.parametrize("cuda_graph_config", [
876878
None,
877879
{
@@ -888,8 +890,9 @@ def test_auto_dtype(self, overlap_scheduler, mtp_nextn):
888890
"cudagraph:with_padding"
889891
])
890892
@pytest.mark.parametrize("comms_medium", ["fifo", "nccl"])
891-
def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config):
893+
def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config, gen_pp, gen_tp, gen_cp):
892894
use_nccl_for_alltoall = comms_medium == "nccl"
895+
gen_ep = gen_tp * gen_cp
893896
kv_cache_config = {
894897
"free_gpu_memory_fraction": 0.5,
895898
"enable_block_reuse": False,
@@ -898,7 +901,7 @@ def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config):
898901
}
899902
ctx_server_config = {
900903
"pipeline_parallel_size": 1,
901-
"tensor_parallel_size": 2,
904+
"tensor_parallel_size": 4,
902905
"context_parallel_size": 1,
903906
"disable_overlap_scheduler": True,
904907
"kv_cache_config": kv_cache_config,
@@ -909,9 +912,10 @@ def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config):
909912
},
910913
}
911914
gen_server_config = {
912-
"tensor_parallel_size": 1,
913-
"pipeline_parallel_size": 1,
914-
"context_parallel_size": 2,
915+
"tensor_parallel_size": gen_tp,
916+
"pipeline_parallel_size": gen_pp,
917+
"context_parallel_size": gen_cp,
918+
"moe_expert_parallel_size": gen_ep,
915919
"cp_config": {
916920
"cp_type": "HELIX",
917921
"tokens_per_block": 32,

tests/integration/test_lists/qa/llm_function_core.txt

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -535,12 +535,10 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding
535535
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2]
536536
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=0]
537537
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=2]
538-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none]
539-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding]
540-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding]
541-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none]
542-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding]
543-
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding]
538+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none-pp1tp2cp2]
539+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp2cp2]
540+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none-pp1tp2cp2]
541+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp2cp2]
544542
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]
545543
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]
546544
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True]

tests/integration/test_lists/test-db/l0_gb200_multi_nodes.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ l0_gb200_multi_nodes:
1313
stage: pre_merge
1414
backend: pytorch
1515
tests:
16+
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp2cp2]
1617
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] TIMEOUT (180)
1718
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_mtp] TIMEOUT (180)
1819
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput] TIMEOUT (180)
@@ -31,6 +32,7 @@ l0_gb200_multi_nodes:
3132
stage: post_merge
3233
backend: pytorch
3334
tests:
35+
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp2cp2]
3436
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency] TIMEOUT (180)
3537
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp8] TIMEOUT (180)
3638
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_trtllmgen] TIMEOUT (180)

tests/unittest/others/test_mapping.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import unittest
16+
from collections import namedtuple
1617

17-
from tensorrt_llm.mapping import Mapping
18+
from tensorrt_llm.mapping import Mapping, CpType
1819

1920

2021
class TestMapping(unittest.TestCase):
@@ -81,3 +82,63 @@ def test_mapping(self):
8182
self.assertEqual(m.next_pp_rank(), 1)
8283
self.assertEqual(m.prev_cp_rank(), 15)
8384
self.assertEqual(m.next_cp_rank(), 11)
85+
86+
def test_helix_overridden_tp_rank(self):
87+
# Test case for helix overridden TP rank: (pp_size, tp_size, cp_size, expected_mapping)
88+
# where expected_mapping is a list of (rank, expected_helix_tp_rank) tuples.
89+
HelixTestCase = namedtuple('HelixTestCase', ['pp_size', 'tp_size', 'cp_size', 'expected_mapping'])
90+
test_cases = [
91+
# Case: pp_size=1, tp_size=2, cp_size=2.
92+
# CP groups: [0, 2], [1, 3] -> helix order: [0, 2, 1, 3].
93+
HelixTestCase(pp_size=1, tp_size=2, cp_size=2, expected_mapping=[
94+
(0, 0), (2, 1), (1, 2), (3, 3),
95+
]),
96+
# Case: pp_size=1, tp_size=4, cp_size=2.
97+
# CP groups: [0, 4], [1, 5], [2, 6], [3, 7] -> helix order: [0, 4, 1, 5, 2, 6, 3, 7].
98+
HelixTestCase(pp_size=1, tp_size=4, cp_size=2, expected_mapping=[
99+
(0, 0), (4, 1), (1, 2), (5, 3),
100+
(2, 4), (6, 5), (3, 6), (7, 7),
101+
]),
102+
# Case: pp_size=1, tp_size=2, cp_size=4.
103+
# CP groups: [0, 2, 4, 6], [1, 3, 5, 7] -> helix order: [0, 2, 4, 6, 1, 3, 5, 7].
104+
HelixTestCase(pp_size=1, tp_size=2, cp_size=4, expected_mapping=[
105+
(0, 0), (2, 1), (4, 2), (6, 3),
106+
(1, 4), (3, 5), (5, 6), (7, 7),
107+
]),
108+
# Case: pp_size=1, tp_size=4, cp_size=4.
109+
# CP groups: [0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15] -> helix order: [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15].
110+
HelixTestCase(pp_size=1, tp_size=4, cp_size=4, expected_mapping=[
111+
(0, 0), (4, 1), (8, 2), (12, 3),
112+
(1, 4), (5, 5), (9, 6), (13, 7),
113+
(2, 8), (6, 9), (10, 10), (14, 11),
114+
(3, 12), (7, 13), (11, 14), (15, 15),
115+
]),
116+
# Case: pp_size=2, tp_size=4, cp_size=2.
117+
# PP stage 0 CP groups: [0,4], [1,5], [2,6], [3,7] -> helix order: [0, 4, 1, 5, 2, 6, 3, 7].
118+
# PP stage 1 CP groups: [8,12], [9,13], [10,14], [11,15] -> helix order: [8, 12, 9, 13, 10, 14, 11, 15].
119+
HelixTestCase(pp_size=2, tp_size=4, cp_size=2, expected_mapping=[
120+
(0, 0), (4, 1), (1, 2), (5, 3), (2, 4), (6, 5), (3, 6), (7, 7), # PP stage 0
121+
(8, 0), (12, 1), (9, 2), (13, 3), (10, 4), (14, 5), (11, 6), (15, 7), # PP stage 1
122+
]),
123+
# Case: pp_size=2, tp_size=2, cp_size=4.
124+
# PP stage 0 CP groups: [0, 2, 4, 6], [1, 3, 5, 7] -> helix order: [0, 2, 4, 6, 1, 3, 5, 7].
125+
# PP stage 1 CP groups: [8, 10, 12, 14], [9, 11, 13, 15] -> helix order: [8, 10, 12, 14, 9, 11, 13, 15].
126+
HelixTestCase(pp_size=2, tp_size=2, cp_size=4, expected_mapping=[
127+
(0, 0), (2, 1), (4, 2), (6, 3), (1, 4), (3, 5), (5, 6), (7, 7), # PP stage 0
128+
(8, 0), (10, 1), (12, 2), (14, 3), (9, 4), (11, 5), (13, 6), (15, 7), # PP stage 1
129+
]),
130+
]
131+
132+
for case in test_cases:
133+
world_size = case.pp_size * case.tp_size * case.cp_size
134+
with self.subTest(pp_size=case.pp_size, tp_size=case.tp_size, cp_size=case.cp_size):
135+
for rank, expected in case.expected_mapping:
136+
m = Mapping(
137+
world_size=world_size, rank=rank,
138+
tp_size=case.tp_size, pp_size=case.pp_size, cp_size=case.cp_size,
139+
cp_config={"cp_type": CpType.HELIX}
140+
)
141+
self.assertEqual(
142+
m.get_helix_overridden_tp_rank(), expected,
143+
f"Failed for rank={rank}: expected {expected}, got {m.get_helix_overridden_tp_rank()}"
144+
)

0 commit comments

Comments
 (0)