Skip to content

Commit a8da78e

Browse files
authored
[Bugfix] Max concurrency estimation and check_enough_kv_cache_memory for models with sliding window layers (vllm-project#19029)
Signed-off-by: Chen Zhang <[email protected]>
1 parent 5d96533 commit a8da78e

File tree

2 files changed

+125
-26
lines changed

2 files changed

+125
-26
lines changed

tests/v1/core/test_kv_cache_utils.py

Lines changed: 83 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,11 @@
1212
from vllm.v1.core.kv_cache_manager import KVCacheManager
1313
# disable yapf here as it formats differently than isort such that both fail
1414
# yapf: disable
15-
from vllm.v1.core.kv_cache_utils import (FreeKVCacheBlockQueue, KVCacheBlock,
16-
PrefixCachingMetrics,
17-
estimate_max_model_len,
18-
generate_block_hash_extra_keys,
19-
hash_block_tokens,
20-
hash_request_tokens,
21-
unify_kv_cache_configs)
15+
from vllm.v1.core.kv_cache_utils import (
16+
FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
17+
estimate_max_model_len, generate_block_hash_extra_keys,
18+
get_max_concurrency_for_kv_cache_config, hash_block_tokens,
19+
hash_request_tokens, unify_kv_cache_configs)
2220
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
2321
KVCacheGroupSpec, KVCacheTensor,
2422
SlidingWindowSpec)
@@ -597,6 +595,84 @@ def test_estimate_max_model_len(model_id, max_model_len,
597595
assert estimated_max_len == want_estimated_max_len
598596

599597

598+
def test_get_max_concurrency_for_kv_cache_config():
599+
# Create a VllmConfig
600+
model_id = "Qwen/Qwen1.5-7B"
601+
max_model_len = 16384
602+
model_config = ModelConfig(
603+
model_id,
604+
task="generate",
605+
tokenizer=model_id,
606+
tokenizer_mode="auto",
607+
trust_remote_code=False,
608+
seed=0,
609+
dtype="float16",
610+
max_model_len=max_model_len,
611+
)
612+
scheduler_config = SchedulerConfig(max_num_batched_tokens=1024,
613+
enable_chunked_prefill=True)
614+
615+
vllm_config = VllmConfig(
616+
model_config=model_config,
617+
scheduler_config=scheduler_config,
618+
)
619+
620+
full_attention_spec = FullAttentionSpec(
621+
block_size=16,
622+
num_kv_heads=32,
623+
head_size=128,
624+
dtype=torch.float16,
625+
use_mla=False,
626+
)
627+
628+
sliding_window_spec = SlidingWindowSpec(
629+
block_size=16,
630+
num_kv_heads=32,
631+
head_size=128,
632+
dtype=torch.float16,
633+
use_mla=False,
634+
sliding_window=1024,
635+
)
636+
637+
kv_cache_config_full_attention = KVCacheConfig(
638+
num_blocks=int(1024 * 1.5),
639+
tensors={},
640+
kv_cache_groups=[
641+
KVCacheGroupSpec([f"layer_{i}" for i in range(32)],
642+
full_attention_spec),
643+
],
644+
)
645+
max_concurrency_full_attention = get_max_concurrency_for_kv_cache_config(
646+
vllm_config, kv_cache_config_full_attention)
647+
assert max_concurrency_full_attention == 1.5
648+
649+
kv_cache_config_sliding_window = KVCacheConfig(
650+
num_blocks=129 * 3,
651+
tensors={},
652+
kv_cache_groups=[
653+
KVCacheGroupSpec([f"layer_{i}" for i in range(32)],
654+
sliding_window_spec),
655+
],
656+
)
657+
max_concurrency_sliding_window = get_max_concurrency_for_kv_cache_config(
658+
vllm_config, kv_cache_config_sliding_window)
659+
assert max_concurrency_sliding_window == 3
660+
661+
kv_cache_config_hybrid_model = KVCacheConfig(
662+
num_blocks=(1024 + 129) * 3,
663+
tensors={},
664+
kv_cache_groups=[
665+
KVCacheGroupSpec([f"layer_{i}" for i in range(32)],
666+
full_attention_spec),
667+
KVCacheGroupSpec([f"layer_{i}" for i in range(32, 64)],
668+
sliding_window_spec),
669+
],
670+
)
671+
max_concurrency_hybrid_model = get_max_concurrency_for_kv_cache_config(
672+
vllm_config, kv_cache_config_hybrid_model)
673+
assert max_concurrency_hybrid_model == 3
674+
675+
600676
def test_allocate_with_lookahead():
601677
"""Verify that lookahead tokens correctly affect block allocation"""
602678
block_size = 4

vllm/v1/core/kv_cache_utils.py

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
"""KV-Cache Utilities."""
44
import os
55
from collections import deque
6-
from collections.abc import Sequence
6+
from collections.abc import Iterable, Sequence
77
from dataclasses import dataclass
88
from typing import Any, Callable, NamedTuple, Optional
99

1010
from vllm.config import VllmConfig
1111
from vllm.logger import init_logger
12-
from vllm.utils import GiB_bytes, sha256
12+
from vllm.utils import GiB_bytes, cdiv, sha256
1313
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
1414
KVCacheGroupSpec, KVCacheSpec,
1515
KVCacheTensor, SlidingWindowSpec)
@@ -468,6 +468,15 @@ def hash_request_tokens(hash_function: Any, block_size: int,
468468
return ret
469469

470470

471+
def max_memory_usage_bytes(vllm_config: VllmConfig,
472+
kv_cache_specs: Iterable[KVCacheSpec]) -> int:
473+
"""
474+
Get the maximum memory usage in bytes for the given KV cache specs.
475+
"""
476+
return sum(
477+
spec.max_memory_usage_bytes(vllm_config) for spec in kv_cache_specs)
478+
479+
471480
def estimate_max_model_len(vllm_config: VllmConfig,
472481
kv_cache_spec: dict[str, KVCacheSpec],
473482
available_memory: int) -> int:
@@ -489,11 +498,8 @@ def fits_in_memory(model_len: int) -> bool:
489498
# Modify the max_model_len for this calculation
490499
vllm_config.model_config.max_model_len = model_len
491500
# Calculate memory needed for the given model length
492-
memory_needed = sum(
493-
(layer_spec.max_memory_usage_bytes(vllm_config)
494-
for layer_spec in kv_cache_spec.values()),
495-
start=0,
496-
)
501+
memory_needed = max_memory_usage_bytes(vllm_config,
502+
kv_cache_spec.values())
497503
return memory_needed <= available_memory
498504

499505
# Binary search for the maximum model length
@@ -538,9 +544,7 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
538544
"initializing the engine.")
539545

540546
max_model_len = vllm_config.model_config.max_model_len
541-
needed_memory = 0
542-
for layer_spec in kv_cache_spec.values():
543-
needed_memory += layer_spec.max_memory_usage_bytes(vllm_config)
547+
needed_memory = max_memory_usage_bytes(vllm_config, kv_cache_spec.values())
544548

545549
if needed_memory > available_memory:
546550
# Estimate the maximum model length that can fit in the available memory
@@ -606,6 +610,24 @@ def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
606610
return len(layer_keys) == 1
607611

608612

613+
def get_max_concurrency_for_kv_cache_config(
614+
vllm_config: VllmConfig, kv_cache_config: KVCacheConfig) -> float:
615+
"""
616+
Get the maximum concurrency for the given KV cache configuration.
617+
"""
618+
num_layer_per_group = max(
619+
len(group.layer_names) for group in kv_cache_config.kv_cache_groups)
620+
max_memory_usage_per_request = num_layer_per_group * max_memory_usage_bytes(
621+
vllm_config,
622+
(group.kv_cache_spec for group in kv_cache_config.kv_cache_groups))
623+
memory_per_block = kv_cache_config.kv_cache_groups[
624+
0].kv_cache_spec.page_size_bytes * num_layer_per_group
625+
num_block_per_request = cdiv(max_memory_usage_per_request,
626+
memory_per_block)
627+
max_concurrency = kv_cache_config.num_blocks / num_block_per_request
628+
return max_concurrency
629+
630+
609631
def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
610632
kv_cache_spec: dict[str, KVCacheSpec],
611633
available_memory: int) -> KVCacheConfig:
@@ -637,14 +659,6 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
637659
"num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override)
638660
num_blocks = num_gpu_blocks_override
639661

640-
num_tokens = num_blocks * vllm_config.cache_config.block_size
641-
num_tokens_str = f"{num_tokens:,}"
642-
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
643-
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
644-
max_concurrency = num_tokens / vllm_config.model_config.max_model_len
645-
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
646-
max_model_len_str, max_concurrency)
647-
648662
per_layer_size = page_size * num_blocks
649663
# All layers have the same KV cache spec, so we create one kv cache group
650664
# for all layers.
@@ -659,6 +673,15 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
659673
kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec,
660674
grouped_layer_names),
661675
)
676+
677+
num_tokens = num_blocks * vllm_config.cache_config.block_size
678+
num_tokens_str = f"{num_tokens:,}"
679+
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
680+
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
681+
max_concurrency = get_max_concurrency_for_kv_cache_config(
682+
vllm_config, kv_cache_config)
683+
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
684+
max_model_len_str, max_concurrency)
662685
return kv_cache_config
663686

664687

@@ -705,8 +728,8 @@ def get_kv_cache_config(vllm_config: VllmConfig,
705728
Returns:
706729
The generated KVCacheConfigs
707730
"""
708-
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
709731
unify_hybrid_kv_cache_specs(kv_cache_spec)
732+
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
710733
if is_kv_cache_type_uniform(kv_cache_spec):
711734
# KV cache of all layers are the same, which is true for
712735
# most models. Allocate the same amount of memory for

0 commit comments

Comments
 (0)