@@ -84,7 +84,6 @@ def allocated_slots(self, request: Request):
84
84
return len (request .block_tables ) * self .config .cache_config .block_size
85
85
86
86
def get_new_block_nums (self , request : Request , num_new_tokens : int ):
87
- self .check_and_free_block_tables ()
88
87
return (
89
88
request .num_computed_tokens + num_new_tokens + self .config .cache_config .block_size - 1
90
89
) // self .config .cache_config .block_size - len (request .block_tables )
@@ -119,7 +118,7 @@ def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_re
119
118
preempted_req .status = RequestStatus .PREEMPTED
120
119
preempted_req .num_computed_tokens = 0
121
120
self ._free_blocks (preempted_req )
122
- preempted_req .prefill_block_num = None
121
+ preempted_req .cached_block_num = 0
123
122
self .to_be_rescheduled_request_id_set .add (preempted_req .request_id )
124
123
preempted_reqs .append (preempted_req )
125
124
scheduled_reqs .append (self ._prepare_preempt_task (preempted_req ))
@@ -282,14 +281,6 @@ def schedule(self):
282
281
if request .num_computed_tokens >= request .need_prefill_tokens : # to be decoding
283
282
if request .num_total_tokens > request .need_prefill_tokens : # has generated tokens
284
283
request .num_computed_tokens = request .num_total_tokens - 1
285
- else : # prefill finished
286
- if (
287
- self .config .cache_config .enable_prefix_caching
288
- and request .get ("prefill_block_num" , None ) is None
289
- ):
290
- # update prefill cache blocks for prefix caching
291
- request .prefill_block_num = len (request .block_tables )
292
- self .cache_manager .update_cache_blocks (request , self .config .cache_config .block_size )
293
284
if (
294
285
self .allocated_slots (request ) - request .num_total_tokens
295
286
<= self .config .cache_config .prealloc_dec_block_slot_num_threshold
@@ -339,6 +330,10 @@ def schedule(self):
339
330
scheduled_reqs .append (self ._prepare_prefill_task (request , num_new_tokens ))
340
331
token_budget -= num_new_tokens
341
332
request .num_computed_tokens += num_new_tokens
333
+ if self .config .cache_config .enable_prefix_caching :
334
+ self .cache_manager .update_cache_blocks (
335
+ request , self .config .cache_config .block_size , request .num_computed_tokens
336
+ )
342
337
req_index += 1
343
338
# schedule the WAITING requests.
344
339
if not preempted_reqs :
@@ -371,6 +366,10 @@ def schedule(self):
371
366
request .schedule_start_time = time .time ()
372
367
token_budget -= num_new_tokens
373
368
request .num_computed_tokens += num_new_tokens
369
+ if self .config .cache_config .enable_prefix_caching :
370
+ self .cache_manager .update_cache_blocks (
371
+ request , self .config .cache_config .block_size , request .num_computed_tokens
372
+ )
374
373
request .status = RequestStatus .RUNNING
375
374
main_process_metrics .num_requests_waiting .dec (1 )
376
375
main_process_metrics .num_requests_running .inc (1 )
@@ -403,6 +402,10 @@ def schedule(self):
403
402
scheduled_reqs .append (self ._prepare_prefill_task (request , num_new_tokens ))
404
403
token_budget -= num_new_tokens
405
404
request .num_computed_tokens += num_new_tokens
405
+ if self .config .cache_config .enable_prefix_caching :
406
+ self .cache_manager .update_cache_blocks (
407
+ request , self .config .cache_config .block_size , request .num_computed_tokens
408
+ )
406
409
request .status = RequestStatus .RUNNING
407
410
main_process_metrics .num_requests_waiting .dec (1 )
408
411
main_process_metrics .num_requests_running .inc (1 )
@@ -447,7 +450,7 @@ def get_prefix_cached_blocks(self, request: Request):
447
450
448
451
matched_block_num = len (common_block_ids )
449
452
no_cache_block_num = self .cache_manager .get_required_block_num (
450
- request .prompt_token_ids_len - matched_token_num ,
453
+ request .need_prefill_tokens - matched_token_num ,
451
454
self .config .cache_config .block_size ,
452
455
)
453
456
@@ -463,7 +466,7 @@ def get_prefix_cached_blocks(self, request: Request):
463
466
main_process_metrics .prefix_gpu_cache_token_num .inc (request .gpu_cache_token_num )
464
467
main_process_metrics .prefix_cpu_cache_token_num .inc (request .cpu_cache_token_num )
465
468
466
- if matched_token_num == request .prompt_token_ids_len :
469
+ if matched_token_num == request .need_prefill_tokens :
467
470
request .num_computed_tokens = matched_token_num - self .config .cache_config .block_size
468
471
request .skip_allocate = True
469
472
else :
@@ -481,16 +484,8 @@ def add_request(self, request: Request) -> None:
481
484
482
485
def _free_blocks (self , request : Request ):
483
486
if self .config .cache_config .enable_prefix_caching :
484
- # TODO(chengyanfu): support cache ouput blocks for prefix caching
485
- if request .get ("prefill_block_num" , None ) is None :
486
- leaf_node = self .cache_manager .req_leaf_map [request .request_id ]
487
- self .cache_manager .decrease_request_share_count (request .request_id )
488
- self .cache_manager .free_nodes_directly (leaf_node )
489
- self .cache_manager .recycle_gpu_blocks (request .block_tables [request .cache_info [0 ] :])
490
-
491
- else :
492
- self .cache_manager .release_block_ids_async (request )
493
- self .cache_manager .recycle_gpu_blocks (request .block_tables [request .prefill_block_num :])
487
+ self .cache_manager .release_block_ids (request )
488
+ self .cache_manager .recycle_gpu_blocks (request .block_tables [request .cached_block_num :])
494
489
else :
495
490
self .cache_manager .recycle_gpu_blocks (request .block_tables )
496
491
request .block_tables = []
0 commit comments