33"""KV-Cache Utilities."""
44import os
55from collections import deque
6- from collections .abc import Sequence
6+ from collections .abc import Iterable , Sequence
77from dataclasses import dataclass
88from typing import Any , Callable , NamedTuple , Optional
99
1010from vllm .config import VllmConfig
1111from vllm .logger import init_logger
12- from vllm .utils import GiB_bytes , sha256
12+ from vllm .utils import GiB_bytes , cdiv , sha256
1313from vllm .v1 .kv_cache_interface import (FullAttentionSpec , KVCacheConfig ,
1414 KVCacheGroupSpec , KVCacheSpec ,
1515 KVCacheTensor , SlidingWindowSpec )
@@ -468,6 +468,15 @@ def hash_request_tokens(hash_function: Any, block_size: int,
468468 return ret
469469
470470
471+ def max_memory_usage_bytes (vllm_config : VllmConfig ,
472+ kv_cache_specs : Iterable [KVCacheSpec ]) -> int :
473+ """
474+ Get the maximum memory usage in bytes for the given KV cache specs.
475+ """
476+ return sum (
477+ spec .max_memory_usage_bytes (vllm_config ) for spec in kv_cache_specs )
478+
479+
471480def estimate_max_model_len (vllm_config : VllmConfig ,
472481 kv_cache_spec : dict [str , KVCacheSpec ],
473482 available_memory : int ) -> int :
@@ -489,11 +498,8 @@ def fits_in_memory(model_len: int) -> bool:
489498 # Modify the max_model_len for this calculation
490499 vllm_config .model_config .max_model_len = model_len
491500 # Calculate memory needed for the given model length
492- memory_needed = sum (
493- (layer_spec .max_memory_usage_bytes (vllm_config )
494- for layer_spec in kv_cache_spec .values ()),
495- start = 0 ,
496- )
501+ memory_needed = max_memory_usage_bytes (vllm_config ,
502+ kv_cache_spec .values ())
497503 return memory_needed <= available_memory
498504
499505 # Binary search for the maximum model length
@@ -538,9 +544,7 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
538544 "initializing the engine." )
539545
540546 max_model_len = vllm_config .model_config .max_model_len
541- needed_memory = 0
542- for layer_spec in kv_cache_spec .values ():
543- needed_memory += layer_spec .max_memory_usage_bytes (vllm_config )
547+ needed_memory = max_memory_usage_bytes (vllm_config , kv_cache_spec .values ())
544548
545549 if needed_memory > available_memory :
546550 # Estimate the maximum model length that can fit in the available memory
@@ -606,6 +610,24 @@ def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
606610 return len (layer_keys ) == 1
607611
608612
613+ def get_max_concurrency_for_kv_cache_config (
614+ vllm_config : VllmConfig , kv_cache_config : KVCacheConfig ) -> float :
615+ """
616+ Get the maximum concurrency for the given KV cache configuration.
617+ """
618+ num_layer_per_group = max (
619+ len (group .layer_names ) for group in kv_cache_config .kv_cache_groups )
620+ max_memory_usage_per_request = num_layer_per_group * max_memory_usage_bytes (
621+ vllm_config ,
622+ (group .kv_cache_spec for group in kv_cache_config .kv_cache_groups ))
623+ memory_per_block = kv_cache_config .kv_cache_groups [
624+ 0 ].kv_cache_spec .page_size_bytes * num_layer_per_group
625+ num_block_per_request = cdiv (max_memory_usage_per_request ,
626+ memory_per_block )
627+ max_concurrency = kv_cache_config .num_blocks / num_block_per_request
628+ return max_concurrency
629+
630+
609631def _get_kv_cache_config_uniform_type (vllm_config : VllmConfig ,
610632 kv_cache_spec : dict [str , KVCacheSpec ],
611633 available_memory : int ) -> KVCacheConfig :
@@ -637,14 +659,6 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
637659 "num_gpu_blocks_override=%d" , num_blocks , num_gpu_blocks_override )
638660 num_blocks = num_gpu_blocks_override
639661
640- num_tokens = num_blocks * vllm_config .cache_config .block_size
641- num_tokens_str = f"{ num_tokens :,} "
642- logger .info ("GPU KV cache size: %s tokens" , num_tokens_str )
643- max_model_len_str = f"{ vllm_config .model_config .max_model_len :,} "
644- max_concurrency = num_tokens / vllm_config .model_config .max_model_len
645- logger .info ("Maximum concurrency for %s tokens per request: %.2fx" ,
646- max_model_len_str , max_concurrency )
647-
648662 per_layer_size = page_size * num_blocks
649663 # All layers have the same KV cache spec, so we create one kv cache group
650664 # for all layers.
@@ -659,6 +673,15 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
659673 kv_cache_groups = create_kv_cache_group_specs (kv_cache_spec ,
660674 grouped_layer_names ),
661675 )
676+
677+ num_tokens = num_blocks * vllm_config .cache_config .block_size
678+ num_tokens_str = f"{ num_tokens :,} "
679+ logger .info ("GPU KV cache size: %s tokens" , num_tokens_str )
680+ max_model_len_str = f"{ vllm_config .model_config .max_model_len :,} "
681+ max_concurrency = get_max_concurrency_for_kv_cache_config (
682+ vllm_config , kv_cache_config )
683+ logger .info ("Maximum concurrency for %s tokens per request: %.2fx" ,
684+ max_model_len_str , max_concurrency )
662685 return kv_cache_config
663686
664687
@@ -705,8 +728,8 @@ def get_kv_cache_config(vllm_config: VllmConfig,
705728 Returns:
706729 The generated KVCacheConfigs
707730 """
708- check_enough_kv_cache_memory (vllm_config , kv_cache_spec , available_memory )
709731 unify_hybrid_kv_cache_specs (kv_cache_spec )
732+ check_enough_kv_cache_memory (vllm_config , kv_cache_spec , available_memory )
710733 if is_kv_cache_type_uniform (kv_cache_spec ):
711734 # KV cache of all layers are the same, which is true for
712735 # most models. Allocate the same amount of memory for
0 commit comments