Skip to content

Commit 3d5f1c8

Browse files
authored
[Mamba][KVCacheManager] Simplify kv cache manage logic for mamba + MTP (vllm-project#25119)
Signed-off-by: Chen Zhang <[email protected]>
1 parent 1cab2f9 commit 3d5f1c8

File tree

1 file changed

+4
-25
lines changed

1 file changed

+4
-25
lines changed

vllm/v1/core/single_type_kv_cache_manager.py

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -565,35 +565,14 @@ def get_num_common_prefix_blocks(self, request_id: str,
565565
def get_num_blocks_to_allocate(
566566
self, request_id: str, num_tokens: int,
567567
new_computed_blocks: list[KVCacheBlock]) -> int:
568-
"""
569-
Get the number of blocks needed to be allocated for the request.
570-
571-
Args:
572-
request_id: The request ID.
573-
num_tokens: The total number of tokens that need a slot (including
574-
tokens that are already allocated).
575-
new_computed_blocks: The new computed blocks just hitting the
576-
prefix caching.
577-
578-
Returns:
579-
The number of blocks
580-
"""
581-
568+
# Allocate extra `num_speculative_blocks` blocks for
569+
# speculative decoding (MTP/EAGLE) with linear attention.
582570
assert isinstance(self.kv_cache_spec, MambaSpec)
583571
if self.kv_cache_spec.num_speculative_blocks > 0:
584572
num_tokens += (self.kv_cache_spec.block_size *
585573
self.kv_cache_spec.num_speculative_blocks)
586-
num_required_blocks = cdiv(num_tokens, self.block_size)
587-
num_new_blocks = (num_required_blocks - len(new_computed_blocks) -
588-
len(self.req_to_blocks[request_id]))
589-
# If a computed block of a request is an eviction candidate (in the
590-
# free queue and ref_cnt == 0), it will be changed from a free block
591-
# to a computed block when the request is allocated, so we also count
592-
# it as needed to be allocated.
593-
num_evictable_computed_blocks = sum(
594-
blk.ref_cnt == 0 and not blk.is_null
595-
for blk in new_computed_blocks)
596-
return num_new_blocks + num_evictable_computed_blocks
574+
return super().get_num_blocks_to_allocate(request_id, num_tokens,
575+
new_computed_blocks)
597576

598577
def allocate_new_blocks(self, request_id: str,
599578
num_tokens: int) -> list[KVCacheBlock]:

0 commit comments

Comments
 (0)