44import time
55import threading
66import torch .distributed as dist
7- from queue import Queue
87from typing import List , Tuple , Callable , Optional
98from transformers .configuration_utils import PretrainedConfig
109from lightllm .utils .infer_utils import set_random_seed
@@ -329,6 +328,7 @@ def _get_classed_reqs(
329328 self ,
330329 req_ids : List [int ] = None ,
331330 no_decode : bool = False ,
331+ strict_prefill : bool = False ,
332332 recover_paused : bool = False ,
333333 ):
334334 """
@@ -337,6 +337,11 @@ def _get_classed_reqs(
337337 避免一些特殊情况,如 radix cache 命中后,只有1token需要prefill,这个判断
338338 条件和decode请求的分类条件相同。所以添加一个参数进行区分。
339339
340+ strict_prefill参数用于控制当 cur_kv_len + 1 == input_len 时,是否将请求
341+ 分为 prefill,当 strict_prefill 设置为True时,表示需要将这个请求分为 prefill,
342+ 为 False 时,将这个请求分为decode。 strict_prefill 主要是用于diverse mode
343+ 使用时,其他模式目前不使用。
344+
340345 将请求分类返回:
341346 1. wait_pause_reqs 因为推理资源不够,等待被暂停的请求。
342347 2. paused_reqs 已经被暂停的请求,可能会被恢复。
@@ -389,7 +394,7 @@ def _get_classed_reqs(
389394 is_decode = False
390395 else :
391396 is_decode = req_obj .cur_kv_len + 1 == req_obj .get_cur_total_len ()
392- if is_decode and req_obj .cur_kv_len + 1 == req_obj .shm_req .input_len :
397+ if is_decode and strict_prefill and req_obj .cur_kv_len + 1 == req_obj .shm_req .input_len :
393398 is_decode = False
394399
395400 if is_decode :
@@ -461,21 +466,18 @@ def _pre_post_handle(self, run_reqs: List[InferReq], is_chuncked_mode: bool) ->
461466 def _post_handle (
462467 self ,
463468 run_reqs : List [InferReq ],
469+ next_token_ids : List [int ],
470+ next_token_logprobs : List [float ],
464471 run_reqs_update_packs : List [InferReqUpdatePack ],
465472 extra_post_req_handle_func : Optional [Callable [[InferReq , int , float ], None ]] = None ,
466473 ):
467474 """
468475 extra_post_req_handle_func 用于提供在一个请求确定输出的时候,给出额外的后处理操作,主要是用于
469476 约束输出等模式,设置自己请求内部的状态机的状态,并添加额外的停止判定条件等。
470477 """
471- next_token_ids = []
472- for req_obj , pack in zip (run_reqs , run_reqs_update_packs ):
473- next_token_id = g_infer_context .req_manager .req_sampling_params_manager .req_to_next_token_ids_cpu [
474- req_obj .req_idx
475- ]
476- next_token_logprob = g_infer_context .req_manager .req_sampling_params_manager .req_to_next_token_probs_cpu [
477- req_obj .req_idx
478- ]
478+ for req_obj , next_token_id , next_token_logprob , pack in zip (
479+ run_reqs , next_token_ids , next_token_logprobs , run_reqs_update_packs
480+ ):
479481 req_obj : InferReq = req_obj
480482 pack : InferReqUpdatePack = pack
481483 pack .handle (
@@ -485,7 +487,6 @@ def _post_handle(
485487 extra_post_req_handle_func = extra_post_req_handle_func ,
486488 is_master_in_dp = self .is_master_in_dp ,
487489 )
488- next_token_ids .append (next_token_id )
489490
490491 g_infer_context .req_manager .req_sampling_params_manager .update_reqs_token_counter (
491492 req_objs = run_reqs , next_token_ids = next_token_ids
0 commit comments