44import os
55import re
66from dataclasses import dataclass
7- from typing import Iterable , List , Optional , Tuple , Union
7+ from typing import Any , Iterable , List , Optional , Tuple , Union
88
99import torch
1010from vllm .config import VllmConfig
1111from vllm .distributed .kv_transfer .kv_connector .v1 .base import \
1212 KVConnectorMetadata
13+ from vllm .multimodal .inputs import MultiModalFeatureSpec
1314from vllm .utils import cdiv , logger
1415from vllm .v1 .core .sched .output import NewRequestData
1516
@@ -128,18 +129,21 @@ def _make_key_by_hash(self,
128129 chunk_hash ,
129130 )
130131
131- def _hash (
132- self ,
133- tokens : Union [torch .Tensor , List [int ]],
134- prefix_hash : str ,
135- ) -> str :
132+ def _hash (self , tokens : Union [torch .Tensor , List [int ]], prefix_hash : str ,
133+ extra_keys : Optional [tuple [Any , ...]]) -> str :
136134 # TODO: change it to a more efficient hash function
137135 if isinstance (tokens , torch .Tensor ):
138136 tokens_bytes = tokens .cpu ().to (torch .uint32 ).numpy ().tobytes ()
139137 elif isinstance (tokens , list ):
140138 tokens_bytes = array .array ("I" , tokens ).tobytes ()
141- return hashlib .sha256 (prefix_hash .encode ("ascii" ) +
142- tokens_bytes ).hexdigest ()
139+ if extra_keys is not None :
140+ extra_bytes = json .dumps (extra_keys ,
141+ separators = (',' , ':' )).encode ("utf-8" )
142+ else :
143+ extra_bytes = b""
144+ return hashlib .sha256 (
145+ prefix_hash .encode ("ascii" ) + tokens_bytes +
146+ extra_bytes ).hexdigest ()
143147
144148 def _chunk_tokens (
145149 self ,
@@ -160,16 +164,24 @@ def _chunk_tokens(
160164 def _prefix_hash (
161165 self ,
162166 token_chunks : Iterable [Union [torch .Tensor , List [int ]]],
167+ mm_features : Optional [list [MultiModalFeatureSpec ]] = None ,
163168 ) -> Iterable [str ]:
164169 prefix_hash = ''
165- for token_chunk in token_chunks :
166- prefix_hash = self ._hash (token_chunk , prefix_hash )
170+ curr_mm_idx = 0
171+ for chunk_id , token_chunk in enumerate (token_chunks ):
172+ start_idx = chunk_id * self .metadata .block_size
173+ end_idx = start_idx + len (token_chunk )
174+ extra_keys , curr_mm_idx = self ._gen_mm_extra_hash_keys (
175+ mm_features , start_idx , end_idx , curr_mm_idx )
176+ prefix_hash = self ._hash (token_chunk , prefix_hash ,
177+ tuple (extra_keys ))
167178 yield prefix_hash
168179
169180 def process_tokens (
170181 self ,
171182 tokens : Union [torch .Tensor , List [int ]],
172183 mask : Optional [torch .Tensor ] = None ,
184+ mm_features : Optional [list [MultiModalFeatureSpec ]] = None ,
173185 ) -> Iterable [Tuple [int , int , MooncakeEngineKey ]]:
174186 """Process the tokens and return the corresponding cache engine keys.
175187
@@ -203,9 +215,8 @@ def process_tokens(
203215 total_len = len (tokens )
204216
205217 token_chunks = self ._chunk_tokens (tokens )
206- prefix_hashes = self ._prefix_hash (token_chunks )
218+ prefix_hashes = self ._prefix_hash (token_chunks , mm_features )
207219
208- start_idx = 0
209220 for chunk_id , hash_val in enumerate (prefix_hashes ):
210221 start_idx = chunk_id * self .metadata .block_size
211222 end_idx = min (start_idx + self .metadata .block_size , total_len )
@@ -214,6 +225,69 @@ def process_tokens(
214225 else :
215226 yield start_idx , end_idx , self ._make_key_by_hash (hash_val )
216227
228+ def _gen_mm_extra_hash_keys (self , mm_features : Optional [
229+ list [MultiModalFeatureSpec ]], start_token_idx : int , end_token_idx : int ,
230+ start_mm_idx : int ) -> tuple [list [Any ], int ]:
231+ """This method refers to: vllm/vllm/v1/core/kv_cache_utils/_gen_mm_extra_hash_keys
232+ Generate extra keys related to MultiModal request for block hash
233+ computation. For multi-modal inputs, the extra keys are
234+ (mm_hash, start_offset) that indicate a mm input contained in the
235+ block and its starting offset in the block tokens.
236+
237+ Args:
238+ mm_features: The multimodel_input of the request.
239+ start_token_idx: The start token index of the block.
240+ end_token_idx: The end token index of the block.
241+ start_mm_idx: The start multi-modal index of the block.
242+
243+ Returns:
244+ A tuple of extra keys and the next multi-modal index.
245+ """
246+ extra_keys : list [Any ] = []
247+
248+ if not mm_features :
249+ return extra_keys , start_mm_idx
250+
251+ # Note that we assume mm_features are sorted by mm_position.offset.
252+ # We do not need to check all mm inputs if the start token index is out of
253+ # range. This usually happens in the late prefill phase and decoding phase.
254+ last_pos = mm_features [- 1 ].mm_position
255+ if last_pos .offset + last_pos .length < start_token_idx :
256+ return extra_keys , start_mm_idx
257+
258+ # Support start_mm_idx == -1 to indicate the last mm input.
259+ if start_mm_idx < 0 :
260+ assert - start_mm_idx <= len (mm_features )
261+ start_mm_idx = len (mm_features ) + start_mm_idx
262+
263+ curr_mm_idx = start_mm_idx
264+ while mm_features and curr_mm_idx < len (mm_features ):
265+ mm_feature = mm_features [curr_mm_idx ]
266+ assert mm_feature .identifier is not None
267+ offset = mm_feature .mm_position .offset
268+ length = mm_feature .mm_position .length
269+ if end_token_idx > offset :
270+ if start_token_idx > offset + length :
271+ # This block has passed the current mm input.
272+ curr_mm_idx += 1
273+ continue
274+
275+ # The block contains the current mm input.
276+ extra_keys .append (mm_feature .identifier )
277+
278+ if end_token_idx >= offset + length :
279+ # If this block contains the end of the current mm input,
280+ # move to the next mm input as this block may also contain
281+ # the next mm input.
282+ curr_mm_idx += 1
283+ else :
284+ # Otherwise this block is done with mm inputs.
285+ break
286+ else :
287+ # This block has not reached the current mm input.
288+ break
289+ return extra_keys , curr_mm_idx
290+
217291
218292@dataclass
219293class LoadSpec :
@@ -241,6 +315,9 @@ class RequestTracker:
241315 # The token ids that has been scheduled so far
242316 token_ids : list [int ]
243317
318+ # Multi-modal related
319+ mm_features : list [MultiModalFeatureSpec ]
320+
244321 # The block ids that has been allocated so far
245322 # NOTE: allocated blocks could be more than the number of tokens
246323 # FIXME: need to check whether the block ids will be changed after
@@ -279,6 +356,7 @@ def from_new_request(
279356 req_id = new_request .req_id ,
280357 token_ids = new_request .prompt_token_ids [:num_tokens_to_compute ].
281358 copy (),
359+ mm_features = new_request .mm_features ,
282360 allocated_block_ids = unfolded_block_ids ,
283361 num_saved_tokens = 0 ,
284362 )
@@ -323,6 +401,8 @@ class ReqMeta:
323401
324402 is_last_chunk : Optional [bool ] = None
325403
404+ mm_features : Optional [list [MultiModalFeatureSpec ]] = None
405+
326406 @staticmethod
327407 def from_request_tracker (
328408 tracker : RequestTracker ,
@@ -372,6 +452,9 @@ def from_request_tracker(
372452 # OPTIMIZATION: pre-allocate the buffer for token ids and block ids
373453 token_ids = torch .tensor (input_token_ids )[:num_tokens_to_save ]
374454
455+ # Multi-modal related
456+ mm_features = tracker .mm_features
457+
375458 # # For load operation: check whether the request is scheduled to load
376459 if load_spec is not None and load_spec .can_load :
377460 logger .debug (
@@ -388,6 +471,7 @@ def from_request_tracker(
388471 return ReqMeta (
389472 req_id = tracker .req_id ,
390473 token_ids = token_ids ,
474+ mm_features = mm_features ,
391475 block_ids = tracker .allocated_block_ids ,
392476 save_spec = save_spec ,
393477 load_spec = load_spec ,
0 commit comments