diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 47e793e4816..c2bb35a488f 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -814,10 +814,13 @@ def __init__( # tensor parallel config = config or ModelConfig() + override_tp_rank_for_o_proj = None if mapping_with_cp is not None: logger.warning( "[MLA::__init__] Overriding mapping with CP detected.") self.mapping = mapping_with_cp + override_tp_rank_for_o_proj = mapping_with_cp.get_helix_overridden_tp_rank( + ) else: self.mapping = config.mapping tp_size = self.mapping.tp_size @@ -952,7 +955,10 @@ def __init__( skip_create_weights_in_init=config.skip_create_weights_in_init, reduce_output=reduce_output, allreduce_strategy=config.allreduce_strategy, - force_dynamic_quantization=config.force_dynamic_quantization) + force_dynamic_quantization=config.force_dynamic_quantization, + # override_tp_rank is only used for helix parallelism. + override_tp_rank=override_tp_rank_for_o_proj, + ) def yarn_get_mscale(scale=1, mscale=1): if scale <= 1: diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 44daa25eb3c..e201031e19c 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -148,8 +148,11 @@ def load_weights_vanilla_helper(module: Linear, assert "bias" in weights[0] device = torch.device('cuda') + # Use override_tp_rank if set, otherwise fall back to tp_rank. Currently, this is only used + # for o_proj in MLA when using helix parallelism. + effective_tp_rank = module.override_tp_rank if module.override_tp_rank is not None else module.tp_rank weight = load_weight_shard(weights[0]['weight'], module.tp_size, - module.tp_rank, module.tp_mode, + effective_tp_rank, module.tp_mode, device) if "weight" in weights[0] else None if weight is not None: @@ -166,7 +169,7 @@ def load_weights_vanilla_helper(module: Linear, if module.bias is not None: bias = load_weight_shard(weights[0]['bias'], module.tp_size, - module.tp_rank, module.tp_mode, + effective_tp_rank, module.tp_mode, device) if "bias" in weights[0] else None if bias is not None: copy_weight(module.bias, bias_transform(bias)) @@ -2065,6 +2068,7 @@ def __init__( disable_deep_gemm: bool = False, fused_weight_shard_indices_mapping: Optional[dict] = None, nvfp4_allowed_backends: Optional[List[str]] = None, + override_tp_rank: Optional[int] = None, ): """ Args: @@ -2105,6 +2109,7 @@ def __init__( 'cutlass', 'cublaslt', 'cuda_core' ] + self.override_tp_rank = override_tp_rank local_in_features = in_features local_out_features = out_features diff --git a/tensorrt_llm/mapping.py b/tensorrt_llm/mapping.py index 386d18da747..48ffebeea05 100644 --- a/tensorrt_llm/mapping.py +++ b/tensorrt_llm/mapping.py @@ -246,6 +246,26 @@ def has_cp_helix(self): return self.cp_size > 1 and self.cp_config.get( "cp_type") == CpType.HELIX + def get_helix_overridden_tp_rank(self) -> int: + """Get the overridden TP rank when repurposing helix CP to TP. + + In helix parallelism, CP groups are structured differently than TP groups. + For example, with tp_size=2, cp_size=2: + - CP groups: [0, 2], [1, 3] (accumulated order: [0, 2, 1, 3]) + - When repurposed to TP: [0, 1, 2, 3] + + The helix accumulated order iterates through TP ranks, and for each TP rank + iterates through CP ranks. So the position in helix order is: + helix_position = tp_rank * cp_size + cp_rank + + This function computes the TP rank in the repurposed mapping, accounting + for the reordering from helix accumulated order to standard TP order. + + Returns: + The TP rank in the repurposed (tp_size * cp_size, cp_size=1) mapping. + """ + return self.tp_rank * self.cp_size + self.cp_rank + def get_node_rank(self, rank: int): return rank // self.gpus_per_node diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index c52911fe00b..0a66ae48e09 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -871,7 +871,8 @@ def test_auto_dtype(self, overlap_scheduler, mtp_nextn): task = GSM8K(self.MODEL_NAME) task.evaluate(llm) - @pytest.mark.skip_less_device(4) + @pytest.mark.parametrize("gen_pp,gen_tp,gen_cp", [(1, 2, 2), (2, 1, 2)], + ids=["pp1tp2cp2", "pp2tp1cp2"]) @pytest.mark.parametrize("cuda_graph_config", [ None, { @@ -888,8 +889,10 @@ def test_auto_dtype(self, overlap_scheduler, mtp_nextn): "cudagraph:with_padding" ]) @pytest.mark.parametrize("comms_medium", ["fifo", "nccl"]) - def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config): + def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config, + gen_pp, gen_tp, gen_cp): use_nccl_for_alltoall = comms_medium == "nccl" + gen_ep = gen_tp * gen_cp kv_cache_config = { "free_gpu_memory_fraction": 0.5, "enable_block_reuse": False, @@ -898,7 +901,7 @@ def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config): } ctx_server_config = { "pipeline_parallel_size": 1, - "tensor_parallel_size": 2, + "tensor_parallel_size": 4, "context_parallel_size": 1, "disable_overlap_scheduler": True, "kv_cache_config": kv_cache_config, @@ -909,9 +912,10 @@ def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config): }, } gen_server_config = { - "tensor_parallel_size": 1, - "pipeline_parallel_size": 1, - "context_parallel_size": 2, + "tensor_parallel_size": gen_tp, + "pipeline_parallel_size": gen_pp, + "context_parallel_size": gen_cp, + "moe_expert_parallel_size": gen_ep, "cp_config": { "cp_type": "HELIX", "tokens_per_block": 32, diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index 9a7102142bf..0771c17ce68 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -540,12 +540,12 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2] accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=0] accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=2] -accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none] -accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding] -accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding] -accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none] -accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding] -accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none-pp1tp2cp2] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding-pp1tp2cp2] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp2cp2] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none-pp1tp2cp2] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding-pp1tp2cp2] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp2cp2] accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False] accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True] accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True] diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index c691acc1fef..2bd53313614 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -66,6 +66,7 @@ l0_dgx_b200: backend: pytorch orchestrator: mpi tests: + - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_mtp] TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_bs8_mtp] TIMEOUT (60) @@ -92,6 +93,7 @@ l0_dgx_b200: backend: pytorch orchestrator: mpi tests: + - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline_fp8kv] TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[latency] TIMEOUT (60) diff --git a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml index bf0ca261fe0..2241aea415a 100644 --- a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml +++ b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml @@ -72,8 +72,6 @@ l0_gb200_multi_gpus: - accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-cutlass] - accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[no_cuda_graph_overlap-cutlass] - accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-trtllm] - - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none] - - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none] - accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm_eagle] TIMEOUT (90) - condition: ranges: @@ -89,10 +87,6 @@ l0_gb200_multi_gpus: stage: post_merge backend: pytorch tests: - - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding] - - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding] - - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding] - - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=FLASHINFER-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] diff --git a/tests/unittest/others/test_mapping.py b/tests/unittest/others/test_mapping.py index bc9839239bf..1d417aab081 100644 --- a/tests/unittest/others/test_mapping.py +++ b/tests/unittest/others/test_mapping.py @@ -13,8 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import unittest +from collections import namedtuple -from tensorrt_llm.mapping import Mapping +from tensorrt_llm.mapping import CpType, Mapping class TestMapping(unittest.TestCase): @@ -81,3 +82,143 @@ def test_mapping(self): self.assertEqual(m.next_pp_rank(), 1) self.assertEqual(m.prev_cp_rank(), 15) self.assertEqual(m.next_cp_rank(), 11) + + def test_helix_overridden_tp_rank(self): + # Test case for helix overridden TP rank: (pp_size, tp_size, cp_size, expected_mapping) + # where expected_mapping is a list of (rank, expected_helix_tp_rank) tuples. + HelixTestCase = namedtuple( + 'HelixTestCase', + ['pp_size', 'tp_size', 'cp_size', 'expected_mapping']) + test_cases = [ + # Case: pp_size=1, tp_size=2, cp_size=2. + # CP groups: [0, 2], [1, 3] -> helix order: [0, 2, 1, 3]. + HelixTestCase(pp_size=1, + tp_size=2, + cp_size=2, + expected_mapping=[ + (0, 0), + (2, 1), + (1, 2), + (3, 3), + ]), + # Case: pp_size=1, tp_size=4, cp_size=2. + # CP groups: [0, 4], [1, 5], [2, 6], [3, 7] -> helix order: [0, 4, 1, 5, 2, 6, 3, 7]. + HelixTestCase(pp_size=1, + tp_size=4, + cp_size=2, + expected_mapping=[ + (0, 0), + (4, 1), + (1, 2), + (5, 3), + (2, 4), + (6, 5), + (3, 6), + (7, 7), + ]), + # Case: pp_size=1, tp_size=2, cp_size=4. + # CP groups: [0, 2, 4, 6], [1, 3, 5, 7] -> helix order: [0, 2, 4, 6, 1, 3, 5, 7]. + HelixTestCase(pp_size=1, + tp_size=2, + cp_size=4, + expected_mapping=[ + (0, 0), + (2, 1), + (4, 2), + (6, 3), + (1, 4), + (3, 5), + (5, 6), + (7, 7), + ]), + # Case: pp_size=1, tp_size=4, cp_size=4. + # CP groups: [0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15] -> helix order: [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]. + HelixTestCase(pp_size=1, + tp_size=4, + cp_size=4, + expected_mapping=[ + (0, 0), + (4, 1), + (8, 2), + (12, 3), + (1, 4), + (5, 5), + (9, 6), + (13, 7), + (2, 8), + (6, 9), + (10, 10), + (14, 11), + (3, 12), + (7, 13), + (11, 14), + (15, 15), + ]), + # Case: pp_size=2, tp_size=4, cp_size=2. + # PP stage 0 CP groups: [0,4], [1,5], [2,6], [3,7] -> helix order: [0, 4, 1, 5, 2, 6, 3, 7]. + # PP stage 1 CP groups: [8,12], [9,13], [10,14], [11,15] -> helix order: [8, 12, 9, 13, 10, 14, 11, 15]. + HelixTestCase( + pp_size=2, + tp_size=4, + cp_size=2, + expected_mapping=[ + (0, 0), + (4, 1), + (1, 2), + (5, 3), + (2, 4), + (6, 5), + (3, 6), + (7, 7), # PP stage 0 + (8, 0), + (12, 1), + (9, 2), + (13, 3), + (10, 4), + (14, 5), + (11, 6), + (15, 7), # PP stage 1 + ]), + # Case: pp_size=2, tp_size=2, cp_size=4. + # PP stage 0 CP groups: [0, 2, 4, 6], [1, 3, 5, 7] -> helix order: [0, 2, 4, 6, 1, 3, 5, 7]. + # PP stage 1 CP groups: [8, 10, 12, 14], [9, 11, 13, 15] -> helix order: [8, 10, 12, 14, 9, 11, 13, 15]. + HelixTestCase( + pp_size=2, + tp_size=2, + cp_size=4, + expected_mapping=[ + (0, 0), + (2, 1), + (4, 2), + (6, 3), + (1, 4), + (3, 5), + (5, 6), + (7, 7), # PP stage 0 + (8, 0), + (10, 1), + (12, 2), + (14, 3), + (9, 4), + (11, 5), + (13, 6), + (15, 7), # PP stage 1 + ]), + ] + + for case in test_cases: + world_size = case.pp_size * case.tp_size * case.cp_size + with self.subTest(pp_size=case.pp_size, + tp_size=case.tp_size, + cp_size=case.cp_size): + for rank, expected in case.expected_mapping: + m = Mapping(world_size=world_size, + rank=rank, + tp_size=case.tp_size, + pp_size=case.pp_size, + cp_size=case.cp_size, + cp_config={"cp_type": CpType.HELIX}) + self.assertEqual( + m.get_helix_overridden_tp_rank(), expected, + f"Failed for rank={rank}: expected {expected}, got {m.get_helix_overridden_tp_rank()}" + )