Skip to content

Commit 3363fd3

Browse files
author
liujiacheng
committed
fix
1 parent 4f9269f commit 3363fd3

File tree

2 files changed

+52
-4
lines changed

2 files changed

+52
-4
lines changed

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def init_model(self, kvargs):
157157

158158
self.logger.info(f"loaded model class {self.model.__class__}")
159159
g_infer_context.register(
160+
backend=self,
160161
req_manager=self.model.req_manager,
161162
radix_cache=self.radix_cache,
162163
shm_req_manager=self.shm_req_manager,
@@ -326,7 +327,10 @@ def _read_reqs_buffer_and_init_reqs(self):
326327
req: InferReq = g_infer_context.requests_mapping[obj.req_id]
327328
req.infer_aborted = True
328329
else:
329-
self._init_reqs(reqs=cmds)
330+
req_ids = self._init_reqs(reqs=cmds)
331+
if self.args.enable_cpu_cache:
332+
self._fill_cpu_cache_to_reqs(req_ids=req_ids)
333+
330334
return
331335

332336
# 一些可以复用的通用功能函数
@@ -348,6 +352,13 @@ def _init_reqs(self, reqs: List[Tuple]):
348352
req_ids = [e[0] for e in reqs]
349353
return req_ids
350354

355+
def _fill_cpu_cache_to_reqs(self, req_ids):
356+
req_objs: List[InferReq] = [g_infer_context.requests_mapping[req_id] for req_id in req_ids]
357+
g_infer_state_lock.acquire()
358+
self.multi_level_cache_manager.fill_cpu_cache_to_reqs(reqs=req_objs)
359+
g_infer_state_lock.release()
360+
return
361+
351362
# 一些可以复用的通用功能函数
352363
def _get_classed_reqs(
353364
self,
@@ -374,6 +385,8 @@ def _get_classed_reqs(
374385
4. prefill_reqs 需要进行prefill操作的请求
375386
5. decode_reqs 需要进行decode操作的请求
376387
"""
388+
if self.args.enable_cpu_cache:
389+
self.multi_level_cache_manager.update_kv_cache_offload_task_states()
377390

378391
if req_ids is None:
379392
req_ids = g_infer_context.infer_req_ids
@@ -486,7 +499,7 @@ def _cpu_kv_cache_task_handle(self, finished_reqs: List[InferReq]) -> List[Infer
486499
else:
487500
# 将请求的 kv cache 卸载到 cpu cache 中
488501
multi_level_cache_manager = self.multi_level_cache_manager
489-
trans_task = multi_level_cache_manager.req_to_cpu_cache_task(
502+
trans_task = multi_level_cache_manager.start_kv_cache_offload_task(
490503
req=req, cpu_kv_cache_stream=g_infer_context.get_cpu_kv_cache_stream()
491504
)
492505
if trans_task is not None:

lightllm/server/router/model_infer/mode_backend/multi_level_cache_manager.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ..infer_batch import InferReq
1010
from lightllm.utils.dist_utils import create_new_group_for_current_dp
1111
from lightllm.common.basemodel.triton_kernel.kv_cache_offload import offload_gpu_kv_to_cpu
12+
from lightllm.server.router.model_infer.infer_batch import g_infer_context
1213

1314

1415
class MultiLevelCacheManager(object):
@@ -20,11 +21,14 @@ def __init__(self, backend):
2021
self.gloo_group = create_new_group_for_current_dp("gloo")
2122
self.filter_group = create_new_group_for_current_dp("gloo")
2223
self.sync_group = create_new_group_for_current_dp("nccl")
24+
self.init_sync_group = create_new_group_for_current_dp("nccl")
2325

2426
self.cpu_cache_handle_queue = deque()
2527
self.cpu_cache_client = CpuKvCacheClient(init_shm_data=False)
2628

27-
def req_to_cpu_cache_task(self, req: InferReq, cpu_kv_cache_stream: torch.cuda.Stream) -> Optional["TransTask"]:
29+
def start_kv_cache_offload_task(
30+
self, req: InferReq, cpu_kv_cache_stream: torch.cuda.Stream
31+
) -> Optional["TransTask"]:
2832
with torch.cuda.stream(cpu_kv_cache_stream):
2933
all_token_hash_list = req.shm_req.token_hash_list.get_all()
3034
block_size = req.cur_kv_len // self.args.cpu_cache_token_chuncked_size
@@ -79,7 +83,7 @@ def req_to_cpu_cache_task(self, req: InferReq, cpu_kv_cache_stream: torch.cuda.S
7983

8084
return trans_task
8185

82-
def handle_task_queue(self):
86+
def update_kv_cache_offload_task_states(self):
8387
if self.backend.is_master_in_dp:
8488
trans_ok_reqs = []
8589
while len(self.cpu_cache_handle_queue) != 0:
@@ -110,6 +114,37 @@ def handle_task_queue(self):
110114
req.req_obj.cpu_cache_task_finished = True
111115
return
112116

117+
def fill_cpu_cache_to_reqs(self, reqs: List[InferReq]):
118+
idle_token_num = g_infer_context.get_can_alloc_token_num()
119+
token_chuncked_size = self.args.cpu_cache_token_chuncked_size
120+
all_page_list = []
121+
for req in reqs:
122+
page_list = req.shm_req.cpu_cache_match_page_indexes.get_all()
123+
match_tokens = len(page_list) * token_chuncked_size
124+
need_token_num = match_tokens - req.cur_kv_len
125+
# 多匹配了一定数量的token 才进行复制操作,不然操作效率不高
126+
if need_token_num > 256:
127+
if need_token_num <= idle_token_num:
128+
if self.backend.radix_cache is not None:
129+
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(need_token_num=need_token_num)
130+
131+
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(need_size=need_token_num)
132+
idle_token_num -= need_token_num
133+
g_infer_context.req_manager.req_to_token_indexs[
134+
req.req_idx, req.cur_kv_len : (req.cur_kv_len + need_token_num)
135+
] = mem_indexes
136+
req.cur_kv_len = req.cur_kv_len + need_token_num
137+
if self.backend.is_master_in_dp:
138+
req.shm_req.shm_cur_kv_len = req.cur_kv_len
139+
140+
all_page_list.extend(page_list)
141+
142+
if self.backend.is_master_in_dp:
143+
self.cpu_cache_client.lock.acquire_sleep1ms()
144+
self.cpu_cache_client.deref_pages(page_list=all_page_list)
145+
self.cpu_cache_client.lock.release()
146+
return
147+
113148

114149
@dataclasses.dataclass
115150
class TransTask:

0 commit comments

Comments
 (0)