Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion tensorrt_llm/_torch/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,10 +814,13 @@ def __init__(

# tensor parallel
config = config or ModelConfig()
override_tp_rank_for_o_proj = None
if mapping_with_cp is not None:
logger.warning(
"[MLA::__init__] Overriding mapping with CP detected.")
self.mapping = mapping_with_cp
override_tp_rank_for_o_proj = mapping_with_cp.get_helix_overridden_tp_rank(
)
else:
self.mapping = config.mapping
tp_size = self.mapping.tp_size
Expand Down Expand Up @@ -952,7 +955,10 @@ def __init__(
skip_create_weights_in_init=config.skip_create_weights_in_init,
reduce_output=reduce_output,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization)
force_dynamic_quantization=config.force_dynamic_quantization,
# override_tp_rank is only used for helix parallelism.
override_tp_rank=override_tp_rank_for_o_proj,
)

def yarn_get_mscale(scale=1, mscale=1):
if scale <= 1:
Expand Down
9 changes: 7 additions & 2 deletions tensorrt_llm/_torch/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,11 @@ def load_weights_vanilla_helper(module: Linear,
assert "bias" in weights[0]
device = torch.device('cuda')

# Use override_tp_rank if set, otherwise fall back to tp_rank. Currently, this is only used
# for o_proj in MLA when using helix parallelism.
effective_tp_rank = module.override_tp_rank if module.override_tp_rank is not None else module.tp_rank
weight = load_weight_shard(weights[0]['weight'], module.tp_size,
module.tp_rank, module.tp_mode,
effective_tp_rank, module.tp_mode,
device) if "weight" in weights[0] else None

if weight is not None:
Expand All @@ -166,7 +169,7 @@ def load_weights_vanilla_helper(module: Linear,

if module.bias is not None:
bias = load_weight_shard(weights[0]['bias'], module.tp_size,
module.tp_rank, module.tp_mode,
effective_tp_rank, module.tp_mode,
device) if "bias" in weights[0] else None
if bias is not None:
copy_weight(module.bias, bias_transform(bias))
Expand Down Expand Up @@ -2065,6 +2068,7 @@ def __init__(
disable_deep_gemm: bool = False,
fused_weight_shard_indices_mapping: Optional[dict] = None,
nvfp4_allowed_backends: Optional[List[str]] = None,
override_tp_rank: Optional[int] = None,
):
"""
Args:
Expand Down Expand Up @@ -2105,6 +2109,7 @@ def __init__(
'cutlass', 'cublaslt', 'cuda_core'
]

self.override_tp_rank = override_tp_rank
local_in_features = in_features
local_out_features = out_features

Expand Down
20 changes: 20 additions & 0 deletions tensorrt_llm/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,26 @@ def has_cp_helix(self):
return self.cp_size > 1 and self.cp_config.get(
"cp_type") == CpType.HELIX

def get_helix_overridden_tp_rank(self) -> int:
"""Get the overridden TP rank when repurposing helix CP to TP.

In helix parallelism, CP groups are structured differently than TP groups.
For example, with tp_size=2, cp_size=2:
- CP groups: [0, 2], [1, 3] (accumulated order: [0, 2, 1, 3])
- When repurposed to TP: [0, 1, 2, 3]

The helix accumulated order iterates through TP ranks, and for each TP rank
iterates through CP ranks. So the position in helix order is:
helix_position = tp_rank * cp_size + cp_rank

This function computes the TP rank in the repurposed mapping, accounting
for the reordering from helix accumulated order to standard TP order.

Returns:
The TP rank in the repurposed (tp_size * cp_size, cp_size=1) mapping.
"""
return self.tp_rank * self.cp_size + self.cp_rank

def get_node_rank(self, rank: int):
return rank // self.gpus_per_node

Expand Down
16 changes: 10 additions & 6 deletions tests/integration/defs/accuracy/test_disaggregated_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,8 @@ def test_auto_dtype(self, overlap_scheduler, mtp_nextn):
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

@pytest.mark.skip_less_device(4)
@pytest.mark.parametrize("gen_pp,gen_tp,gen_cp", [(1, 2, 2), (2, 1, 2)],
ids=["pp1tp2cp2", "pp2tp1cp2"])
@pytest.mark.parametrize("cuda_graph_config", [
None,
{
Expand All @@ -888,8 +889,10 @@ def test_auto_dtype(self, overlap_scheduler, mtp_nextn):
"cudagraph:with_padding"
])
@pytest.mark.parametrize("comms_medium", ["fifo", "nccl"])
def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config):
def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config,
gen_pp, gen_tp, gen_cp):
use_nccl_for_alltoall = comms_medium == "nccl"
gen_ep = gen_tp * gen_cp
kv_cache_config = {
"free_gpu_memory_fraction": 0.5,
"enable_block_reuse": False,
Expand All @@ -898,7 +901,7 @@ def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config):
}
ctx_server_config = {
"pipeline_parallel_size": 1,
"tensor_parallel_size": 2,
"tensor_parallel_size": 4,
"context_parallel_size": 1,
"disable_overlap_scheduler": True,
"kv_cache_config": kv_cache_config,
Expand All @@ -909,9 +912,10 @@ def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config):
},
}
gen_server_config = {
"tensor_parallel_size": 1,
"pipeline_parallel_size": 1,
"context_parallel_size": 2,
"tensor_parallel_size": gen_tp,
"pipeline_parallel_size": gen_pp,
"context_parallel_size": gen_cp,
"moe_expert_parallel_size": gen_ep,
"cp_config": {
"cp_type": "HELIX",
"tokens_per_block": 32,
Expand Down
12 changes: 6 additions & 6 deletions tests/integration/test_lists/qa/llm_function_core.txt
Original file line number Diff line number Diff line change
Expand Up @@ -540,12 +540,12 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=0]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none-pp1tp2cp2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding-pp1tp2cp2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp2cp2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none-pp1tp2cp2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding-pp1tp2cp2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp2cp2]
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True]
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/test_lists/test-db/l0_dgx_b200.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ l0_dgx_b200:
backend: pytorch
orchestrator: mpi
tests:
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_mtp] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_bs8_mtp] TIMEOUT (60)
Expand All @@ -92,6 +93,7 @@ l0_dgx_b200:
backend: pytorch
orchestrator: mpi
tests:
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline_fp8kv] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[latency] TIMEOUT (60)
Expand Down
6 changes: 0 additions & 6 deletions tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ l0_gb200_multi_gpus:
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-cutlass]
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[no_cuda_graph_overlap-cutlass]
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-trtllm]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none]
- accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm_eagle] TIMEOUT (90)
- condition:
ranges:
Expand All @@ -89,10 +87,6 @@ l0_gb200_multi_gpus:
stage: post_merge
backend: pytorch
tests:
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=FLASHINFER-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
Expand Down
143 changes: 142 additions & 1 deletion tests/unittest/others/test_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from collections import namedtuple

from tensorrt_llm.mapping import Mapping
from tensorrt_llm.mapping import CpType, Mapping


class TestMapping(unittest.TestCase):
Expand Down Expand Up @@ -81,3 +82,143 @@ def test_mapping(self):
self.assertEqual(m.next_pp_rank(), 1)
self.assertEqual(m.prev_cp_rank(), 15)
self.assertEqual(m.next_cp_rank(), 11)

def test_helix_overridden_tp_rank(self):
# Test case for helix overridden TP rank: (pp_size, tp_size, cp_size, expected_mapping)
# where expected_mapping is a list of (rank, expected_helix_tp_rank) tuples.
HelixTestCase = namedtuple(
'HelixTestCase',
['pp_size', 'tp_size', 'cp_size', 'expected_mapping'])
test_cases = [
# Case: pp_size=1, tp_size=2, cp_size=2.
# CP groups: [0, 2], [1, 3] -> helix order: [0, 2, 1, 3].
HelixTestCase(pp_size=1,
tp_size=2,
cp_size=2,
expected_mapping=[
(0, 0),
(2, 1),
(1, 2),
(3, 3),
]),
# Case: pp_size=1, tp_size=4, cp_size=2.
# CP groups: [0, 4], [1, 5], [2, 6], [3, 7] -> helix order: [0, 4, 1, 5, 2, 6, 3, 7].
HelixTestCase(pp_size=1,
tp_size=4,
cp_size=2,
expected_mapping=[
(0, 0),
(4, 1),
(1, 2),
(5, 3),
(2, 4),
(6, 5),
(3, 6),
(7, 7),
]),
# Case: pp_size=1, tp_size=2, cp_size=4.
# CP groups: [0, 2, 4, 6], [1, 3, 5, 7] -> helix order: [0, 2, 4, 6, 1, 3, 5, 7].
HelixTestCase(pp_size=1,
tp_size=2,
cp_size=4,
expected_mapping=[
(0, 0),
(2, 1),
(4, 2),
(6, 3),
(1, 4),
(3, 5),
(5, 6),
(7, 7),
]),
# Case: pp_size=1, tp_size=4, cp_size=4.
# CP groups: [0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15] -> helix order: [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15].
HelixTestCase(pp_size=1,
tp_size=4,
cp_size=4,
expected_mapping=[
(0, 0),
(4, 1),
(8, 2),
(12, 3),
(1, 4),
(5, 5),
(9, 6),
(13, 7),
(2, 8),
(6, 9),
(10, 10),
(14, 11),
(3, 12),
(7, 13),
(11, 14),
(15, 15),
]),
# Case: pp_size=2, tp_size=4, cp_size=2.
# PP stage 0 CP groups: [0,4], [1,5], [2,6], [3,7] -> helix order: [0, 4, 1, 5, 2, 6, 3, 7].
# PP stage 1 CP groups: [8,12], [9,13], [10,14], [11,15] -> helix order: [8, 12, 9, 13, 10, 14, 11, 15].
HelixTestCase(
pp_size=2,
tp_size=4,
cp_size=2,
expected_mapping=[
(0, 0),
(4, 1),
(1, 2),
(5, 3),
(2, 4),
(6, 5),
(3, 6),
(7, 7), # PP stage 0
(8, 0),
(12, 1),
(9, 2),
(13, 3),
(10, 4),
(14, 5),
(11, 6),
(15, 7), # PP stage 1
]),
# Case: pp_size=2, tp_size=2, cp_size=4.
# PP stage 0 CP groups: [0, 2, 4, 6], [1, 3, 5, 7] -> helix order: [0, 2, 4, 6, 1, 3, 5, 7].
# PP stage 1 CP groups: [8, 10, 12, 14], [9, 11, 13, 15] -> helix order: [8, 10, 12, 14, 9, 11, 13, 15].
HelixTestCase(
pp_size=2,
tp_size=2,
cp_size=4,
expected_mapping=[
(0, 0),
(2, 1),
(4, 2),
(6, 3),
(1, 4),
(3, 5),
(5, 6),
(7, 7), # PP stage 0
(8, 0),
(10, 1),
(12, 2),
(14, 3),
(9, 4),
(11, 5),
(13, 6),
(15, 7), # PP stage 1
]),
]

for case in test_cases:
world_size = case.pp_size * case.tp_size * case.cp_size
with self.subTest(pp_size=case.pp_size,
tp_size=case.tp_size,
cp_size=case.cp_size):
for rank, expected in case.expected_mapping:
m = Mapping(world_size=world_size,
rank=rank,
tp_size=case.tp_size,
pp_size=case.pp_size,
cp_size=case.cp_size,
cp_config={"cp_type": CpType.HELIX})
self.assertEqual(
m.get_helix_overridden_tp_rank(), expected,
f"Failed for rank={rank}: expected {expected}, got {m.get_helix_overridden_tp_rank()}"
)