Skip to content

Commit cdce68c

Browse files
Njuappchzblych
andauthored
[TRTLLM-6741][fix] Add heuristics for lm head tp size when enable_lm_head_tp_in_adp=True (#7891)
Signed-off-by: Cheng Hang <[email protected]> Co-authored-by: Yanchao Lu <[email protected]>
1 parent 6396cb9 commit cdce68c

File tree

8 files changed

+72
-33
lines changed

8 files changed

+72
-33
lines changed

jenkins/L0_Test.groovy

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2205,8 +2205,9 @@ def launchTestJobs(pipeline, testFilter)
22052205
// "GB200-8_GPUs-2_Nodes-PyTorch-5": ["gb200-multi-node", "l0_gb200_multi_nodes", 5, 5, 8, 2],
22062206
// ]
22072207
multiNodesSBSAConfigs = [:]
2208-
multiNodesSBSAConfigs += (1..7).collectEntries { i ->
2209-
["GB200-8_GPUs-2_Nodes-PyTorch-Post-Merge-${i}".toString(), ["gb200-multi-node", "l0_gb200_multi_nodes", i, 7, 8, 2]]
2208+
def numMultiNodeTests = 9
2209+
multiNodesSBSAConfigs += (1..numMultiNodeTests).collectEntries { i ->
2210+
["GB200-8_GPUs-2_Nodes-PyTorch-Post-Merge-${i}".toString(), ["gb200-multi-node", "l0_gb200_multi_nodes", i, numMultiNodeTests, 8, 2]]
22102211
}
22112212
fullSet += multiNodesSBSAConfigs.keySet()
22122213

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -404,12 +404,8 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
404404
self.norm = RMSNorm(hidden_size=config.hidden_size,
405405
eps=config.rms_norm_eps,
406406
dtype=config.torch_dtype)
407-
if self.model_config.mapping.enable_attention_dp and \
408-
getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False):
409-
self.mapping_lm_head_tp = create_lm_head_tp_mapping(
410-
self.model_config.mapping)
411-
else:
412-
self.mapping_lm_head_tp = self.model_config.mapping
407+
408+
self.mapping_lm_head_tp = None
413409

414410
@torch.compile(options={"max-autotune": True})
415411
def get_last_token_states(self, hidden_states, attn_metadata):
@@ -433,19 +429,23 @@ def forward(self,
433429
hidden_states = hidden_states[-1].unsqueeze(0)
434430

435431
enable_attention_dp = self.model_config.mapping.enable_attention_dp
436-
enable_lm_head_tp_in_adp = self.model_config.mapping.enable_lm_head_tp_in_adp
432+
enable_lm_head_tp_in_adp = enable_attention_dp and self.model_config.mapping.enable_lm_head_tp_in_adp
437433

438434
# Add pre-lm gather logic
439435
if enable_lm_head_tp_in_adp:
440436
# ADP + LM TP mode: perform All-Gather before LM_head
437+
self.mapping_lm_head_tp = create_lm_head_tp_mapping(
438+
self.model_config.mapping, hidden_states.shape[0])
441439
hidden_states = allgather(hidden_states,
442440
self.mapping_lm_head_tp,
443441
dim=0)
444442

445443
# Temporarily disable gather_output when not in ADP mode or (in ADP mode and LM TP is enabled)
446444
if not enable_attention_dp or enable_lm_head_tp_in_adp:
447445
lm_head.gather_output = False
448-
logits = lm_head(hidden_states, is_spec_decoding_head=True)
446+
logits = lm_head(hidden_states,
447+
mapping_lm_head_tp=self.mapping_lm_head_tp,
448+
is_spec_decoding_head=True)
449449
if not enable_attention_dp or enable_lm_head_tp_in_adp:
450450
lm_head.gather_output = True
451451
return logits

tensorrt_llm/_torch/modules/embedding.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77

88
from tensorrt_llm.functional import AllReduceParams
99
from tensorrt_llm.mapping import Mapping
10+
from tensorrt_llm.math_utils import ceil_div
1011

1112
from ..distributed import allgather
12-
from ..utils import create_lm_head_tp_mapping
1313
from .linear import Linear, TensorParallelMode
1414

1515

@@ -38,8 +38,6 @@ def __init__(
3838
mapping = mapping or Mapping()
3939
self.enable_lm_head_tp_in_adp = mapping.enable_attention_dp and \
4040
getattr(mapping, 'enable_lm_head_tp_in_adp', False)
41-
if self.enable_lm_head_tp_in_adp:
42-
mapping = create_lm_head_tp_mapping(mapping)
4341

4442
tp_size = mapping.tp_size
4543

@@ -78,18 +76,6 @@ def __init__(
7876
self.weight = Parameter(torch.empty(weight_shape, dtype=dtype))
7977
self.register_parameter("bias", None)
8078

81-
# For LM head TP in ADP, we need to slice the weight for the LM head
82-
self.lm_head_slice_obj = None
83-
if self.enable_lm_head_tp_in_adp:
84-
tp_rank = self.mapping.tp_rank
85-
tp_size = self.mapping.tp_size
86-
slice_width = math.ceil(self.out_features / tp_size)
87-
slice_start = tp_rank * slice_width
88-
slice_end = min((tp_rank + 1) * slice_width, self.out_features)
89-
slice_obj = [slice(None)] * len(self.weight.shape)
90-
slice_obj[0] = slice(slice_start, slice_end)
91-
self.lm_head_slice_obj = tuple(slice_obj)
92-
9379
@property
9480
def vocab_size_padded(self) -> int:
9581
if self.tp_mode == TensorParallelMode.COLUMN and self.gather_output:
@@ -102,10 +88,18 @@ def forward(
10288
input: torch.Tensor,
10389
*,
10490
all_reduce_params: Optional[AllReduceParams] = None,
91+
mapping_lm_head_tp: Optional[Mapping] = None,
10592
is_spec_decoding_head: bool = False,
10693
) -> torch.Tensor:
10794
if is_spec_decoding_head and self.enable_lm_head_tp_in_adp:
108-
output = F.linear(input, self.weight[self.lm_head_slice_obj], None)
95+
# For LM head TP in ADP, we need to slice the weight for the LM head
96+
tp_rank = mapping_lm_head_tp.tp_rank
97+
tp_size = mapping_lm_head_tp.tp_size
98+
slice_width = ceil_div(self.out_features, tp_size)
99+
slice_start = tp_rank * slice_width
100+
slice_end = min((tp_rank + 1) * slice_width, self.out_features)
101+
output = F.linear(input, self.weight[slice_start:slice_end, :],
102+
None)
109103
else:
110104
output = super().forward(input, all_reduce_params=all_reduce_params)
111105
if (self.tp_mode == TensorParallelMode.COLUMN and self.gather_output

tensorrt_llm/_torch/utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import contextlib
2-
import os
32
import threading
43
from dataclasses import dataclass
54
from enum import Enum
@@ -288,8 +287,13 @@ def get_per_request_piecewise_cuda_graph_flag() -> bool:
288287
return getattr(_global_attrs, 'per_request_piecewise_cuda_graph_flag', True)
289288

290289

291-
def create_lm_head_tp_mapping(mapping: Mapping) -> Mapping:
292-
lm_head_tp_size = int(os.getenv('LM_HEAD_TP_SIZE', 2))
290+
def create_lm_head_tp_mapping(mapping: Mapping, token_count: int) -> Mapping:
291+
# We use heuristic to determine the lm_head_tp_size
292+
# Since token_count=256 will hit the boundary of math-bound problem
293+
# We use 256 // token_count to determine the lm_head_tp_size
294+
lm_head_tp_size_raw = 256 // token_count
295+
lm_head_tp_size = nearest_in_buckets(lm_head_tp_size_raw,
296+
[1, mapping.gpus_per_node])
293297
assert mapping.tp_size % lm_head_tp_size == 0
294298
lm_head_pp_size = mapping.pp_size * mapping.tp_size // lm_head_tp_size
295299

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1903,7 +1903,7 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
19031903

19041904
@skip_pre_blackwell
19051905
@pytest.mark.parametrize(
1906-
"tp_size,pp_size,ep_size,mtp_nextn,fp8kv,attention_dp,cuda_graph,overlap_scheduler,max_batch_size,moe_backend",
1906+
"tp_size,pp_size,ep_size,mtp_nextn,fp8kv,attention_dp,enable_lm_head_tp_in_adp,cuda_graph,overlap_scheduler,max_batch_size,moe_backend",
19071907
[
19081908
# Use a larger batch_size to speed up the tests
19091909
pytest.param(8,
@@ -1912,6 +1912,7 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
19121912
3,
19131913
False,
19141914
False,
1915+
False,
19151916
True,
19161917
True,
19171918
32,
@@ -1923,6 +1924,31 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
19231924
3,
19241925
False,
19251926
False,
1927+
False,
1928+
True,
1929+
True,
1930+
32,
1931+
"TRTLLM",
1932+
marks=pytest.mark.skip_less_mpi_world_size(8)),
1933+
pytest.param(8,
1934+
1,
1935+
4,
1936+
3,
1937+
False,
1938+
True,
1939+
True,
1940+
True,
1941+
True,
1942+
32,
1943+
"CUTLASS",
1944+
marks=pytest.mark.skip_less_mpi_world_size(8)),
1945+
pytest.param(8,
1946+
1,
1947+
4,
1948+
3,
1949+
False,
1950+
True,
1951+
True,
19261952
True,
19271953
True,
19281954
32,
@@ -1934,6 +1960,7 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
19341960
0,
19351961
True,
19361962
True,
1963+
False,
19371964
True,
19381965
True,
19391966
32,
@@ -1945,6 +1972,7 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
19451972
0,
19461973
True,
19471974
True,
1975+
False,
19481976
True,
19491977
True,
19501978
32,
@@ -1956,6 +1984,7 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
19561984
0,
19571985
True,
19581986
True,
1987+
False,
19591988
True,
19601989
True,
19611990
16,
@@ -1967,6 +1996,7 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
19671996
1,
19681997
True,
19691998
True,
1999+
False,
19702000
True,
19712001
True,
19722002
32,
@@ -1978,19 +2008,22 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
19782008
1,
19792009
True,
19802010
True,
2011+
False,
19812012
True,
19822013
True,
19832014
8,
19842015
"CUTLASS",
19852016
marks=pytest.mark.skip_less_mpi_world_size(8)),
19862017
],
19872018
ids=[
1988-
"latency", "latency_trtllmgen", "throughput", "throughput_tp8",
2019+
"latency", "latency_trtllmgen", "latency_adp_lmtp",
2020+
"latency_trtllmgen_adp_lmtp", "throughput", "throughput_tp8",
19892021
"throughput_tp4", "throughput_mtp", "throughput_bs8_mtp"
19902022
])
19912023
def test_nvfp4_multi_gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
1992-
attention_dp, cuda_graph, overlap_scheduler,
1993-
max_batch_size, moe_backend):
2024+
attention_dp, enable_lm_head_tp_in_adp,
2025+
cuda_graph, overlap_scheduler, max_batch_size,
2026+
moe_backend):
19942027
if moe_backend == "TRTLLM" and (get_sm_version() == 120
19952028
or get_sm_version() == 121):
19962029
pytest.skip(
@@ -2016,6 +2049,7 @@ def test_nvfp4_multi_gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
20162049
kv_cache_config=kv_cache_config,
20172050
**pytorch_config,
20182051
enable_attention_dp=attention_dp,
2052+
enable_lm_head_tp_in_adp=enable_lm_head_tp_in_adp,
20192053
speculative_config=mtp_config) as llm:
20202054

20212055
assert llm.args.moe_config.backend == moe_backend

tests/integration/test_lists/qa/llm_function_core.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,8 @@ accuracy/test_llm_api_pytorch.py::TestNemotronNas::test_auto_dtype_tp8
483483
accuracy/test_llm_api_pytorch.py::TestQwen2_7BInstruct::test_auto_dtype
484484
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency]
485485
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_trtllmgen]
486+
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_adp_lmtp]
487+
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_trtllmgen_adp_lmtp]
486488
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput]
487489
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp8]
488490
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp4]

tests/integration/test_lists/qa/llm_function_core_sanity.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ accuracy/test_llm_api_pytorch.py::TestBielik11BInstruct::test_auto_dtype
3434
accuracy/test_llm_api_pytorch.py::TestBielik11BInstruct::test_fp8
3535
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_trtllmgen]
3636
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency]
37+
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_adp_lmtp]
38+
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_trtllmgen_adp_lmtp]
3739
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp4]
3840
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp8]
3941
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput]

tests/integration/test_lists/test-db/l0_gb200_multi_nodes.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ l0_gb200_multi_nodes:
3434
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency] TIMEOUT (180)
3535
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp8] TIMEOUT (180)
3636
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_trtllmgen] TIMEOUT (180)
37+
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_adp_lmtp] TIMEOUT (180)
38+
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_trtllmgen_adp_lmtp] TIMEOUT (180)
3739
- accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_cutlass] TIMEOUT (90)
3840
- accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm] TIMEOUT (90)
3941
- accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm_attention_dp] TIMEOUT (90)

0 commit comments

Comments
 (0)