Skip to content

Commit a4dcc6a

Browse files
[TRTLLM-10171][fix] Correct attention handling in ModelConfig and KVCacheManager (#10330)
Signed-off-by: Jaedeok Kim <[email protected]>
1 parent 6ba04eb commit a4dcc6a

File tree

4 files changed

+102
-6
lines changed

4 files changed

+102
-6
lines changed

tensorrt_llm/_torch/model_config.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -495,10 +495,15 @@ def get_bindings_model_config(self,
495495
# TODO smor- currently assuming no rnn layers, no MOE
496496
from tensorrt_llm.bindings import ModelConfig as ModelConfigCpp
497497

498+
# Attention DP should not shard attention heads; use attn_tp_size=1 in that case
499+
# so downstream KV calculations see the full (non-partitioned) head count.
500+
attn_tp_size = self.mapping.attn_tp_size if not self.mapping.enable_attention_dp else 1
501+
attn_cp_size = self.mapping.attn_cp_size
502+
498503
num_heads = self.pretrained_config.num_attention_heads // (
499-
self.mapping.tp_size * self.mapping.cp_size)
504+
attn_tp_size * attn_cp_size)
500505

501-
hidden_size = self.pretrained_config.hidden_size // self.mapping.tp_size
506+
hidden_size = self.pretrained_config.hidden_size // attn_tp_size
502507

503508
model_config_cpp = ModelConfigCpp(
504509
vocab_size=self.pretrained_config.vocab_size,
@@ -523,13 +528,12 @@ def get_bindings_model_config(self,
523528
if isinstance(num_key_value_heads, (list, tuple)):
524529
# Per-layer KV heads (e.g., Nemotron-NAS, variable GQA models)
525530
num_kv_heads_per_layer = [
526-
kv_heads // (self.mapping.tp_size * self.mapping.cp_size)
531+
kv_heads // (attn_tp_size * attn_cp_size)
527532
for kv_heads in num_key_value_heads
528533
]
529534
model_config_cpp.num_kv_heads_per_layer = num_kv_heads_per_layer
530535
else:
531-
num_kv_heads = num_key_value_heads // (self.mapping.tp_size *
532-
self.mapping.cp_size)
536+
num_kv_heads = num_key_value_heads // (attn_tp_size * attn_cp_size)
533537
model_config_cpp.set_num_kv_heads(num_kv_heads)
534538

535539
mlp_hidden_size = None

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1138,7 +1138,8 @@ def calculate_max_num_blocks_from_cpp(
11381138
tensor_parallelism=self.mapping.tp_size,
11391139
pipeline_parallelism=self.mapping.pp_size,
11401140
rank=self.mapping.rank,
1141-
gpus_per_node=self.mapping.gpus_per_node)
1141+
gpus_per_node=self.mapping.gpus_per_node,
1142+
enable_attention_dp=self.mapping.enable_attention_dp)
11421143

11431144
window_size_to_layers = self._get_window_size_to_layers()
11441145
logger.debug(f"window_size_to_layers: {window_size_to_layers}")

tests/integration/test_lists/test-db/l0_a10.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ l0_a10:
1717
- unittest/_torch/sampler/test_torch_sampler.py
1818
- unittest/_torch/sampler/test_torch_multi_arange.py
1919
- unittest/utils/test_util.py
20+
- unittest/_torch/test_model_config.py
2021
- unittest/_torch/modeling/test_modeling_mistral.py
2122
- unittest/_torch/modeling/test_modeling_pixtral.py
2223
- unittest/_torch/sampler/test_trtllm_sampler.py
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import types
2+
3+
import pytest
4+
import torch
5+
6+
from tensorrt_llm._torch.model_config import ModelConfig
7+
from tensorrt_llm.mapping import Mapping
8+
9+
10+
def make_pretrained_config(
11+
*,
12+
num_attention_heads: int = 16,
13+
num_key_value_heads=8,
14+
head_dim: int | None = None,
15+
num_hidden_layers: int = 1,
16+
vocab_size: int = 3000,
17+
):
18+
# A minimal config object that provides the attributes used by
19+
# ModelConfig.get_bindings_model_config().
20+
hidden_size = head_dim * num_attention_heads
21+
intermediate_size = hidden_size * 4
22+
23+
return types.SimpleNamespace(
24+
architectures=["DummyArchitecture"],
25+
num_attention_heads=num_attention_heads,
26+
hidden_size=hidden_size,
27+
intermediate_size=intermediate_size,
28+
num_key_value_heads=num_key_value_heads,
29+
head_dim=head_dim,
30+
num_hidden_layers=num_hidden_layers,
31+
vocab_size=vocab_size,
32+
torch_dtype=torch.float16,
33+
)
34+
35+
36+
@pytest.mark.parametrize(
37+
"num_key_value_heads",
38+
[
39+
pytest.param(8, id="kv_heads_scalar"),
40+
pytest.param([8, 20], id="kv_heads_per_layer_varied"),
41+
],
42+
)
43+
@pytest.mark.parametrize("enable_attention_dp", [False, True])
44+
@pytest.mark.parametrize(
45+
"mapping_kwargs",
46+
[
47+
# Same tp/cp sizes, but different ways of setting attention TP:
48+
# - No explicit attn_tp_size: Mapping infers it.
49+
# - Explicit attn_tp_size: Mapping uses the provided value.
50+
dict(world_size=8, tp_size=4, cp_size=2),
51+
dict(world_size=4, tp_size=2, cp_size=2, attn_tp_size=4),
52+
],
53+
)
54+
def test_get_bindings_model_config_attention_dp_attn_tp_override(
55+
enable_attention_dp, mapping_kwargs, num_key_value_heads
56+
):
57+
mapping = Mapping(enable_attention_dp=enable_attention_dp, **mapping_kwargs)
58+
cfg = make_pretrained_config(
59+
# Keep values consistent:
60+
# hidden_size = num_attention_heads * head_dim.
61+
num_attention_heads=16,
62+
head_dim=4,
63+
num_key_value_heads=num_key_value_heads,
64+
num_hidden_layers=2,
65+
)
66+
model_config = ModelConfig(pretrained_config=cfg, mapping=mapping)
67+
68+
tokens_per_block = 32
69+
bindings_cfg = model_config.get_bindings_model_config(tokens_per_block=tokens_per_block)
70+
71+
# bindings hidden_size is sharded by attn_tp_size and attn_cp_size.
72+
attn_tp_size = mapping.attn_tp_size if not mapping.enable_attention_dp else 1
73+
attn_cp_size = mapping.attn_cp_size
74+
assert bindings_cfg.num_heads == cfg.num_attention_heads // (attn_tp_size * attn_cp_size)
75+
# bindings hidden_size is sharded by attn_tp_size.
76+
assert bindings_cfg.hidden_size == cfg.hidden_size // attn_tp_size
77+
if isinstance(cfg.num_key_value_heads, (list, tuple)):
78+
expected_num_kv_heads_per_layer = [
79+
kv // (attn_tp_size * attn_cp_size) for kv in cfg.num_key_value_heads
80+
]
81+
assert list(bindings_cfg.num_kv_heads_per_layer) == expected_num_kv_heads_per_layer
82+
assert bindings_cfg.num_kv_heads(0) == expected_num_kv_heads_per_layer[0]
83+
else:
84+
assert bindings_cfg.num_kv_heads(0) == cfg.num_key_value_heads // (
85+
attn_tp_size * attn_cp_size
86+
)
87+
88+
# tp_size-dependent value (uses mapping.tp_size, not attn_tp_size).
89+
assert bindings_cfg.mlp_hidden_size == (cfg.intermediate_size // mapping.tp_size)
90+
assert bindings_cfg.tokens_per_block == tokens_per_block

0 commit comments

Comments
 (0)