Skip to content

Commit f01928d

Browse files
tanruixiangvadiklyutiy
authored andcommitted
[Misc] Improve code readability of KVCacheManager (vllm-project#21673)
Signed-off-by: tanruixiang <[email protected]> Signed-off-by: Ruixiang Tan <[email protected]> Signed-off-by: GitHub <[email protected]>
1 parent b04c499 commit f01928d

File tree

6 files changed

+18
-22
lines changed

6 files changed

+18
-22
lines changed

tests/v1/core/test_kv_cache_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,9 @@ def test_kv_cache_block():
112112
assert block.block_hash is None
113113

114114
# Test reference count manipulation
115-
block.incr_ref()
115+
block.ref_cnt += 1
116116
assert block.ref_cnt == 1
117-
block.decr_ref()
117+
block.ref_cnt -= 1
118118
assert block.ref_cnt == 0
119119

120120
# Test block hash setting and resetting

vllm/v1/core/block_pool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None:
276276
# candidate), so remove it.
277277
if block.ref_cnt == 0 and not block.is_null:
278278
self.free_block_queue.remove(block)
279-
block.incr_ref()
279+
block.ref_cnt += 1
280280

281281
def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
282282
"""Free a list of blocks. The blocks should be ordered by their

vllm/v1/core/kv_cache_coordinator.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,14 +126,17 @@ def free(self, request_id: str) -> None:
126126
def get_num_common_prefix_blocks(self, request_id: str,
127127
num_running_requests: int) -> list[int]:
128128
"""
129-
Get the number of common prefix blocks for a request.
129+
Get the number of common prefix blocks for all requests in the RUNNING
130+
state for each kv cache group.
130131
131132
Args:
132133
request_id: The request ID.
133-
num_running_requests: The number of requests in the RUNNING state.
134+
num_running_requests: The total number of requests in the RUNNING
135+
state.
134136
135137
Returns:
136-
list[int]: The number of common prefix blocks.
138+
list[int]: The number of common prefix blocks for all requests in
139+
the RUNNING state for each kv cache group.
137140
"""
138141
num_blocks_per_group = [
139142
manager.get_num_common_prefix_blocks(request_id,

vllm/v1/core/kv_cache_manager.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,6 @@ def get_computed_blocks(self,
170170
self.block_size, request)
171171
self.req_to_block_hashes[request.request_id] = block_hashes
172172

173-
if self.log_stats:
174-
assert self.prefix_cache_stats is not None
175-
self.prefix_cache_stats.requests += 1
176-
177173
# NOTE: When all tokens hit the cache, we must recompute the last token
178174
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
179175
# This can trigger recomputation of an entire block, rather than just
@@ -187,6 +183,7 @@ def get_computed_blocks(self,
187183

188184
if self.log_stats:
189185
assert self.prefix_cache_stats is not None
186+
self.prefix_cache_stats.requests += 1
190187
self.prefix_cache_stats.queries += request.num_tokens
191188
self.prefix_cache_stats.hits += num_new_computed_tokens
192189

vllm/v1/core/kv_cache_utils.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -154,14 +154,6 @@ class KVCacheBlock:
154154
# Whether the block is a null block that should never be cached.
155155
is_null: bool = False
156156

157-
# TODO(Jialin): For performance, let callers handle ref_cnt bumps to
158-
# avoid function calls.
159-
def incr_ref(self):
160-
self.ref_cnt += 1
161-
162-
def decr_ref(self):
163-
self.ref_cnt -= 1
164-
165157
@property
166158
def block_hash(self) -> Optional[BlockHashWithGroupId]:
167159
return self._block_hash

vllm/v1/core/single_type_kv_cache_manager.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import itertools
34
from abc import ABC, abstractmethod
45
from collections import defaultdict
56
from typing import Callable
@@ -177,14 +178,17 @@ def free(self, request_id: str) -> None:
177178
def get_num_common_prefix_blocks(self, request_id: str,
178179
num_running_requests: int) -> int:
179180
"""
180-
Get the number of common prefix blocks for a request.
181+
Get the number of common prefix blocks for all requests in the RUNNING
182+
state.
181183
182184
Args:
183185
request_id: The request ID.
184-
num_running_requests: The number of requests in the RUNNING state.
186+
num_running_requests: The total number of requests in the RUNNING
187+
state.
185188
186189
Returns:
187-
The number of common prefix blocks.
190+
The number of common prefix blocks for all requests in the RUNNING
191+
state.
188192
"""
189193

190194
raise NotImplementedError
@@ -264,7 +268,7 @@ def find_longest_cache_hit(
264268
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
265269
[] for _ in range(len(kv_cache_group_ids)))
266270
max_num_blocks = max_length // kv_cache_spec.block_size
267-
for i, block_hash in zip(range(max_num_blocks), block_hashes):
271+
for block_hash in itertools.islice(block_hashes, max_num_blocks):
268272
# block_hashes is a chain of block hashes. If a block hash is not
269273
# in the cached_block_hash_to_id, the following block hashes are
270274
# not computed yet for sure.

0 commit comments

Comments
 (0)