diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index ed61109dc87..c1411eb1bf9 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -495,10 +495,15 @@ def get_bindings_model_config(self, # TODO smor- currently assuming no rnn layers, no MOE from tensorrt_llm.bindings import ModelConfig as ModelConfigCpp + # Attention DP should not shard attention heads; use attn_tp_size=1 in that case + # so downstream KV calculations see the full (non-partitioned) head count. + attn_tp_size = self.mapping.attn_tp_size if not self.mapping.enable_attention_dp else 1 + attn_cp_size = self.mapping.attn_cp_size + num_heads = self.pretrained_config.num_attention_heads // ( - self.mapping.tp_size * self.mapping.cp_size) + attn_tp_size * attn_cp_size) - hidden_size = self.pretrained_config.hidden_size // self.mapping.tp_size + hidden_size = self.pretrained_config.hidden_size // attn_tp_size model_config_cpp = ModelConfigCpp( vocab_size=self.pretrained_config.vocab_size, @@ -523,13 +528,12 @@ def get_bindings_model_config(self, if isinstance(num_key_value_heads, (list, tuple)): # Per-layer KV heads (e.g., Nemotron-NAS, variable GQA models) num_kv_heads_per_layer = [ - kv_heads // (self.mapping.tp_size * self.mapping.cp_size) + kv_heads // (attn_tp_size * attn_cp_size) for kv_heads in num_key_value_heads ] model_config_cpp.num_kv_heads_per_layer = num_kv_heads_per_layer else: - num_kv_heads = num_key_value_heads // (self.mapping.tp_size * - self.mapping.cp_size) + num_kv_heads = num_key_value_heads // (attn_tp_size * attn_cp_size) model_config_cpp.set_num_kv_heads(num_kv_heads) mlp_hidden_size = None diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index aa4e1f228bd..619f8525c17 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -1138,7 +1138,8 @@ def calculate_max_num_blocks_from_cpp( tensor_parallelism=self.mapping.tp_size, pipeline_parallelism=self.mapping.pp_size, rank=self.mapping.rank, - gpus_per_node=self.mapping.gpus_per_node) + gpus_per_node=self.mapping.gpus_per_node, + enable_attention_dp=self.mapping.enable_attention_dp) window_size_to_layers = self._get_window_size_to_layers() logger.debug(f"window_size_to_layers: {window_size_to_layers}") diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 03a7b53b5a6..adb11c9416a 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -17,6 +17,7 @@ l0_a10: - unittest/_torch/sampler/test_torch_sampler.py - unittest/_torch/sampler/test_torch_multi_arange.py - unittest/utils/test_util.py + - unittest/_torch/test_model_config.py - unittest/_torch/modeling/test_modeling_mistral.py - unittest/_torch/modeling/test_modeling_pixtral.py - unittest/_torch/sampler/test_trtllm_sampler.py diff --git a/tests/unittest/_torch/test_model_config.py b/tests/unittest/_torch/test_model_config.py new file mode 100644 index 00000000000..5fb55cce9ee --- /dev/null +++ b/tests/unittest/_torch/test_model_config.py @@ -0,0 +1,90 @@ +import types + +import pytest +import torch + +from tensorrt_llm._torch.model_config import ModelConfig +from tensorrt_llm.mapping import Mapping + + +def make_pretrained_config( + *, + num_attention_heads: int = 16, + num_key_value_heads=8, + head_dim: int | None = None, + num_hidden_layers: int = 1, + vocab_size: int = 3000, +): + # A minimal config object that provides the attributes used by + # ModelConfig.get_bindings_model_config(). + hidden_size = head_dim * num_attention_heads + intermediate_size = hidden_size * 4 + + return types.SimpleNamespace( + architectures=["DummyArchitecture"], + num_attention_heads=num_attention_heads, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + num_hidden_layers=num_hidden_layers, + vocab_size=vocab_size, + torch_dtype=torch.float16, + ) + + +@pytest.mark.parametrize( + "num_key_value_heads", + [ + pytest.param(8, id="kv_heads_scalar"), + pytest.param([8, 20], id="kv_heads_per_layer_varied"), + ], +) +@pytest.mark.parametrize("enable_attention_dp", [False, True]) +@pytest.mark.parametrize( + "mapping_kwargs", + [ + # Same tp/cp sizes, but different ways of setting attention TP: + # - No explicit attn_tp_size: Mapping infers it. + # - Explicit attn_tp_size: Mapping uses the provided value. + dict(world_size=8, tp_size=4, cp_size=2), + dict(world_size=4, tp_size=2, cp_size=2, attn_tp_size=4), + ], +) +def test_get_bindings_model_config_attention_dp_attn_tp_override( + enable_attention_dp, mapping_kwargs, num_key_value_heads +): + mapping = Mapping(enable_attention_dp=enable_attention_dp, **mapping_kwargs) + cfg = make_pretrained_config( + # Keep values consistent: + # hidden_size = num_attention_heads * head_dim. + num_attention_heads=16, + head_dim=4, + num_key_value_heads=num_key_value_heads, + num_hidden_layers=2, + ) + model_config = ModelConfig(pretrained_config=cfg, mapping=mapping) + + tokens_per_block = 32 + bindings_cfg = model_config.get_bindings_model_config(tokens_per_block=tokens_per_block) + + # bindings hidden_size is sharded by attn_tp_size and attn_cp_size. + attn_tp_size = mapping.attn_tp_size if not mapping.enable_attention_dp else 1 + attn_cp_size = mapping.attn_cp_size + assert bindings_cfg.num_heads == cfg.num_attention_heads // (attn_tp_size * attn_cp_size) + # bindings hidden_size is sharded by attn_tp_size. + assert bindings_cfg.hidden_size == cfg.hidden_size // attn_tp_size + if isinstance(cfg.num_key_value_heads, (list, tuple)): + expected_num_kv_heads_per_layer = [ + kv // (attn_tp_size * attn_cp_size) for kv in cfg.num_key_value_heads + ] + assert list(bindings_cfg.num_kv_heads_per_layer) == expected_num_kv_heads_per_layer + assert bindings_cfg.num_kv_heads(0) == expected_num_kv_heads_per_layer[0] + else: + assert bindings_cfg.num_kv_heads(0) == cfg.num_key_value_heads // ( + attn_tp_size * attn_cp_size + ) + + # tp_size-dependent value (uses mapping.tp_size, not attn_tp_size). + assert bindings_cfg.mlp_hidden_size == (cfg.intermediate_size // mapping.tp_size) + assert bindings_cfg.tokens_per_block == tokens_per_block