Skip to content

Commit a829095

Browse files
author
wangzaijun
committed
fix names
1 parent 7ae1d0c commit a829095

File tree

2 files changed

+63
-57
lines changed

2 files changed

+63
-57
lines changed

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def _read_reqs_buffer_and_init_reqs(self):
369369
if init_reqs:
370370
req_ids = self._init_reqs(reqs=init_reqs)
371371
if self.args.enable_cpu_cache and req_ids:
372-
self._fill_cpu_cache_to_reqs(req_ids=req_ids)
372+
self._load_cpu_cache_to_reqs(req_ids=req_ids)
373373
return
374374

375375
def _read_nixl_trans_io_buffer_and_update_req_status(self):
@@ -424,10 +424,10 @@ def _init_reqs(self, reqs: List[Tuple]):
424424
req_ids = [e[0] for e in reqs]
425425
return req_ids
426426

427-
def _fill_cpu_cache_to_reqs(self, req_ids):
427+
def _load_cpu_cache_to_reqs(self, req_ids):
428428
req_objs: List[InferReq] = [g_infer_context.requests_mapping[req_id] for req_id in req_ids]
429429
g_infer_state_lock.acquire()
430-
self.multi_level_cache_module.fill_cpu_cache_to_reqs(reqs=req_objs)
430+
self.multi_level_cache_module.load_cpu_cache_to_reqs(reqs=req_objs)
431431
g_infer_state_lock.release()
432432
return
433433

@@ -536,6 +536,12 @@ def _get_classed_reqs(
536536
req_obj.wait_pause = True
537537
wait_pause_count += 1
538538
else:
539+
# 在 diverse mode 模式下,prefill 只会使用 master 状态的请求,slave 请求依靠后续
540+
# 的推理代码中将master请求的状态复制到slave请求中去, 所以这里 slave 状态的请求,不
541+
# 放入到 prefill reqs 队列中,在其他模式下,所有请求都是 master状态,所以也不受影响
542+
if req_obj.is_slave_req():
543+
continue
544+
539545
token_num = req_obj.prefill_need_token_num(is_chuncked_prefill=not self.disable_chunked_prefill)
540546
if prefill_tokens + token_num > self.batch_max_tokens:
541547
continue

lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py

Lines changed: 54 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,60 @@ def wait(self):
3838
if attach_shm_handle is not None:
3939
attach_shm_handle.wait()
4040

41+
def load_cpu_cache_to_reqs(self, reqs: List[InferReq]):
42+
idle_token_num = g_infer_context.get_can_alloc_token_num()
43+
token_page_size = self.args.cpu_cache_token_page_size
44+
all_page_list = []
45+
is_master_in_dp = self.backend.is_master_in_dp
46+
for req in reqs:
47+
page_list = req.shm_req.cpu_cache_match_page_indexes.get_all()
48+
match_tokens = len(page_list) * token_page_size
49+
# 更新命中的 cpu kv cache 长度.
50+
if is_master_in_dp:
51+
req.shm_req.cpu_prompt_cache_len = match_tokens
52+
53+
need_token_num = match_tokens - req.cur_kv_len
54+
# 多匹配了一定数量的token 才进行复制操作,不然操作效率不高
55+
if need_token_num >= 64:
56+
if need_token_num <= idle_token_num:
57+
if self.backend.radix_cache is not None:
58+
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(need_token_num=need_token_num)
59+
60+
# 计算需要加载的页面(只加载未匹配的部分)
61+
cur_kv_pages = req.cur_kv_len // token_page_size
62+
need_pages = page_list[cur_kv_pages:] # 只取需要的页面
63+
actual_need_tokens = len(need_pages) * token_page_size
64+
65+
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(need_size=actual_need_tokens)
66+
67+
# 将 cpu page 的内容拷贝到 gpu 页面中
68+
load_cpu_kv_to_gpu(
69+
mem_indexes=mem_indexes,
70+
gpu_kv_cache=self.backend.model.mem_manager.kv_buffer,
71+
cpu_kv_cache=self.cpu_cache_client.cpu_kv_cache_tensor,
72+
page_indexes=torch.tensor(need_pages, dtype=torch.int32, device="cpu").cuda(non_blocking=True),
73+
)
74+
75+
torch.cuda.current_stream().synchronize()
76+
77+
idle_token_num -= actual_need_tokens
78+
g_infer_context.req_manager.req_to_token_indexs[
79+
req.req_idx, req.cur_kv_len : (req.cur_kv_len + actual_need_tokens)
80+
] = mem_indexes
81+
req.cur_kv_len = req.cur_kv_len + actual_need_tokens
82+
if self.backend.is_master_in_dp:
83+
req.shm_req.shm_cur_kv_len = req.cur_kv_len
84+
85+
all_page_list.extend(page_list)
86+
87+
dist.barrier(group=self.init_sync_group)
88+
89+
if self.backend.is_master_in_dp:
90+
self.cpu_cache_client.lock.acquire_sleep1ms()
91+
self.cpu_cache_client.deref_pages(page_list=all_page_list)
92+
self.cpu_cache_client.lock.release()
93+
return
94+
4195
def handle_finished_reqs(self, finished_reqs: List[InferReq]) -> List[InferReq]:
4296
"""
4397
将满足cpu kv cache 卸载条件的请求进行处理,并返回需要真正退出的请求列表。
@@ -181,60 +235,6 @@ def update_cpu_cache_task_states(self):
181235
task.req_obj.cpu_cache_task_status = InferReq._CpuCacheTaskStatus.FINISHED
182236
return
183237

184-
def fill_cpu_cache_to_reqs(self, reqs: List[InferReq]):
185-
idle_token_num = g_infer_context.get_can_alloc_token_num()
186-
token_page_size = self.args.cpu_cache_token_page_size
187-
all_page_list = []
188-
is_master_in_dp = self.backend.is_master_in_dp
189-
for req in reqs:
190-
page_list = req.shm_req.cpu_cache_match_page_indexes.get_all()
191-
match_tokens = len(page_list) * token_page_size
192-
# 更新命中的 cpu kv cache 长度.
193-
if is_master_in_dp:
194-
req.shm_req.cpu_prompt_cache_len = match_tokens
195-
196-
need_token_num = match_tokens - req.cur_kv_len
197-
# 多匹配了一定数量的token 才进行复制操作,不然操作效率不高
198-
if need_token_num >= 64:
199-
if need_token_num <= idle_token_num:
200-
if self.backend.radix_cache is not None:
201-
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(need_token_num=need_token_num)
202-
203-
# 计算需要加载的页面(只加载未匹配的部分)
204-
cur_kv_pages = req.cur_kv_len // token_page_size
205-
need_pages = page_list[cur_kv_pages:] # 只取需要的页面
206-
actual_need_tokens = len(need_pages) * token_page_size
207-
208-
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(need_size=actual_need_tokens)
209-
210-
# 将 cpu page 的内容拷贝到 gpu 页面中
211-
load_cpu_kv_to_gpu(
212-
mem_indexes=mem_indexes,
213-
gpu_kv_cache=self.backend.model.mem_manager.kv_buffer,
214-
cpu_kv_cache=self.cpu_cache_client.cpu_kv_cache_tensor,
215-
page_indexes=torch.tensor(need_pages, dtype=torch.int32, device="cpu").cuda(non_blocking=True),
216-
)
217-
218-
torch.cuda.current_stream().synchronize()
219-
220-
idle_token_num -= actual_need_tokens
221-
g_infer_context.req_manager.req_to_token_indexs[
222-
req.req_idx, req.cur_kv_len : (req.cur_kv_len + actual_need_tokens)
223-
] = mem_indexes
224-
req.cur_kv_len = req.cur_kv_len + actual_need_tokens
225-
if self.backend.is_master_in_dp:
226-
req.shm_req.shm_cur_kv_len = req.cur_kv_len
227-
228-
all_page_list.extend(page_list)
229-
230-
dist.barrier(group=self.init_sync_group)
231-
232-
if self.backend.is_master_in_dp:
233-
self.cpu_cache_client.lock.acquire_sleep1ms()
234-
self.cpu_cache_client.deref_pages(page_list=all_page_list)
235-
self.cpu_cache_client.lock.release()
236-
return
237-
238238

239239
@dataclasses.dataclass
240240
class TransTask:

0 commit comments

Comments
 (0)