Skip to content

Commit dde6f18

Browse files
committed
fix
1 parent 204f6fa commit dde6f18

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import time
55
import threading
66
import torch.distributed as dist
7-
from queue import Queue
87
from typing import List, Tuple, Callable, Optional
98
from transformers.configuration_utils import PretrainedConfig
109
from lightllm.utils.infer_utils import set_random_seed
@@ -329,6 +328,7 @@ def _get_classed_reqs(
329328
self,
330329
req_ids: List[int] = None,
331330
no_decode: bool = False,
331+
strict_prefill: bool = False,
332332
recover_paused: bool = False,
333333
):
334334
"""
@@ -337,6 +337,11 @@ def _get_classed_reqs(
337337
避免一些特殊情况,如 radix cache 命中后,只有1token需要prefill,这个判断
338338
条件和decode请求的分类条件相同。所以添加一个参数进行区分。
339339
340+
strict_prefill参数用于控制当 cur_kv_len + 1 == input_len 时,是否将请求
341+
分为 prefill,当 strict_prefill 设置为True时,表示需要将这个请求分为 prefill,
342+
为 False 时,将这个请求分为decode。 strict_prefill 主要是用于diverse mode
343+
使用时,其他模式目前不使用。
344+
340345
将请求分类返回:
341346
1. wait_pause_reqs 因为推理资源不够,等待被暂停的请求。
342347
2. paused_reqs 已经被暂停的请求,可能会被恢复。
@@ -389,7 +394,7 @@ def _get_classed_reqs(
389394
is_decode = False
390395
else:
391396
is_decode = req_obj.cur_kv_len + 1 == req_obj.get_cur_total_len()
392-
if is_decode and req_obj.cur_kv_len + 1 == req_obj.shm_req.input_len:
397+
if is_decode and strict_prefill and req_obj.cur_kv_len + 1 == req_obj.shm_req.input_len:
393398
is_decode = False
394399

395400
if is_decode:
@@ -461,21 +466,18 @@ def _pre_post_handle(self, run_reqs: List[InferReq], is_chuncked_mode: bool) ->
461466
def _post_handle(
462467
self,
463468
run_reqs: List[InferReq],
469+
next_token_ids: List[int],
470+
next_token_logprobs: List[float],
464471
run_reqs_update_packs: List[InferReqUpdatePack],
465472
extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None,
466473
):
467474
"""
468475
extra_post_req_handle_func 用于提供在一个请求确定输出的时候,给出额外的后处理操作,主要是用于
469476
约束输出等模式,设置自己请求内部的状态机的状态,并添加额外的停止判定条件等。
470477
"""
471-
next_token_ids = []
472-
for req_obj, pack in zip(run_reqs, run_reqs_update_packs):
473-
next_token_id = g_infer_context.req_manager.req_sampling_params_manager.req_to_next_token_ids_cpu[
474-
req_obj.req_idx
475-
]
476-
next_token_logprob = g_infer_context.req_manager.req_sampling_params_manager.req_to_next_token_probs_cpu[
477-
req_obj.req_idx
478-
]
478+
for req_obj, next_token_id, next_token_logprob, pack in zip(
479+
run_reqs, next_token_ids, next_token_logprobs, run_reqs_update_packs
480+
):
479481
req_obj: InferReq = req_obj
480482
pack: InferReqUpdatePack = pack
481483
pack.handle(
@@ -485,7 +487,6 @@ def _post_handle(
485487
extra_post_req_handle_func=extra_post_req_handle_func,
486488
is_master_in_dp=self.is_master_in_dp,
487489
)
488-
next_token_ids.append(next_token_id)
489490

490491
g_infer_context.req_manager.req_sampling_params_manager.update_reqs_token_counter(
491492
req_objs=run_reqs, next_token_ids=next_token_ids

0 commit comments

Comments
 (0)