|
11 | 11 | import tensorrt_llm.bindings |
12 | 12 | from tensorrt_llm._utils import mpi_disabled |
13 | 13 | from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE |
| 14 | +from tensorrt_llm.llmapi.llm_args import KvCacheConfig, PybindMirror |
14 | 15 | from tensorrt_llm.lora_helper import LoraConfig |
15 | 16 | from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig |
16 | 17 | from tensorrt_llm.runtime import ModelConfig as ModelConfigPython |
|
31 | 32 |
|
32 | 33 | BufferManagerCpp = tensorrt_llm.bindings.internal.runtime.BufferManager |
33 | 34 | KVCacheManagerCpp = tensorrt_llm.bindings.internal.batch_manager.KVCacheManager |
34 | | -KvCacheConfigCpp = tensorrt_llm.bindings.executor.KvCacheConfig |
35 | 35 | CacheTypeCpp = tensorrt_llm.bindings.internal.batch_manager.CacheType |
36 | 36 | ModelConfigCpp = tensorrt_llm.bindings.ModelConfig |
37 | 37 | DataType = tensorrt_llm.bindings.DataType |
@@ -145,7 +145,7 @@ class KVCacheManager(BaseResourceManager): |
145 | 145 |
|
146 | 146 | def __init__( |
147 | 147 | self, |
148 | | - kv_cache_config: KvCacheConfigCpp, |
| 148 | + kv_cache_config: KvCacheConfig, |
149 | 149 | kv_cache_type: CacheTypeCpp, |
150 | 150 | *, |
151 | 151 | num_layers: int, |
@@ -268,8 +268,8 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int], |
268 | 268 | ) |
269 | 269 | # kv cache config check |
270 | 270 | assert isinstance( |
271 | | - kv_cache_config, KvCacheConfigCpp |
272 | | - ), "calculate_max_num_blocks_from_cpp only accepts KvCacheConfigCpp" |
| 271 | + kv_cache_config, KvCacheConfig |
| 272 | + ), "calculate_max_num_blocks_from_cpp only accepts KvCacheConfig" |
273 | 273 | blocks_per_window = self.calculate_max_num_blocks_from_cpp( |
274 | 274 | kv_cache_config=kv_cache_config, |
275 | 275 | model_config=model_config, |
@@ -370,28 +370,6 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int], |
370 | 370 | def shutdown(self): |
371 | 371 | self.impl.release_pools() |
372 | 372 |
|
373 | | - @classmethod |
374 | | - def from_model_config(cls, |
375 | | - model_config: ModelConfigCpp, |
376 | | - kv_cache_config: KvCacheConfigCpp, |
377 | | - mapping: Mapping, |
378 | | - kv_cache_type: CacheTypeCpp = CacheTypeCpp.SELF, |
379 | | - dtype: DataType = DataType.HALF) -> "KVCacheManager": |
380 | | - return cls( |
381 | | - kv_cache_config, |
382 | | - kv_cache_type, |
383 | | - num_layers=model_config.num_attention_layers(mapping.pp_size), |
384 | | - # NOTE: this preserves existing behavior in KV cache manager. |
385 | | - # But we should change this to pass a list at some point. |
386 | | - # We're assuming the KV cache is homogeneous here. |
387 | | - num_kv_heads=model_config.num_kv_heads(0), |
388 | | - head_dim=model_config.size_per_head, |
389 | | - tokens_per_block=model_config.tokens_per_block, |
390 | | - max_seq_len=model_config.max_seq_len, |
391 | | - max_batch_size=model_config.max_batch_size, |
392 | | - mapping=mapping, |
393 | | - dtype=dtype) |
394 | | - |
395 | 373 | def get_max_resource_count(self) -> int: |
396 | 374 | return self.impl.max_num_blocks |
397 | 375 |
|
@@ -566,7 +544,7 @@ def calculate_scaling_factor_size_bytes( |
566 | 544 | scaling_factor_dtype) |
567 | 545 |
|
568 | 546 | def calculate_max_num_blocks(self, |
569 | | - kv_cache_config: KvCacheConfigCpp, |
| 547 | + kv_cache_config: KvCacheConfig, |
570 | 548 | head_dim: int, |
571 | 549 | tokens_per_block: int, |
572 | 550 | mapping: Mapping, |
@@ -772,7 +750,7 @@ def _get_window_size_to_layers(self) -> dict[int, list[int]]: |
772 | 750 | def adjust_window_sizes_for_vswa( |
773 | 751 | window_size_to_layers: Dict[int, List[int]], |
774 | 752 | max_attention_window_vec: List[int], |
775 | | - kv_cache_config: KvCacheConfigCpp, |
| 753 | + kv_cache_config: KvCacheConfig, |
776 | 754 | model_config: ModelConfigCpp, |
777 | 755 | pool_memory_bytes: int, |
778 | 756 | kv_factor: int, |
@@ -887,7 +865,7 @@ def calculate_cache_size_per_token(layers: Set[int]) -> int: |
887 | 865 |
|
888 | 866 | def calculate_max_num_blocks_from_cpp( |
889 | 867 | self, |
890 | | - kv_cache_config: KvCacheConfigCpp, |
| 868 | + kv_cache_config: KvCacheConfig, |
891 | 869 | model_config: ModelConfigCpp, |
892 | 870 | extra_cost_memory: int = 0) -> dict[int, tuple[int, int]]: |
893 | 871 | """ |
@@ -945,7 +923,7 @@ def calculate_max_num_blocks_from_cpp( |
945 | 923 | self.max_attention_window_vec = max_attention_window_vec |
946 | 924 |
|
947 | 925 | blocks_per_window = KVCacheManagerCpp.calculate_max_num_blocks( |
948 | | - config=kv_cache_config, |
| 926 | + config=PybindMirror.maybe_to_pybind(kv_cache_config), |
949 | 927 | # TODO: support cross attention |
950 | 928 | is_cross_attention=is_cross_attention, |
951 | 929 | dtype=self.dtype, |
|
0 commit comments