Skip to content

Commit 23ee96e

Browse files
author
wangzaijun
committed
fix
1 parent a99315c commit 23ee96e

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ def load_cpu_cache_to_reqs(self, reqs: List[InferReq]):
6969
# TODO fa3 现在必须使用同步模式, 未来需要移除
7070
g_infer_context.get_overlap_stream().synchronize()
7171

72+
# TODO 更有效的分配策略。
73+
grid_num = 16 if self.need_sync_compute_stream or (not self.args.enable_fa3) else 1
74+
7275
# 将 cpu page 的内容拷贝到 gpu 页面中
7376
load_cpu_kv_to_gpu(
7477
gpu_mem_indexes=mem_indexes.cuda(non_blocking=True),
@@ -77,7 +80,7 @@ def load_cpu_cache_to_reqs(self, reqs: List[InferReq]):
7780
page_indexes=torch.tensor(need_pages, dtype=torch.int32, device="cpu").cuda(non_blocking=True),
7881
tp_index=self.backend.rank_in_dp,
7982
tp_world_size=self.backend.dp_world_size,
80-
grid_num=1 if self.args.enable_fa3 else 16, # TODO 更有效的分配策略。
83+
grid_num=grid_num,
8184
)
8285

8386
torch.cuda.current_stream().synchronize()
@@ -202,6 +205,10 @@ def _start_kv_cache_offload_task(
202205
move_token_num = item_size * self.args.cpu_cache_token_page_size
203206
assert req.cur_kv_len >= item_size * self.args.cpu_cache_token_page_size
204207
token_indexes = self.backend.model.req_manager.req_to_token_indexs[req.req_idx, 0:move_token_num]
208+
209+
# TODO 更有效的分配策略。
210+
grid_num = 16 if self.need_sync_compute_stream or (not self.args.enable_fa3) else 1
211+
205212
# assert max(page_list) < self.cpu_cache_client.cpu_kv_cache_tensor.shape[0]
206213
offload_gpu_kv_to_cpu(
207214
token_indexes=token_indexes,
@@ -211,7 +218,7 @@ def _start_kv_cache_offload_task(
211218
page_readies=page_readies,
212219
tp_index=self.backend.rank_in_dp,
213220
tp_world_size=self.backend.dp_world_size,
214-
grid_num=1 if self.args.enable_fa3 else 16, # TODO 更有效的分配策略。
221+
grid_num=grid_num,
215222
)
216223

217224
sync_event = torch.cuda.Event()

0 commit comments

Comments
 (0)