Skip to content

Commit 8d77c1c

Browse files
[Optimize] optimize prefix cache in release22 (#3889)
* optimize prefix cache in release22 * optimize prefix cache in release22 * fix worker * fix * fix --------- Co-authored-by: Jiang-Jia-Jun <[email protected]>
1 parent 41cd3e2 commit 8d77c1c

File tree

4 files changed

+44
-44
lines changed

4 files changed

+44
-44
lines changed

fastdeploy/cache_manager/prefix_cache_manager.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,12 @@ def can_allocate_gpu_blocks(self, num_blocks: int):
257257
Check if num_blocks gpu blocks can be allocated.
258258
"""
259259
if len(self.gpu_free_block_list) < num_blocks:
260-
return False
260+
if self.cache_config.enable_prefix_caching:
261+
self.free_block_ids(num_blocks)
262+
if len(self.gpu_free_block_list) < num_blocks:
263+
return False
264+
else:
265+
return True
261266
else:
262267
return True
263268

@@ -448,7 +453,7 @@ def get_required_block_num(self, input_token_num, block_size):
448453
"""
449454
return (input_token_num + block_size - 1) // block_size
450455

451-
def update_cache_blocks(self, task, block_size):
456+
def update_cache_blocks(self, task, block_size, num_computed_tokens):
452457
"""
453458
update cache blocks for a task.
454459
# TODO(chengyanfu): support async update
@@ -459,12 +464,15 @@ def update_cache_blocks(self, task, block_size):
459464
"""
460465
try:
461466
req_id = task.request_id
462-
num_cached_tokens = task.num_cached_tokens
463467
block_tables = task.block_tables
464468

465-
last_node, input_ids = self.cache_info[req_id]
466-
left_input_ids = input_ids[num_cached_tokens:]
469+
last_node, num_cached_tokens = self.cache_info[req_id]
470+
input_ids = task.prompt_token_ids + task.output_token_ids
471+
can_cache_computed_tokens = num_computed_tokens - num_computed_tokens % block_size
472+
left_input_ids = input_ids[num_cached_tokens:can_cache_computed_tokens]
467473
gpu_extra_block_ids = block_tables[num_cached_tokens // block_size :]
474+
if req_id in self.leaf_req_map[last_node]: # delete old leaf record, update later
475+
self.leaf_req_map[last_node].remove(req_id)
468476

469477
with self.request_release_lock:
470478
current_time = time.time()
@@ -480,7 +488,8 @@ def update_cache_blocks(self, task, block_size):
480488
)
481489
self.req_leaf_map[req_id] = leaf_node
482490
self.leaf_req_map[leaf_node].add(req_id)
483-
self.cache_info[req_id] = (leaf_node, input_ids)
491+
self.cache_info[req_id] = (leaf_node, can_cache_computed_tokens)
492+
task.cached_block_num = can_cache_computed_tokens // block_size
484493
except Exception as e:
485494
logger.error(f"update_cache_blocks, error: {type(e)} {e}, {str(traceback.format_exc())}")
486495
raise e
@@ -508,7 +517,7 @@ def request_match_blocks(self, task, block_size, *args):
508517
hit_info["gpu_cache_blocks"] = 0
509518
hit_info["cpu_cache_blocks"] = 0
510519
self.metrics.req_count += 1
511-
input_ids = task.prompt_token_ids
520+
input_ids = task.prompt_token_ids + task.output_token_ids
512521
req_id = task.request_id
513522
logger.info(f"request_match_blocks: start to allocate blocks for req_id {req_id}")
514523
input_token_num = len(input_ids)
@@ -546,9 +555,6 @@ def request_match_blocks(self, task, block_size, *args):
546555
"request_match_blocks: Not enough GPU memory to allocate cache for matched CPU Cache"
547556
)
548557

549-
# record request cache info
550-
self.cache_info[req_id] = (match_block_node, input_ids)
551-
552558
# 3. update metrics
553559
matched_token_num = gpu_match_token_num + cpu_match_token_num
554560
common_block_ids = match_gpu_block_ids + gpu_recv_block_ids
@@ -571,6 +577,9 @@ def request_match_blocks(self, task, block_size, *args):
571577
# set leaf node temporarily, then update it in update_cache_blocks
572578
self.req_leaf_map[req_id] = match_block_node
573579
self.leaf_req_map[match_block_node].add(req_id)
580+
# record request cache info
581+
self.cache_info[req_id] = (match_block_node, matched_token_num)
582+
task.cached_block_num = matched_token_num // block_size
574583
return common_block_ids, matched_token_num, hit_info
575584
except Exception as e:
576585
logger.error(f"request_match_blocks: request_block_ids: error: {type(e)} {e}")
@@ -687,6 +696,11 @@ def release_block_ids_async(self, task):
687696
"""
688697
return self.executor_pool.submit(self.release_block_ids, task)
689698

699+
def free_block_ids(self, need_block_num):
700+
self.free_block_ids_async(need_block_num)
701+
while (self.gpu_free_task_future is not None) and (not self.gpu_free_task_future.done()):
702+
time.sleep(0.001)
703+
690704
def release_block_ids(self, task):
691705
"""
692706
release block ids
@@ -1108,15 +1122,6 @@ def _update_matched_node_info(self, req_id, last_node, current_time):
11081122
node.req_id_set.add(req_id)
11091123
node = node.parent
11101124

1111-
def decrease_request_share_count(self, req_id):
1112-
"""
1113-
Decrease node shared count
1114-
"""
1115-
node, input_ids = self.cache_info[req_id]
1116-
while node != self.radix_tree_root:
1117-
node.decrement_shared_count()
1118-
node = node.parent
1119-
11201125
def build_path(
11211126
self,
11221127
req_id,

fastdeploy/engine/common_engine.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -527,9 +527,8 @@ def _fetch_request():
527527
self.cfg.max_prefill_batch,
528528
)
529529

530-
self.resource_manager.check_and_free_block_tables()
531530
tasks = self.scheduler.get_requests(
532-
available_blocks=self.resource_manager.available_block_num(),
531+
available_blocks=self.cfg.cache_config.max_block_num_per_seq,
533532
block_size=self.cfg.cache_config.block_size,
534533
reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
535534
max_num_batched_tokens=self.cfg.max_model_len,

fastdeploy/engine/sched/resource_manager_v1.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ def allocated_slots(self, request: Request):
8484
return len(request.block_tables) * self.config.cache_config.block_size
8585

8686
def get_new_block_nums(self, request: Request, num_new_tokens: int):
87-
self.check_and_free_block_tables()
8887
return (
8988
request.num_computed_tokens + num_new_tokens + self.config.cache_config.block_size - 1
9089
) // self.config.cache_config.block_size - len(request.block_tables)
@@ -119,7 +118,7 @@ def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_re
119118
preempted_req.status = RequestStatus.PREEMPTED
120119
preempted_req.num_computed_tokens = 0
121120
self._free_blocks(preempted_req)
122-
preempted_req.prefill_block_num = None
121+
preempted_req.cached_block_num = 0
123122
self.to_be_rescheduled_request_id_set.add(preempted_req.request_id)
124123
preempted_reqs.append(preempted_req)
125124
scheduled_reqs.append(self._prepare_preempt_task(preempted_req))
@@ -282,14 +281,6 @@ def schedule(self):
282281
if request.num_computed_tokens >= request.need_prefill_tokens: # to be decoding
283282
if request.num_total_tokens > request.need_prefill_tokens: # has generated tokens
284283
request.num_computed_tokens = request.num_total_tokens - 1
285-
else: # prefill finished
286-
if (
287-
self.config.cache_config.enable_prefix_caching
288-
and request.get("prefill_block_num", None) is None
289-
):
290-
# update prefill cache blocks for prefix caching
291-
request.prefill_block_num = len(request.block_tables)
292-
self.cache_manager.update_cache_blocks(request, self.config.cache_config.block_size)
293284
if (
294285
self.allocated_slots(request) - request.num_total_tokens
295286
<= self.config.cache_config.prealloc_dec_block_slot_num_threshold
@@ -339,6 +330,10 @@ def schedule(self):
339330
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
340331
token_budget -= num_new_tokens
341332
request.num_computed_tokens += num_new_tokens
333+
if self.config.cache_config.enable_prefix_caching:
334+
self.cache_manager.update_cache_blocks(
335+
request, self.config.cache_config.block_size, request.num_computed_tokens
336+
)
342337
req_index += 1
343338
# schedule the WAITING requests.
344339
if not preempted_reqs:
@@ -371,6 +366,10 @@ def schedule(self):
371366
request.schedule_start_time = time.time()
372367
token_budget -= num_new_tokens
373368
request.num_computed_tokens += num_new_tokens
369+
if self.config.cache_config.enable_prefix_caching:
370+
self.cache_manager.update_cache_blocks(
371+
request, self.config.cache_config.block_size, request.num_computed_tokens
372+
)
374373
request.status = RequestStatus.RUNNING
375374
main_process_metrics.num_requests_waiting.dec(1)
376375
main_process_metrics.num_requests_running.inc(1)
@@ -403,6 +402,10 @@ def schedule(self):
403402
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
404403
token_budget -= num_new_tokens
405404
request.num_computed_tokens += num_new_tokens
405+
if self.config.cache_config.enable_prefix_caching:
406+
self.cache_manager.update_cache_blocks(
407+
request, self.config.cache_config.block_size, request.num_computed_tokens
408+
)
406409
request.status = RequestStatus.RUNNING
407410
main_process_metrics.num_requests_waiting.dec(1)
408411
main_process_metrics.num_requests_running.inc(1)
@@ -447,7 +450,7 @@ def get_prefix_cached_blocks(self, request: Request):
447450

448451
matched_block_num = len(common_block_ids)
449452
no_cache_block_num = self.cache_manager.get_required_block_num(
450-
request.prompt_token_ids_len - matched_token_num,
453+
request.need_prefill_tokens - matched_token_num,
451454
self.config.cache_config.block_size,
452455
)
453456

@@ -463,7 +466,7 @@ def get_prefix_cached_blocks(self, request: Request):
463466
main_process_metrics.prefix_gpu_cache_token_num.inc(request.gpu_cache_token_num)
464467
main_process_metrics.prefix_cpu_cache_token_num.inc(request.cpu_cache_token_num)
465468

466-
if matched_token_num == request.prompt_token_ids_len:
469+
if matched_token_num == request.need_prefill_tokens:
467470
request.num_computed_tokens = matched_token_num - self.config.cache_config.block_size
468471
request.skip_allocate = True
469472
else:
@@ -481,16 +484,8 @@ def add_request(self, request: Request) -> None:
481484

482485
def _free_blocks(self, request: Request):
483486
if self.config.cache_config.enable_prefix_caching:
484-
# TODO(chengyanfu): support cache ouput blocks for prefix caching
485-
if request.get("prefill_block_num", None) is None:
486-
leaf_node = self.cache_manager.req_leaf_map[request.request_id]
487-
self.cache_manager.decrease_request_share_count(request.request_id)
488-
self.cache_manager.free_nodes_directly(leaf_node)
489-
self.cache_manager.recycle_gpu_blocks(request.block_tables[request.cache_info[0] :])
490-
491-
else:
492-
self.cache_manager.release_block_ids_async(request)
493-
self.cache_manager.recycle_gpu_blocks(request.block_tables[request.prefill_block_num :])
487+
self.cache_manager.release_block_ids(request)
488+
self.cache_manager.recycle_gpu_blocks(request.block_tables[request.cached_block_num :])
494489
else:
495490
self.cache_manager.recycle_gpu_blocks(request.block_tables)
496491
request.block_tables = []

fastdeploy/worker/gpu_model_runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1347,6 +1347,7 @@ def _get_skip_idx(self, model_forward_batch: Optional[List[Request]] = None):
13471347
if (
13481348
not self.cache_config.enable_chunked_prefill
13491349
or self.guided_backend is None
1350+
or model_forward_batch is None
13501351
or envs.ENABLE_V1_KVCACHE_SCHEDULER
13511352
):
13521353
return skip_idx_list
@@ -1549,7 +1550,7 @@ def _add_cache(self, model_forward_batch) -> None:
15491550
"""
15501551
Add cache for guided decoding.
15511552
"""
1552-
if self.guided_backend is None:
1553+
if self.guided_backend is None or model_forward_batch is None:
15531554
return
15541555

15551556
for request in model_forward_batch:

0 commit comments

Comments
 (0)