Skip to content

Commit 4e69221

Browse files
Add unittest
Signed-off-by: Jaedeok Kim <[email protected]>
1 parent 2786519 commit 4e69221

File tree

2 files changed

+91
-0
lines changed

2 files changed

+91
-0
lines changed

tests/integration/test_lists/test-db/l0_a10.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ l0_a10:
1717
- unittest/_torch/sampler/test_torch_sampler.py
1818
- unittest/_torch/sampler/test_torch_multi_arange.py
1919
- unittest/utils/test_util.py
20+
- unittest/_torch/test_model_config.py
2021
- unittest/_torch/modeling/test_modeling_mistral.py
2122
- unittest/_torch/modeling/test_modeling_pixtral.py
2223
- unittest/_torch/sampler/test_trtllm_sampler.py
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import types
2+
3+
import pytest
4+
import torch
5+
6+
from tensorrt_llm._torch.model_config import ModelConfig
7+
from tensorrt_llm.mapping import Mapping
8+
9+
10+
def make_pretrained_config(
11+
*,
12+
num_attention_heads: int = 16,
13+
num_key_value_heads=8,
14+
head_dim: int | None = None,
15+
num_hidden_layers: int = 1,
16+
vocab_size: int = 3000,
17+
):
18+
# A minimal config object that provides the attributes used by
19+
# ModelConfig.get_bindings_model_config().
20+
hidden_size = head_dim * num_attention_heads
21+
intermediate_size = hidden_size * 4
22+
23+
return types.SimpleNamespace(
24+
architectures=["DummyArchitecture"],
25+
num_attention_heads=num_attention_heads,
26+
hidden_size=hidden_size,
27+
intermediate_size=intermediate_size,
28+
num_key_value_heads=num_key_value_heads,
29+
head_dim=head_dim,
30+
num_hidden_layers=num_hidden_layers,
31+
vocab_size=vocab_size,
32+
torch_dtype=torch.float16,
33+
)
34+
35+
36+
@pytest.mark.parametrize(
37+
"num_key_value_heads",
38+
[
39+
pytest.param(8, id="kv_heads_scalar"),
40+
pytest.param([8, 20], id="kv_heads_per_layer_varied"),
41+
],
42+
)
43+
@pytest.mark.parametrize("enable_attention_dp", [False, True])
44+
@pytest.mark.parametrize(
45+
"mapping_kwargs",
46+
[
47+
# Same tp/cp sizes, but different ways of setting attention TP:
48+
# - No explicit attn_tp_size: Mapping infers it.
49+
# - Explicit attn_tp_size: Mapping uses the provided value.
50+
dict(world_size=8, tp_size=4, cp_size=2),
51+
dict(world_size=4, tp_size=2, cp_size=2, attn_tp_size=4),
52+
],
53+
)
54+
def test_get_bindings_model_config_attention_dp_attn_tp_override(
55+
enable_attention_dp, mapping_kwargs, num_key_value_heads
56+
):
57+
mapping = Mapping(enable_attention_dp=enable_attention_dp, **mapping_kwargs)
58+
cfg = make_pretrained_config(
59+
# Keep values consistent:
60+
# hidden_size = num_attention_heads * head_dim.
61+
num_attention_heads=16,
62+
head_dim=4,
63+
num_key_value_heads=num_key_value_heads,
64+
num_hidden_layers=2,
65+
)
66+
model_config = ModelConfig(pretrained_config=cfg, mapping=mapping)
67+
68+
tokens_per_block = 32
69+
bindings_cfg = model_config.get_bindings_model_config(tokens_per_block=tokens_per_block)
70+
71+
# bindings hidden_size is sharded by attn_tp_size and attn_cp_size.
72+
assert bindings_cfg.num_heads == cfg.num_attention_heads // (
73+
mapping.attn_tp_size * mapping.attn_cp_size
74+
)
75+
# bindings hidden_size is sharded by tp_size (not attention TP size).
76+
assert bindings_cfg.hidden_size == cfg.hidden_size // mapping.tp_size
77+
if isinstance(cfg.num_key_value_heads, (list, tuple)):
78+
expected_num_kv_heads_per_layer = [
79+
kv // (mapping.attn_tp_size * mapping.attn_cp_size) for kv in cfg.num_key_value_heads
80+
]
81+
assert list(bindings_cfg.num_kv_heads_per_layer) == expected_num_kv_heads_per_layer
82+
assert bindings_cfg.num_kv_heads(0) == expected_num_kv_heads_per_layer[0]
83+
else:
84+
assert bindings_cfg.num_kv_heads(0) == cfg.num_key_value_heads // (
85+
mapping.attn_tp_size * mapping.attn_cp_size
86+
)
87+
88+
# tp_size-dependent value (uses mapping.tp_size, not attn_tp_size).
89+
assert bindings_cfg.mlp_hidden_size == (cfg.intermediate_size // mapping.tp_size)
90+
assert bindings_cfg.tokens_per_block == tokens_per_block

0 commit comments

Comments
 (0)