Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,10 +495,15 @@ def get_bindings_model_config(self,
# TODO smor- currently assuming no rnn layers, no MOE
from tensorrt_llm.bindings import ModelConfig as ModelConfigCpp

# Attention DP should not shard attention heads; use attn_tp_size=1 in that case
# so downstream KV calculations see the full (non-partitioned) head count.
attn_tp_size = self.mapping.attn_tp_size if not self.mapping.enable_attention_dp else 1
attn_cp_size = self.mapping.attn_cp_size

num_heads = self.pretrained_config.num_attention_heads // (
self.mapping.tp_size * self.mapping.cp_size)
attn_tp_size * attn_cp_size)

hidden_size = self.pretrained_config.hidden_size // self.mapping.tp_size
hidden_size = self.pretrained_config.hidden_size // attn_tp_size

model_config_cpp = ModelConfigCpp(
vocab_size=self.pretrained_config.vocab_size,
Expand All @@ -523,13 +528,12 @@ def get_bindings_model_config(self,
if isinstance(num_key_value_heads, (list, tuple)):
# Per-layer KV heads (e.g., Nemotron-NAS, variable GQA models)
num_kv_heads_per_layer = [
kv_heads // (self.mapping.tp_size * self.mapping.cp_size)
kv_heads // (attn_tp_size * attn_cp_size)
for kv_heads in num_key_value_heads
]
model_config_cpp.num_kv_heads_per_layer = num_kv_heads_per_layer
else:
num_kv_heads = num_key_value_heads // (self.mapping.tp_size *
self.mapping.cp_size)
num_kv_heads = num_key_value_heads // (attn_tp_size * attn_cp_size)
model_config_cpp.set_num_kv_heads(num_kv_heads)

mlp_hidden_size = None
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,7 +1138,8 @@ def calculate_max_num_blocks_from_cpp(
tensor_parallelism=self.mapping.tp_size,
pipeline_parallelism=self.mapping.pp_size,
rank=self.mapping.rank,
gpus_per_node=self.mapping.gpus_per_node)
gpus_per_node=self.mapping.gpus_per_node,
enable_attention_dp=self.mapping.enable_attention_dp)

window_size_to_layers = self._get_window_size_to_layers()
logger.debug(f"window_size_to_layers: {window_size_to_layers}")
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/test-db/l0_a10.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ l0_a10:
- unittest/_torch/sampler/test_torch_sampler.py
- unittest/_torch/sampler/test_torch_multi_arange.py
- unittest/utils/test_util.py
- unittest/_torch/test_model_config.py
- unittest/_torch/modeling/test_modeling_mistral.py
- unittest/_torch/modeling/test_modeling_pixtral.py
- unittest/_torch/sampler/test_trtllm_sampler.py
Expand Down
90 changes: 90 additions & 0 deletions tests/unittest/_torch/test_model_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import types

import pytest
import torch

from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm.mapping import Mapping


def make_pretrained_config(
*,
num_attention_heads: int = 16,
num_key_value_heads=8,
head_dim: int | None = None,
num_hidden_layers: int = 1,
vocab_size: int = 3000,
):
# A minimal config object that provides the attributes used by
# ModelConfig.get_bindings_model_config().
hidden_size = head_dim * num_attention_heads
intermediate_size = hidden_size * 4

return types.SimpleNamespace(
architectures=["DummyArchitecture"],
num_attention_heads=num_attention_heads,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
num_hidden_layers=num_hidden_layers,
vocab_size=vocab_size,
torch_dtype=torch.float16,
)


@pytest.mark.parametrize(
"num_key_value_heads",
[
pytest.param(8, id="kv_heads_scalar"),
pytest.param([8, 20], id="kv_heads_per_layer_varied"),
],
)
@pytest.mark.parametrize("enable_attention_dp", [False, True])
@pytest.mark.parametrize(
"mapping_kwargs",
[
# Same tp/cp sizes, but different ways of setting attention TP:
# - No explicit attn_tp_size: Mapping infers it.
# - Explicit attn_tp_size: Mapping uses the provided value.
dict(world_size=8, tp_size=4, cp_size=2),
dict(world_size=4, tp_size=2, cp_size=2, attn_tp_size=4),
],
)
def test_get_bindings_model_config_attention_dp_attn_tp_override(
enable_attention_dp, mapping_kwargs, num_key_value_heads
):
mapping = Mapping(enable_attention_dp=enable_attention_dp, **mapping_kwargs)
cfg = make_pretrained_config(
# Keep values consistent:
# hidden_size = num_attention_heads * head_dim.
num_attention_heads=16,
head_dim=4,
num_key_value_heads=num_key_value_heads,
num_hidden_layers=2,
)
model_config = ModelConfig(pretrained_config=cfg, mapping=mapping)

tokens_per_block = 32
bindings_cfg = model_config.get_bindings_model_config(tokens_per_block=tokens_per_block)

# bindings hidden_size is sharded by attn_tp_size and attn_cp_size.
attn_tp_size = mapping.attn_tp_size if not mapping.enable_attention_dp else 1
attn_cp_size = mapping.attn_cp_size
assert bindings_cfg.num_heads == cfg.num_attention_heads // (attn_tp_size * attn_cp_size)
# bindings hidden_size is sharded by attn_tp_size.
assert bindings_cfg.hidden_size == cfg.hidden_size // attn_tp_size
if isinstance(cfg.num_key_value_heads, (list, tuple)):
expected_num_kv_heads_per_layer = [
kv // (attn_tp_size * attn_cp_size) for kv in cfg.num_key_value_heads
]
assert list(bindings_cfg.num_kv_heads_per_layer) == expected_num_kv_heads_per_layer
assert bindings_cfg.num_kv_heads(0) == expected_num_kv_heads_per_layer[0]
else:
assert bindings_cfg.num_kv_heads(0) == cfg.num_key_value_heads // (
attn_tp_size * attn_cp_size
)

# tp_size-dependent value (uses mapping.tp_size, not attn_tp_size).
assert bindings_cfg.mlp_hidden_size == (cfg.intermediate_size // mapping.tp_size)
assert bindings_cfg.tokens_per_block == tokens_per_block