Skip to content

Commit 8edbf78

Browse files
committed
Reimplement on top of PR vllm-project#21329
Signed-off-by: Jialin Ouyang <[email protected]>
1 parent c3752dc commit 8edbf78

File tree

7 files changed

+279
-259
lines changed

7 files changed

+279
-259
lines changed

vllm/v1/core/block_pool.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88
BlockStored, KVCacheEvent)
99
from vllm.logger import init_logger
1010
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
11-
FreeKVCacheBlockQueue, KVCacheBlock,
12-
generate_block_hash_extra_keys,
13-
hash_block_tokens)
14-
from vllm.v1.request import Request
11+
FreeKVCacheBlockQueue, KVCacheBlock)
12+
from vllm.v1.request import (Request, generate_block_hash_extra_keys,
13+
hash_block_tokens)
1514

1615
logger = init_logger(__name__)
1716

vllm/v1/core/kv_cache_common.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""KV-Cache Types."""
4+
from typing import Any, NamedTuple, Optional
5+
6+
7+
class BlockHash(NamedTuple):
8+
"""Hash value of a block (int), the token IDs in the block, and extra keys.
9+
We keep a tuple of token IDs and extra keys to reduce the likelihood of
10+
hash collisions when the hash value is the same. By using SHA256 however,
11+
hash collisions are practically impossible.
12+
"""
13+
# Hash value of the block in an integer.
14+
hash_value: int
15+
# Token IDs in the block.
16+
token_ids: tuple[int, ...]
17+
# Extra keys for the block.
18+
extra_keys: Optional[Any] = None

vllm/v1/core/kv_cache_manager.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
from vllm.logger import init_logger
1010
from vllm.utils import sha256, sha256_cbor_64bit
1111
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
12-
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
13-
hash_request_tokens, init_none_hash)
12+
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
1413
from vllm.v1.kv_cache_interface import KVCacheConfig
1514
from vllm.v1.metrics.stats import PrefixCacheStats
16-
from vllm.v1.request import Request, RequestStatus
15+
from vllm.v1.request import (Request, RequestStatus, hash_request_tokens,
16+
init_none_hash)
1717

1818
logger = init_logger(__name__)
1919

@@ -166,8 +166,8 @@ def get_computed_blocks(self,
166166
block_hashes = self.req_to_block_hashes[request.request_id]
167167
if not block_hashes:
168168
assert self.block_size is not None
169-
block_hashes = hash_request_tokens(self.caching_hash_fn,
170-
self.block_size, request)
169+
block_hashes = request.precomputed_block_hashes if request.precomputed_block_hashes is not None else hash_request_tokens(
170+
self.caching_hash_fn, self.block_size, request)
171171
self.req_to_block_hashes[request.request_id] = block_hashes
172172

173173
if self.log_stats:

vllm/v1/core/kv_cache_utils.py

Lines changed: 4 additions & 246 deletions
Original file line numberDiff line numberDiff line change
@@ -2,39 +2,24 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""KV-Cache Utilities."""
44

5-
import os
65
from collections import defaultdict, deque
7-
from collections.abc import Iterable, Sequence
6+
from collections.abc import Iterable
87
from dataclasses import dataclass
9-
from typing import Any, Callable, NamedTuple, Optional
8+
from typing import NamedTuple, Optional
109

1110
from vllm.config import VllmConfig
1211
from vllm.logger import init_logger
13-
from vllm.utils import GiB_bytes, cdiv, sha256_cbor_64bit
12+
from vllm.utils import GiB_bytes, cdiv
13+
from vllm.v1.core.kv_cache_common import BlockHash
1414
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
1515
FullAttentionSpec, KVCacheConfig,
1616
KVCacheGroupSpec, KVCacheSpec,
1717
KVCacheTensor, SlidingWindowSpec)
1818
from vllm.v1.metrics.stats import PrefixCacheStats
19-
from vllm.v1.request import Request
2019

2120
logger = init_logger(__name__)
2221

2322

24-
class BlockHash(NamedTuple):
25-
"""Hash value of a block (int), the token IDs in the block, and extra keys.
26-
We keep a tuple of token IDs and extra keys to reduce the likelihood of
27-
hash collisions when the hash value is the same. By using SHA256 however,
28-
hash collisions are practically impossible.
29-
"""
30-
# Hash value of the block in an integer.
31-
hash_value: int
32-
# Token IDs in the block.
33-
token_ids: tuple[int, ...]
34-
# Extra keys for the block.
35-
extra_keys: Optional[Any] = None
36-
37-
3823
class BlockHashWithGroupId(NamedTuple):
3924
# The hash value for the contents (e.g., token_ids) of a block without group
4025
# ID. The value is the same for blocks representing the same tokens but for
@@ -47,32 +32,6 @@ def get_hash_value(self) -> int:
4732
return self.block_hash.hash_value
4833

4934

50-
# The hash seed for the first block of any prefix block sequence.
51-
#
52-
# We use a random value to avoid hash collisions or PYTHONHASHSEED environment
53-
# variable if set such that processes can share the seed if needed.
54-
# This aligns with the behavior of Python's hash() function, which also uses
55-
# a random seed if PYTHONHASHSEED is not set.
56-
#
57-
# The function `init_none_hash` initializes this variable globally.
58-
NONE_HASH: int
59-
60-
61-
def init_none_hash(hash_fn: Callable):
62-
global NONE_HASH
63-
64-
hash_seed = os.getenv("PYTHONHASHSEED")
65-
if hash_seed is None and hash_fn is sha256_cbor_64bit:
66-
logger.warning(
67-
"PYTHONHASHSEED is not set. This will lead to non-reproducible "
68-
"block-hashes when using sha256_cbor_64bit as the hash function."
69-
"Consider setting PYTHONHASHSEED to a fixed value for "
70-
"reproducibility.")
71-
72-
NONE_HASH = (int.from_bytes(os.urandom(32), byteorder="big")
73-
if hash_seed is None else hash_fn(hash_seed))
74-
75-
7635
class PrefixCachingMetrics:
7736
"""Metrics for prefix caching with a hit rate of the max recent N requests.
7837
@@ -335,207 +294,6 @@ def get_all_free_blocks(self) -> list[KVCacheBlock]:
335294
return ret
336295

337296

338-
def need_extra_keys(request: Request) -> bool:
339-
"""Check whether the blocks allocated to this request need extra hash keys.
340-
341-
Args:
342-
request (Request): The request.
343-
344-
Returns:
345-
bool: Whether blocks allocated to this request need extra hash keys.
346-
"""
347-
348-
# Multimodal requests need to include the MM hash.
349-
# LoRA requests need to include the LoRA ID.
350-
# Request with provided cache salt need to include the salt.
351-
return bool(request.mm_positions) or (request.lora_request
352-
is not None) or (request.cache_salt
353-
is not None)
354-
355-
356-
def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
357-
end_token_idx: int,
358-
start_mm_idx: int) -> tuple[list[Any], int]:
359-
"""Generate extra keys related to MultiModal request for block hash
360-
computation. For multi-modal inputs, the extra keys are
361-
(mm_hash, start_offset) that indicate a mm input contained in the
362-
block and its starting offset in the block tokens.
363-
364-
Args:
365-
request: The request object.
366-
start_token_idx: The start token index of the block.
367-
end_token_idx: The end token index of the block.
368-
start_mm_idx: The start multi-modal index of the block.
369-
370-
Returns:
371-
A tuple of extra keys and the next multi-modal index.
372-
"""
373-
extra_keys: list[Any] = []
374-
375-
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
376-
if not mm_positions:
377-
return extra_keys, start_mm_idx
378-
379-
if mm_positions and len(mm_positions) != len(mm_hashes):
380-
raise ValueError(
381-
"The number of multi-modal positions and hashes must match. This "
382-
"is likely because you do not enable MM preprocessor hashing. "
383-
"Please set disable_mm_preprocessor_cache=False.")
384-
385-
# Note that we assume mm_positions is sorted by offset.
386-
# We do not need to check all mm inputs if the start token index is out of
387-
# range. This usually happens in the late prefill phase and decoding phase.
388-
if mm_positions[-1].offset + mm_positions[-1].length < start_token_idx:
389-
return extra_keys, start_mm_idx
390-
391-
# Support start_mm_idx == -1 to indicate the last mm input.
392-
if start_mm_idx < 0:
393-
assert -start_mm_idx <= len(mm_positions)
394-
start_mm_idx = len(mm_positions) + start_mm_idx
395-
396-
curr_mm_idx = start_mm_idx
397-
while mm_positions and curr_mm_idx < len(mm_positions):
398-
assert mm_hashes[curr_mm_idx] is not None
399-
offset = mm_positions[curr_mm_idx].offset
400-
length = mm_positions[curr_mm_idx].length
401-
if end_token_idx > offset:
402-
if start_token_idx > offset + length:
403-
# This block has passed the current mm input.
404-
curr_mm_idx += 1
405-
continue
406-
407-
# The block contains the current mm input.
408-
extra_keys.append(mm_hashes[curr_mm_idx])
409-
410-
if end_token_idx >= offset + length:
411-
# If this block contains the end of the current mm input,
412-
# move to the next mm input as this block may also contain
413-
# the next mm input.
414-
curr_mm_idx += 1
415-
else:
416-
# Otherwise this block is done with mm inputs.
417-
break
418-
else:
419-
# This block has not reached the current mm input.
420-
break
421-
return extra_keys, curr_mm_idx
422-
423-
424-
def _gen_lora_extra_hash_keys(request: Request) -> list[int]:
425-
"""Generate extra keys related to LoRA for block hash computation.
426-
427-
Args:
428-
request: The request object.
429-
430-
Returns:
431-
Return LoRA id of the request if it is a LoRA request. Return empty
432-
list otherwise.
433-
"""
434-
if not request.lora_request:
435-
return []
436-
return [request.lora_request.lora_int_id]
437-
438-
439-
def generate_block_hash_extra_keys(
440-
request: Request, start_token_idx: int, end_token_idx: int,
441-
start_mm_idx: int) -> tuple[Optional[tuple[Any, ...]], int]:
442-
"""Generate extra keys for the block hash. The extra keys can come from
443-
the multi-modal inputs and request specific metadata (e.g., LoRA ID).
444-
445-
Args:
446-
request: The request object.
447-
start_token_idx: The start token index of the block.
448-
end_token_idx: The end token index of the block.
449-
start_mm_idx: The start multi-modal index of the block.
450-
451-
Returns:
452-
A tuple of extra keys and the next multi-modal index.
453-
"""
454-
mm_extra_keys: list[Any]
455-
mm_extra_keys, new_start_mm_idx = _gen_mm_extra_hash_keys(
456-
request, start_token_idx, end_token_idx, start_mm_idx)
457-
lora_extra_keys: list[int] = _gen_lora_extra_hash_keys(request)
458-
cache_salt_keys: list[str] = [request.cache_salt] if (
459-
start_token_idx == 0 and request.cache_salt) else []
460-
461-
extra_keys: list[Any] = lora_extra_keys + mm_extra_keys + cache_salt_keys
462-
463-
if not extra_keys:
464-
return None, new_start_mm_idx
465-
466-
return tuple(extra_keys), new_start_mm_idx
467-
468-
469-
def hash_block_tokens(
470-
hash_function: Callable,
471-
parent_block_hash: Optional[int],
472-
curr_block_token_ids: Sequence[int],
473-
extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHash:
474-
"""Computes a hash value corresponding to the contents of a block and
475-
the contents of the preceding block(s). The hash value is used for
476-
prefix caching. We use LRU cache for this function to avoid recomputing
477-
hash values for the same block contents.
478-
479-
Args:
480-
parent_block_hash: The hash of the parent block. None
481-
if this is the first block.
482-
curr_block_token_ids: A list of token ids in the current
483-
block. The current block is assumed to be full.
484-
extra_keys: Extra keys for the block.
485-
486-
Returns:
487-
The hash value of the block and the token ids in the block.
488-
The entire tuple is used as the hash key of the block.
489-
"""
490-
if not parent_block_hash:
491-
parent_block_hash = NONE_HASH
492-
493-
curr_block_token_ids_tuple = tuple(curr_block_token_ids)
494-
return BlockHash(
495-
hash_function(
496-
(parent_block_hash, curr_block_token_ids_tuple, extra_keys)),
497-
curr_block_token_ids_tuple, extra_keys)
498-
499-
500-
def hash_request_tokens(hash_function: Any, block_size: int,
501-
request: Request) -> list[BlockHash]:
502-
"""Computes hash values of a chain of blocks given a sequence of
503-
token IDs. The hash value is used for prefix caching.
504-
505-
Args:
506-
block_size: The size of each block.
507-
request: The request object.
508-
509-
Returns:
510-
The list of computed hash values.
511-
"""
512-
token_ids = request.all_token_ids
513-
514-
req_need_extra_keys = need_extra_keys(request)
515-
req_extra_keys = None
516-
curr_mm_idx = 0
517-
518-
ret = []
519-
parent_block_hash_value = None
520-
for start in range(0, len(token_ids), block_size):
521-
end = start + block_size
522-
block_token_ids = token_ids[start:end]
523-
# Do not hash the block if it is not full.
524-
if len(block_token_ids) < block_size:
525-
break
526-
527-
if req_need_extra_keys:
528-
# MM and LoRA requests need extra keys for block-hash computation.
529-
req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
530-
request, start, end, curr_mm_idx)
531-
532-
block_hash = hash_block_tokens(hash_function, parent_block_hash_value,
533-
block_token_ids, req_extra_keys)
534-
ret.append(block_hash)
535-
parent_block_hash_value = block_hash.hash_value
536-
return ret
537-
538-
539297
def max_memory_usage_bytes(vllm_config: VllmConfig,
540298
kv_cache_specs: Iterable[KVCacheSpec]) -> int:
541299
"""

vllm/v1/engine/async_llm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
4040
from vllm.v1.metrics.prometheus import shutdown_prometheus
4141
from vllm.v1.metrics.stats import IterationStats
42+
from vllm.v1.request import init_none_hash
4243

4344
logger = init_logger(__name__)
4445

@@ -131,6 +132,11 @@ def __init__(
131132
self.logger_manager.log_engine_initialized()
132133

133134
self.output_handler: Optional[asyncio.Task] = None
135+
136+
# logger.info("===jialino init_none_hash")
137+
# TODO(Jialin): Extract the right hash function from vllm_config @nocommit
138+
init_none_hash(hash)
139+
134140
try:
135141
# Start output handler eagerly if we are in the asyncio eventloop.
136142
asyncio.get_running_loop()

vllm/v1/engine/core.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from vllm.v1.kv_cache_interface import KVCacheConfig
4343
from vllm.v1.metrics.stats import SchedulerStats
4444
from vllm.v1.outputs import ModelRunnerOutput
45-
from vllm.v1.request import Request, RequestStatus
45+
from vllm.v1.request import Request, RequestStatus, hash_request_tokens
4646
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
4747
from vllm.v1.structured_output import StructuredOutputManager
4848
from vllm.version import __version__ as VLLM_VERSION
@@ -396,7 +396,12 @@ def _preprocess_add_request(self, request: EngineCoreRequest) -> Request:
396396
397397
This function could be directly used in input processing thread to allow
398398
request initialization running in parallel with Model forward"""
399-
return Request.from_engine_core_request(request)
399+
converted_request = Request.from_engine_core_request(request)
400+
# TODO(Jialin): Use the right hash function here
401+
# TODO(Jialin): Use the right block size here
402+
converted_request.precomputed_block_hashes = hash_request_tokens(
403+
hash, 16, converted_request)
404+
return converted_request
400405

401406

402407
class EngineCoreProc(EngineCore):

0 commit comments

Comments
 (0)