Skip to content

Commit e363d4d

Browse files
Correct attention handling in ModelConfig and KVCacheManager
Fix the model config binding so KVCacheManager can compute required cache blocks with accurate head/hidden sizes. Changes: - Updated hidden size and key-value head calculations with the correct attention TP and CP sizes. - Added enable_attention_dp parameter to KVCacheManager for improved resource management. Signed-off-by: Jaedeok Kim <jaedeokk@nvidia.com>
1 parent 6732c76 commit e363d4d

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

tensorrt_llm/_torch/model_config.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -495,10 +495,15 @@ def get_bindings_model_config(self,
495495
# TODO smor- currently assuming no rnn layers, no MOE
496496
from tensorrt_llm.bindings import ModelConfig as ModelConfigCpp
497497

498+
# Attention DP should not shard attention heads; use attn_tp_size=1 in that case
499+
# so downstream KV calculations see the full (non-partitioned) head count.
500+
attn_tp_size = self.mapping.attn_tp_size if not self.mapping.enable_attention_dp else 1
501+
attn_cp_size = self.mapping.attn_cp_size
502+
498503
num_heads = self.pretrained_config.num_attention_heads // (
499-
self.mapping.tp_size * self.mapping.cp_size)
504+
attn_tp_size * attn_cp_size)
500505

501-
hidden_size = self.pretrained_config.hidden_size // self.mapping.tp_size
506+
hidden_size = self.pretrained_config.hidden_size // attn_tp_size
502507

503508
model_config_cpp = ModelConfigCpp(
504509
vocab_size=self.pretrained_config.vocab_size,
@@ -523,13 +528,12 @@ def get_bindings_model_config(self,
523528
if isinstance(num_key_value_heads, (list, tuple)):
524529
# Per-layer KV heads (e.g., Nemotron-NAS, variable GQA models)
525530
num_kv_heads_per_layer = [
526-
kv_heads // (self.mapping.tp_size * self.mapping.cp_size)
531+
kv_heads // (attn_tp_size * attn_cp_size)
527532
for kv_heads in num_key_value_heads
528533
]
529534
model_config_cpp.num_kv_heads_per_layer = num_kv_heads_per_layer
530535
else:
531-
num_kv_heads = num_key_value_heads // (self.mapping.tp_size *
532-
self.mapping.cp_size)
536+
num_kv_heads = num_key_value_heads // (attn_tp_size * attn_cp_size)
533537
model_config_cpp.set_num_kv_heads(num_kv_heads)
534538

535539
mlp_hidden_size = None

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1133,7 +1133,8 @@ def calculate_max_num_blocks_from_cpp(
11331133
tensor_parallelism=self.mapping.tp_size,
11341134
pipeline_parallelism=self.mapping.pp_size,
11351135
rank=self.mapping.rank,
1136-
gpus_per_node=self.mapping.gpus_per_node)
1136+
gpus_per_node=self.mapping.gpus_per_node,
1137+
enable_attention_dp=self.mapping.enable_attention_dp)
11371138

11381139
window_size_to_layers = self._get_window_size_to_layers()
11391140
logger.debug(f"window_size_to_layers: {window_size_to_layers}")

0 commit comments

Comments
 (0)