Skip to content

Commit d6fd1a2

Browse files
committed
fix
1 parent 3de061d commit d6fd1a2

File tree

4 files changed

+5
-6
lines changed

4 files changed

+5
-6
lines changed

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -395,14 +395,15 @@ def diverse_copy(self, req_manager, is_prefill):
395395
prefix_len = prev_req.shared_kv_node.node_prefix_total_len
396396
else:
397397
prefix_len = 0
398-
pre_input_token_ids = prev_req.get_input_token_ids()
398+
prefix_len = max(prefix_len, prev_req.cur_kv_len)
399+
pre_input_token_ids = prev_req.get_chuncked_input_token_ids()
399400
cache_token_id = req_manager.req_to_token_indexs[prev_req.req_idx][prefix_len : len(pre_input_token_ids)]
400401
# update the InferReq status and mem_manager status for cache sharing
401402
for req_id in self.req_ids_group[:]:
402403
if req_id == convert_sub_id_to_group_id(req_id):
403404
continue
404405
req = g_infer_context.requests_mapping[req_id]
405406
req.finish_status.set_status(FinishStatus.NO_FINISH)
406-
input_token_ids = req.get_input_token_ids()
407+
input_token_ids = req.get_chuncked_input_token_ids()
407408
req_manager.req_to_token_indexs[req.req_idx][prefix_len : len(input_token_ids)] = cache_token_id
408409
assert len(input_token_ids) == len(pre_input_token_ids)

lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ def diverse_copy(self, groups: List[InferReqGroup]):
3838
for i in range(len(groups)):
3939
req_group = groups[i]
4040
best_of = req_group.best_of()
41-
_0_req_obj = req_group.get_req(0)
42-
if best_of > 1 and _0_req_obj.get_chuncked_input_token_len() == _0_req_obj.get_cur_total_len():
41+
if best_of > 1:
4342
req_group.diverse_copy(g_infer_context.req_manager, is_prefill=True)
4443
batch_idx.extend([i for _ in range(best_of)])
4544
else:
@@ -58,7 +57,6 @@ def decode(self):
5857

5958
if aborted_reqs:
6059
g_infer_context.filter_reqs(aborted_reqs)
61-
6260
if prefill_reqs:
6361
group_reqs = [
6462
g_infer_context.requests_mapping[req.req_id]

lightllm/server/router/req_queue/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from .continues_batch.impl import ContinuesBatchQueue
2-
from .continues_batch.beam_impl import BeamContinuesBatchQueue
32
from .continues_batch.impl_for_pd_decode import QueueForPDDecode
43
from .chunked_prefill.impl_for_pd_prefill import QueueForPDChunkedPrefill
54
from .chunked_prefill.impl import ChunkedPrefillQueue
5+
from .chunked_prefill.beam_impl import BeamContinuesBatchQueue
66
from .dp_base_queue import DpQueue
77

88

lightllm/server/router/req_queue/continues_batch/beam_impl.py renamed to lightllm/server/router/req_queue/chunked_prefill/beam_impl.py

File renamed without changes.

0 commit comments

Comments
 (0)