Skip to content

Commit 92bd068

Browse files
committed
fix
1 parent 5f995ae commit 92bd068

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,10 +283,10 @@ def init_all(self):
283283

284284
if self.paused or not self.initialized:
285285
# 如果是具有 prompt_cache 的使用特性则需要进行提前的填充和恢复操作。
286-
if g_infer_context.radix_cache is not None and self.get_cur_total_len() > 2:
286+
if g_infer_context.radix_cache is not None and self.get_cur_total_len() > 1:
287287
input_token_ids = self.shm_req.shm_prompt_ids.arr[0 : self.get_cur_total_len()]
288288
key = torch.tensor(input_token_ids, dtype=torch.int64, device="cpu")
289-
key = key[0 : len(key) - 2] # 最后两个不需要,因为需要一个额外的token,让其在prefill的时候输出下一个token的值
289+
key = key[0 : len(key) - 1] # 最后两个不需要,因为需要一个额外的token,让其在prefill的时候输出下一个token的值
290290
share_node, kv_len, value_tensor = g_infer_context.radix_cache.match_prefix(key, update_refs=True)
291291
if share_node is not None:
292292
self.shared_kv_node = share_node

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,10 @@ def _get_classed_reqs(self, req_ids: List[int], no_decode: bool = False):
335335
ok_finished_reqs.append(req_obj)
336336
continue
337337

338-
is_decode = req_obj.cur_kv_len + 1 == req_obj.get_cur_total_len()
338+
is_decode = (
339+
req_obj.cur_kv_len + 1 == req_obj.get_cur_total_len()
340+
and req_obj.cur_kv_len + 1 != req_obj.shm_req.input_len
341+
)
339342

340343
if not is_decode:
341344
prefill_reqs.append(req_obj)

0 commit comments

Comments
 (0)