1010from tensorrt_llm ._utils import str_dtype_to_binding , torch_dtype_to_str
1111from tensorrt_llm .bindings .executor import DecodingMode
1212from tensorrt_llm .llmapi .llm_args import (PeftCacheConfig , SamplerType ,
13- SpeculativeConfig , SparseAttentionConfig )
13+ SparseAttentionConfig ,
14+ SpeculativeConfig )
1415from tensorrt_llm .logger import logger
1516from tensorrt_llm .lora_helper import (LoraConfig ,
1617 get_default_trtllm_modules_to_hf_modules )
4041GB = 1 << 30
4142
4243
43- def get_kv_cache_manager_cls (model_config : ModelConfig ,
44- executor_config : ExecutorConfig ):
44+ def get_kv_cache_manager_cls (model_config : ModelConfig ):
4545 config = model_config .pretrained_config
46- sparse_attn_config = executor_config .sparse_attention_config
46+ sparse_attn_config = model_config .sparse_attention_config
4747 if is_mla (config ):
4848 return KVCacheManager
4949 elif is_nemotron_hybrid (config ):
@@ -93,46 +93,7 @@ def __init__(
9393 self ._max_seq_len = max_seq_len
9494 self ._max_batch_size = max_batch_size
9595 self ._kv_cache_manager_cls = get_kv_cache_manager_cls (
96- model_engine .model .model_config , executor_config )
97-
98- @staticmethod
99- def _get_cache_size_per_token (model_config : ModelConfig ,
100- mapping : Mapping ) -> int :
101- mem_per_token = 2
102- quant_config = model_config .quant_config
103- if quant_config is not None and quant_config .quant_mode .has_fp8_kv_cache (
104- ):
105- mem_per_token = 1
106-
107- config = model_config .pretrained_config
108-
109- num_key_value_heads = getattr (config , 'num_key_value_heads' ,
110- config .num_attention_heads )
111- if isinstance (num_key_value_heads , Iterable ):
112- num_key_value_heads = sum (num_key_value_heads ) / len (
113- num_key_value_heads )
114-
115- mla = is_mla (config )
116- tp_size = 1 if mapping .enable_attention_dp else mapping .tp_size
117-
118- kv_factor = 2
119- if mla :
120- # MLA has kv_lora_rank and qk_rope_head_dim
121- head_dim = config .kv_lora_rank + config .qk_rope_head_dim
122- kv_factor = 1
123- else :
124- _head_dim = getattr (config , 'head_dim' , None )
125- if not isinstance (_head_dim , int ):
126- _head_dim = config .hidden_size // config .num_attention_heads
127- head_dim = _head_dim * num_key_value_heads // tp_size
128-
129- # provide at least 1 layer to prevent division by zero cache size
130- num_attention_layers = max (
131- len (mapping .pp_layers (model_config .get_num_attention_layers ())), 1 )
132- mem_per_token *= num_attention_layers * head_dim
133- # K and V
134- mem_per_token *= kv_factor
135- return mem_per_token
96+ model_engine .model .model_config )
13697
13798 def _get_free_gpu_memory_fraction (self ) -> float :
13899 fraction = self ._kv_cache_config .free_gpu_memory_fraction
@@ -144,11 +105,11 @@ def _get_kv_size_per_token(self):
144105 model_config = self ._model_engine .model .model_config
145106 mapping = self ._mapping
146107 kv_size_per_token = self ._kv_cache_manager_cls .get_cache_size_per_token (
147- model_config , self ._executor_config , mapping )
108+ model_config , self ._tokens_per_block , mapping )
148109 if self ._draft_model_engine is not None :
149110 draft_model_config = self ._draft_model_engine .model .model_config
150111 kv_size_per_token += self ._kv_cache_manager_cls .get_cache_size_per_token (
151- draft_model_config , self ._executor_config , mapping )
112+ draft_model_config , self ._tokens_per_block , mapping )
152113 return kv_size_per_token
153114
154115 def _cal_max_memory (self , peak_memory , total_gpu_memory , fraction ,
0 commit comments