2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
3
"""KV-Cache Utilities."""
4
4
5
- import os
6
5
from collections import defaultdict , deque
7
- from collections .abc import Iterable , Sequence
6
+ from collections .abc import Iterable
8
7
from dataclasses import dataclass
9
- from typing import Any , Callable , NamedTuple , Optional
8
+ from typing import NamedTuple , Optional
10
9
11
10
from vllm .config import VllmConfig
12
11
from vllm .logger import init_logger
13
- from vllm .utils import GiB_bytes , cdiv , sha256_cbor_64bit
12
+ from vllm .utils import GiB_bytes , cdiv
13
+ from vllm .v1 .core .kv_cache_common import BlockHash
14
14
from vllm .v1 .kv_cache_interface import (ChunkedLocalAttentionSpec ,
15
15
FullAttentionSpec , KVCacheConfig ,
16
16
KVCacheGroupSpec , KVCacheSpec ,
17
17
KVCacheTensor , SlidingWindowSpec )
18
18
from vllm .v1 .metrics .stats import PrefixCacheStats
19
- from vllm .v1 .request import Request
20
19
21
20
logger = init_logger (__name__ )
22
21
23
22
24
- class BlockHash (NamedTuple ):
25
- """Hash value of a block (int), the token IDs in the block, and extra keys.
26
- We keep a tuple of token IDs and extra keys to reduce the likelihood of
27
- hash collisions when the hash value is the same. By using SHA256 however,
28
- hash collisions are practically impossible.
29
- """
30
- # Hash value of the block in an integer.
31
- hash_value : int
32
- # Token IDs in the block.
33
- token_ids : tuple [int , ...]
34
- # Extra keys for the block.
35
- extra_keys : Optional [Any ] = None
36
-
37
-
38
23
class BlockHashWithGroupId (NamedTuple ):
39
24
# The hash value for the contents (e.g., token_ids) of a block without group
40
25
# ID. The value is the same for blocks representing the same tokens but for
@@ -47,32 +32,6 @@ def get_hash_value(self) -> int:
47
32
return self .block_hash .hash_value
48
33
49
34
50
- # The hash seed for the first block of any prefix block sequence.
51
- #
52
- # We use a random value to avoid hash collisions or PYTHONHASHSEED environment
53
- # variable if set such that processes can share the seed if needed.
54
- # This aligns with the behavior of Python's hash() function, which also uses
55
- # a random seed if PYTHONHASHSEED is not set.
56
- #
57
- # The function `init_none_hash` initializes this variable globally.
58
- NONE_HASH : int
59
-
60
-
61
- def init_none_hash (hash_fn : Callable ):
62
- global NONE_HASH
63
-
64
- hash_seed = os .getenv ("PYTHONHASHSEED" )
65
- if hash_seed is None and hash_fn is sha256_cbor_64bit :
66
- logger .warning (
67
- "PYTHONHASHSEED is not set. This will lead to non-reproducible "
68
- "block-hashes when using sha256_cbor_64bit as the hash function."
69
- "Consider setting PYTHONHASHSEED to a fixed value for "
70
- "reproducibility." )
71
-
72
- NONE_HASH = (int .from_bytes (os .urandom (32 ), byteorder = "big" )
73
- if hash_seed is None else hash_fn (hash_seed ))
74
-
75
-
76
35
class PrefixCachingMetrics :
77
36
"""Metrics for prefix caching with a hit rate of the max recent N requests.
78
37
@@ -335,207 +294,6 @@ def get_all_free_blocks(self) -> list[KVCacheBlock]:
335
294
return ret
336
295
337
296
338
- def need_extra_keys (request : Request ) -> bool :
339
- """Check whether the blocks allocated to this request need extra hash keys.
340
-
341
- Args:
342
- request (Request): The request.
343
-
344
- Returns:
345
- bool: Whether blocks allocated to this request need extra hash keys.
346
- """
347
-
348
- # Multimodal requests need to include the MM hash.
349
- # LoRA requests need to include the LoRA ID.
350
- # Request with provided cache salt need to include the salt.
351
- return bool (request .mm_positions ) or (request .lora_request
352
- is not None ) or (request .cache_salt
353
- is not None )
354
-
355
-
356
- def _gen_mm_extra_hash_keys (request : Request , start_token_idx : int ,
357
- end_token_idx : int ,
358
- start_mm_idx : int ) -> tuple [list [Any ], int ]:
359
- """Generate extra keys related to MultiModal request for block hash
360
- computation. For multi-modal inputs, the extra keys are
361
- (mm_hash, start_offset) that indicate a mm input contained in the
362
- block and its starting offset in the block tokens.
363
-
364
- Args:
365
- request: The request object.
366
- start_token_idx: The start token index of the block.
367
- end_token_idx: The end token index of the block.
368
- start_mm_idx: The start multi-modal index of the block.
369
-
370
- Returns:
371
- A tuple of extra keys and the next multi-modal index.
372
- """
373
- extra_keys : list [Any ] = []
374
-
375
- mm_positions , mm_hashes = request .mm_positions , request .mm_hashes
376
- if not mm_positions :
377
- return extra_keys , start_mm_idx
378
-
379
- if mm_positions and len (mm_positions ) != len (mm_hashes ):
380
- raise ValueError (
381
- "The number of multi-modal positions and hashes must match. This "
382
- "is likely because you do not enable MM preprocessor hashing. "
383
- "Please set disable_mm_preprocessor_cache=False." )
384
-
385
- # Note that we assume mm_positions is sorted by offset.
386
- # We do not need to check all mm inputs if the start token index is out of
387
- # range. This usually happens in the late prefill phase and decoding phase.
388
- if mm_positions [- 1 ].offset + mm_positions [- 1 ].length < start_token_idx :
389
- return extra_keys , start_mm_idx
390
-
391
- # Support start_mm_idx == -1 to indicate the last mm input.
392
- if start_mm_idx < 0 :
393
- assert - start_mm_idx <= len (mm_positions )
394
- start_mm_idx = len (mm_positions ) + start_mm_idx
395
-
396
- curr_mm_idx = start_mm_idx
397
- while mm_positions and curr_mm_idx < len (mm_positions ):
398
- assert mm_hashes [curr_mm_idx ] is not None
399
- offset = mm_positions [curr_mm_idx ].offset
400
- length = mm_positions [curr_mm_idx ].length
401
- if end_token_idx > offset :
402
- if start_token_idx > offset + length :
403
- # This block has passed the current mm input.
404
- curr_mm_idx += 1
405
- continue
406
-
407
- # The block contains the current mm input.
408
- extra_keys .append (mm_hashes [curr_mm_idx ])
409
-
410
- if end_token_idx >= offset + length :
411
- # If this block contains the end of the current mm input,
412
- # move to the next mm input as this block may also contain
413
- # the next mm input.
414
- curr_mm_idx += 1
415
- else :
416
- # Otherwise this block is done with mm inputs.
417
- break
418
- else :
419
- # This block has not reached the current mm input.
420
- break
421
- return extra_keys , curr_mm_idx
422
-
423
-
424
- def _gen_lora_extra_hash_keys (request : Request ) -> list [int ]:
425
- """Generate extra keys related to LoRA for block hash computation.
426
-
427
- Args:
428
- request: The request object.
429
-
430
- Returns:
431
- Return LoRA id of the request if it is a LoRA request. Return empty
432
- list otherwise.
433
- """
434
- if not request .lora_request :
435
- return []
436
- return [request .lora_request .lora_int_id ]
437
-
438
-
439
- def generate_block_hash_extra_keys (
440
- request : Request , start_token_idx : int , end_token_idx : int ,
441
- start_mm_idx : int ) -> tuple [Optional [tuple [Any , ...]], int ]:
442
- """Generate extra keys for the block hash. The extra keys can come from
443
- the multi-modal inputs and request specific metadata (e.g., LoRA ID).
444
-
445
- Args:
446
- request: The request object.
447
- start_token_idx: The start token index of the block.
448
- end_token_idx: The end token index of the block.
449
- start_mm_idx: The start multi-modal index of the block.
450
-
451
- Returns:
452
- A tuple of extra keys and the next multi-modal index.
453
- """
454
- mm_extra_keys : list [Any ]
455
- mm_extra_keys , new_start_mm_idx = _gen_mm_extra_hash_keys (
456
- request , start_token_idx , end_token_idx , start_mm_idx )
457
- lora_extra_keys : list [int ] = _gen_lora_extra_hash_keys (request )
458
- cache_salt_keys : list [str ] = [request .cache_salt ] if (
459
- start_token_idx == 0 and request .cache_salt ) else []
460
-
461
- extra_keys : list [Any ] = lora_extra_keys + mm_extra_keys + cache_salt_keys
462
-
463
- if not extra_keys :
464
- return None , new_start_mm_idx
465
-
466
- return tuple (extra_keys ), new_start_mm_idx
467
-
468
-
469
- def hash_block_tokens (
470
- hash_function : Callable ,
471
- parent_block_hash : Optional [int ],
472
- curr_block_token_ids : Sequence [int ],
473
- extra_keys : Optional [tuple [Any , ...]] = None ) -> BlockHash :
474
- """Computes a hash value corresponding to the contents of a block and
475
- the contents of the preceding block(s). The hash value is used for
476
- prefix caching. We use LRU cache for this function to avoid recomputing
477
- hash values for the same block contents.
478
-
479
- Args:
480
- parent_block_hash: The hash of the parent block. None
481
- if this is the first block.
482
- curr_block_token_ids: A list of token ids in the current
483
- block. The current block is assumed to be full.
484
- extra_keys: Extra keys for the block.
485
-
486
- Returns:
487
- The hash value of the block and the token ids in the block.
488
- The entire tuple is used as the hash key of the block.
489
- """
490
- if not parent_block_hash :
491
- parent_block_hash = NONE_HASH
492
-
493
- curr_block_token_ids_tuple = tuple (curr_block_token_ids )
494
- return BlockHash (
495
- hash_function (
496
- (parent_block_hash , curr_block_token_ids_tuple , extra_keys )),
497
- curr_block_token_ids_tuple , extra_keys )
498
-
499
-
500
- def hash_request_tokens (hash_function : Any , block_size : int ,
501
- request : Request ) -> list [BlockHash ]:
502
- """Computes hash values of a chain of blocks given a sequence of
503
- token IDs. The hash value is used for prefix caching.
504
-
505
- Args:
506
- block_size: The size of each block.
507
- request: The request object.
508
-
509
- Returns:
510
- The list of computed hash values.
511
- """
512
- token_ids = request .all_token_ids
513
-
514
- req_need_extra_keys = need_extra_keys (request )
515
- req_extra_keys = None
516
- curr_mm_idx = 0
517
-
518
- ret = []
519
- parent_block_hash_value = None
520
- for start in range (0 , len (token_ids ), block_size ):
521
- end = start + block_size
522
- block_token_ids = token_ids [start :end ]
523
- # Do not hash the block if it is not full.
524
- if len (block_token_ids ) < block_size :
525
- break
526
-
527
- if req_need_extra_keys :
528
- # MM and LoRA requests need extra keys for block-hash computation.
529
- req_extra_keys , curr_mm_idx = generate_block_hash_extra_keys (
530
- request , start , end , curr_mm_idx )
531
-
532
- block_hash = hash_block_tokens (hash_function , parent_block_hash_value ,
533
- block_token_ids , req_extra_keys )
534
- ret .append (block_hash )
535
- parent_block_hash_value = block_hash .hash_value
536
- return ret
537
-
538
-
539
297
def max_memory_usage_bytes (vllm_config : VllmConfig ,
540
298
kv_cache_specs : Iterable [KVCacheSpec ]) -> int :
541
299
"""
0 commit comments