@@ -38,6 +38,60 @@ def wait(self):
3838 if attach_shm_handle is not None :
3939 attach_shm_handle .wait ()
4040
41+ def load_cpu_cache_to_reqs (self , reqs : List [InferReq ]):
42+ idle_token_num = g_infer_context .get_can_alloc_token_num ()
43+ token_page_size = self .args .cpu_cache_token_page_size
44+ all_page_list = []
45+ is_master_in_dp = self .backend .is_master_in_dp
46+ for req in reqs :
47+ page_list = req .shm_req .cpu_cache_match_page_indexes .get_all ()
48+ match_tokens = len (page_list ) * token_page_size
49+ # 更新命中的 cpu kv cache 长度.
50+ if is_master_in_dp :
51+ req .shm_req .cpu_prompt_cache_len = match_tokens
52+
53+ need_token_num = match_tokens - req .cur_kv_len
54+ # 多匹配了一定数量的token 才进行复制操作,不然操作效率不高
55+ if need_token_num >= 64 :
56+ if need_token_num <= idle_token_num :
57+ if self .backend .radix_cache is not None :
58+ g_infer_context .radix_cache .free_radix_cache_to_get_enough_token (need_token_num = need_token_num )
59+
60+ # 计算需要加载的页面(只加载未匹配的部分)
61+ cur_kv_pages = req .cur_kv_len // token_page_size
62+ need_pages = page_list [cur_kv_pages :] # 只取需要的页面
63+ actual_need_tokens = len (need_pages ) * token_page_size
64+
65+ mem_indexes = g_infer_context .req_manager .mem_manager .alloc (need_size = actual_need_tokens )
66+
67+ # 将 cpu page 的内容拷贝到 gpu 页面中
68+ load_cpu_kv_to_gpu (
69+ mem_indexes = mem_indexes ,
70+ gpu_kv_cache = self .backend .model .mem_manager .kv_buffer ,
71+ cpu_kv_cache = self .cpu_cache_client .cpu_kv_cache_tensor ,
72+ page_indexes = torch .tensor (need_pages , dtype = torch .int32 , device = "cpu" ).cuda (non_blocking = True ),
73+ )
74+
75+ torch .cuda .current_stream ().synchronize ()
76+
77+ idle_token_num -= actual_need_tokens
78+ g_infer_context .req_manager .req_to_token_indexs [
79+ req .req_idx , req .cur_kv_len : (req .cur_kv_len + actual_need_tokens )
80+ ] = mem_indexes
81+ req .cur_kv_len = req .cur_kv_len + actual_need_tokens
82+ if self .backend .is_master_in_dp :
83+ req .shm_req .shm_cur_kv_len = req .cur_kv_len
84+
85+ all_page_list .extend (page_list )
86+
87+ dist .barrier (group = self .init_sync_group )
88+
89+ if self .backend .is_master_in_dp :
90+ self .cpu_cache_client .lock .acquire_sleep1ms ()
91+ self .cpu_cache_client .deref_pages (page_list = all_page_list )
92+ self .cpu_cache_client .lock .release ()
93+ return
94+
4195 def handle_finished_reqs (self , finished_reqs : List [InferReq ]) -> List [InferReq ]:
4296 """
4397 将满足cpu kv cache 卸载条件的请求进行处理,并返回需要真正退出的请求列表。
@@ -181,60 +235,6 @@ def update_cpu_cache_task_states(self):
181235 task .req_obj .cpu_cache_task_status = InferReq ._CpuCacheTaskStatus .FINISHED
182236 return
183237
184- def fill_cpu_cache_to_reqs (self , reqs : List [InferReq ]):
185- idle_token_num = g_infer_context .get_can_alloc_token_num ()
186- token_page_size = self .args .cpu_cache_token_page_size
187- all_page_list = []
188- is_master_in_dp = self .backend .is_master_in_dp
189- for req in reqs :
190- page_list = req .shm_req .cpu_cache_match_page_indexes .get_all ()
191- match_tokens = len (page_list ) * token_page_size
192- # 更新命中的 cpu kv cache 长度.
193- if is_master_in_dp :
194- req .shm_req .cpu_prompt_cache_len = match_tokens
195-
196- need_token_num = match_tokens - req .cur_kv_len
197- # 多匹配了一定数量的token 才进行复制操作,不然操作效率不高
198- if need_token_num >= 64 :
199- if need_token_num <= idle_token_num :
200- if self .backend .radix_cache is not None :
201- g_infer_context .radix_cache .free_radix_cache_to_get_enough_token (need_token_num = need_token_num )
202-
203- # 计算需要加载的页面(只加载未匹配的部分)
204- cur_kv_pages = req .cur_kv_len // token_page_size
205- need_pages = page_list [cur_kv_pages :] # 只取需要的页面
206- actual_need_tokens = len (need_pages ) * token_page_size
207-
208- mem_indexes = g_infer_context .req_manager .mem_manager .alloc (need_size = actual_need_tokens )
209-
210- # 将 cpu page 的内容拷贝到 gpu 页面中
211- load_cpu_kv_to_gpu (
212- mem_indexes = mem_indexes ,
213- gpu_kv_cache = self .backend .model .mem_manager .kv_buffer ,
214- cpu_kv_cache = self .cpu_cache_client .cpu_kv_cache_tensor ,
215- page_indexes = torch .tensor (need_pages , dtype = torch .int32 , device = "cpu" ).cuda (non_blocking = True ),
216- )
217-
218- torch .cuda .current_stream ().synchronize ()
219-
220- idle_token_num -= actual_need_tokens
221- g_infer_context .req_manager .req_to_token_indexs [
222- req .req_idx , req .cur_kv_len : (req .cur_kv_len + actual_need_tokens )
223- ] = mem_indexes
224- req .cur_kv_len = req .cur_kv_len + actual_need_tokens
225- if self .backend .is_master_in_dp :
226- req .shm_req .shm_cur_kv_len = req .cur_kv_len
227-
228- all_page_list .extend (page_list )
229-
230- dist .barrier (group = self .init_sync_group )
231-
232- if self .backend .is_master_in_dp :
233- self .cpu_cache_client .lock .acquire_sleep1ms ()
234- self .cpu_cache_client .deref_pages (page_list = all_page_list )
235- self .cpu_cache_client .lock .release ()
236- return
237-
238238
239239@dataclasses .dataclass
240240class TransTask :
0 commit comments