@@ -69,6 +69,9 @@ def load_cpu_cache_to_reqs(self, reqs: List[InferReq]):
6969 # TODO fa3 现在必须使用同步模式, 未来需要移除
7070 g_infer_context .get_overlap_stream ().synchronize ()
7171
72+ # TODO 更有效的分配策略。
73+ grid_num = 16 if self .need_sync_compute_stream or (not self .args .enable_fa3 ) else 1
74+
7275 # 将 cpu page 的内容拷贝到 gpu 页面中
7376 load_cpu_kv_to_gpu (
7477 gpu_mem_indexes = mem_indexes .cuda (non_blocking = True ),
@@ -77,7 +80,7 @@ def load_cpu_cache_to_reqs(self, reqs: List[InferReq]):
7780 page_indexes = torch .tensor (need_pages , dtype = torch .int32 , device = "cpu" ).cuda (non_blocking = True ),
7881 tp_index = self .backend .rank_in_dp ,
7982 tp_world_size = self .backend .dp_world_size ,
80- grid_num = 1 if self . args . enable_fa3 else 16 , # TODO 更有效的分配策略。
83+ grid_num = grid_num ,
8184 )
8285
8386 torch .cuda .current_stream ().synchronize ()
@@ -202,6 +205,10 @@ def _start_kv_cache_offload_task(
202205 move_token_num = item_size * self .args .cpu_cache_token_page_size
203206 assert req .cur_kv_len >= item_size * self .args .cpu_cache_token_page_size
204207 token_indexes = self .backend .model .req_manager .req_to_token_indexs [req .req_idx , 0 :move_token_num ]
208+
209+ # TODO 更有效的分配策略。
210+ grid_num = 16 if self .need_sync_compute_stream or (not self .args .enable_fa3 ) else 1
211+
205212 # assert max(page_list) < self.cpu_cache_client.cpu_kv_cache_tensor.shape[0]
206213 offload_gpu_kv_to_cpu (
207214 token_indexes = token_indexes ,
@@ -211,7 +218,7 @@ def _start_kv_cache_offload_task(
211218 page_readies = page_readies ,
212219 tp_index = self .backend .rank_in_dp ,
213220 tp_world_size = self .backend .dp_world_size ,
214- grid_num = 1 if self . args . enable_fa3 else 16 , # TODO 更有效的分配策略。
221+ grid_num = grid_num ,
215222 )
216223
217224 sync_event = torch .cuda .Event ()
0 commit comments