Skip to content

Commit ac80a73

Browse files
author
niushengxiao
committed
feat: use lighter synchronize
1 parent 18064c6 commit ac80a73

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def handle_finished_reqs(self, finished_reqs: List[InferReq]) -> List[InferReq]:
6464
# 如果开启了cpu cache,将达到finished状态的请求开启将gpu kv cache 卸载到 cpu cache中的操作。
6565
# 当 kv cache 卸载完成后,才会进行请求的真实退出操作。
6666
true_finished_reqs = []
67+
cpu_stream = g_infer_context.get_cpu_kv_cache_stream()
6768
for req in finished_reqs:
6869
# 只有 group_req_id 和 request_id 相同的请求才会被卸载到 cpu cache 中。
6970
# 这个限制是为了兼容 diverse 模式下的请求处理。
@@ -87,11 +88,13 @@ def handle_finished_reqs(self, finished_reqs: List[InferReq]) -> List[InferReq]:
8788
# 发起将请求的 kv cache 卸载到 cpu cache 中的任务
8889
# if self.backend.is_master_in_dp:
8990
# mark_start("blueswhen offload_kv_to_cpu")
90-
torch.cuda.synchronize()
91+
if g_infer_context.overlap_stream is not None:
92+
cpu_stream.wait_stream(g_infer_context.overlap_stream)
93+
else:
94+
cpu_stream.wait_stream(torch.cuda.current_stream())
9195
trans_task = self._start_kv_cache_offload_task(
92-
req=req, cpu_kv_cache_stream=g_infer_context.get_cpu_kv_cache_stream()
96+
req=req, cpu_kv_cache_stream=cpu_stream
9397
)
94-
torch.cuda.synchronize()
9598
# if self.backend.is_master_in_dp:
9699
# mark_end("blueswhen offload_kv_to_cpu")
97100

@@ -100,6 +103,7 @@ def handle_finished_reqs(self, finished_reqs: List[InferReq]) -> List[InferReq]:
100103
else:
101104
true_finished_reqs.append(req)
102105

106+
cpu_stream.synchronize()
103107
return true_finished_reqs
104108
else:
105109
return finished_reqs
@@ -217,7 +221,7 @@ def fill_cpu_cache_to_reqs(self, reqs: List[InferReq]):
217221

218222
need_token_num = match_tokens - req.cur_kv_len
219223
# 多匹配了一定数量的token 才进行复制操作,不然操作效率不高
220-
if need_token_num > 128:
224+
if need_token_num >= 64:
221225
if need_token_num <= idle_token_num:
222226
if self.backend.radix_cache is not None:
223227
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(need_token_num=need_token_num)

0 commit comments

Comments
 (0)