@@ -39,6 +39,9 @@ class RequestMeta:
3939 hbm_hit_block_num : int = 0
4040 # local_computed_block + external_computed_block
4141 total_hit_block_num : int = 0
42+ num_token_ids : int = 0
43+ vllm_block_ids : list [int ] = field (default_factory = list )
44+ token_processed : int = 0
4245
4346
4447@dataclass
@@ -207,7 +210,6 @@ def get_num_new_matched_tokens(
207210 request : "Request" ,
208211 num_computed_tokens : int ,
209212 ) -> tuple [int , bool ]:
210-
211213 assert num_computed_tokens % self .block_size == 0
212214 hbm_hit_block_num = num_computed_tokens // self .block_size
213215
@@ -242,13 +244,16 @@ def get_num_new_matched_tokens(
242244 # When all the tokens are cached in ssd or hbm,
243245 # we need to recompute the last token. This if condition will be removed
244246 # once vLLM scheduler provides a better solution in the future.
245- if total_hit_block_num * self .block_size == request .num_tokens :
247+ num_total_hit_tokens = total_hit_block_num * self .block_size
248+ if num_total_hit_tokens == request .num_tokens :
246249 external_hit_tokens -= 1
247250
248251 self .requests_meta [request .request_id ] = RequestMeta (
249252 ucm_block_ids = ucm_block_ids ,
250253 hbm_hit_block_num = hbm_hit_block_num ,
251254 total_hit_block_num = total_hit_block_num ,
255+ num_token_ids = len (request .all_token_ids ),
256+ token_processed = num_total_hit_tokens ,
252257 )
253258
254259 return external_hit_tokens , False
@@ -277,22 +282,16 @@ def _generate_dispatch_meta(
277282 | scheduled_block_num |
278283 """
279284
280- new_blocks_num = new_tokens // self .block_size
281285 hbm_hit_block_num = req_meta .hbm_hit_block_num
282286 total_hit_block_num = req_meta .total_hit_block_num
283- scheduled_block_num = total_hit_block_num + new_blocks_num
284287 ucm_block_ids = req_meta .ucm_block_ids
288+ req_meta .vllm_block_ids .extend (vllm_block_ids )
285289
286- dump_ucm_block_ids = ucm_block_ids [total_hit_block_num :scheduled_block_num ]
287- if need_load :
288- dump_vllm_block_ids = vllm_block_ids [
289- total_hit_block_num :scheduled_block_num
290- ]
291- else :
292- dump_vllm_block_ids = vllm_block_ids
293-
294- # after this round, req_meta will be updated
295- req_meta .total_hit_block_num = scheduled_block_num
290+ start_idx = req_meta .token_processed // self .block_size
291+ end_idx = (req_meta .token_processed + new_tokens ) // self .block_size
292+ dump_ucm_block_ids = ucm_block_ids [start_idx :end_idx ]
293+ dump_vllm_block_ids = req_meta .vllm_block_ids [start_idx :end_idx ]
294+ req_meta .token_processed += new_tokens
296295
297296 load_ucm_block_ids , load_vllm_block_ids = [], []
298297 if need_load :
@@ -332,10 +331,13 @@ def build_connector_meta(
332331 continue
333332 req_meta = self .requests_meta .get (request_id )
334333 if req_meta :
334+ new_block_ids = []
335+ if scheduled_cached_reqs .new_block_ids [i ] != None :
336+ new_block_ids = scheduled_cached_reqs .new_block_ids [i ][0 ]
335337 requests_dispatch_meta [request_id ] = self ._generate_dispatch_meta (
336338 req_meta ,
337339 scheduler_output .num_scheduled_tokens [request_id ],
338- scheduled_cached_reqs . new_block_ids [ i ][ 0 ] ,
340+ new_block_ids ,
339341 scheduled_cached_reqs .resumed_from_preemption [i ],
340342 )
341343 else :
0 commit comments