@@ -300,13 +300,18 @@ def _init_reqs(self, reqs: List[Tuple], init_req_obj=True):
300300 return req_ids
301301
302302 # 一些可以复用的通用功能函数
303- def _get_classed_reqs (self , req_ids : List [int ], no_decode : bool = False ):
303+ def _get_classed_reqs (self , req_ids : List [int ], no_decode : bool = False , strict_prefill : bool = False ):
304304 """
305305 当将参数 no_decode 设置为True后,返回的 decode_reqs 永远为空list,主要是
306306 PD 分离的某些backend需要用这个参数进行控制,因为P节点永远只进行Prefill,
307307 避免一些特殊情况,如 radix cache 命中后,只有1token需要prefill,这个判断
308308 条件和decode请求的分类条件相同。所以添加一个参数进行区分。
309309
310+ strict_prefill参数用于控制当 cur_kv_len + 1 == input_len 时,是否将请求
311+ 分为 prefill,当 strict_prefill 设置为True时,表示需要将这个请求分为 prefill,
312+ 为 False 时,将这个请求分为decode。 strict_prefill 主要是用于diverse mode
313+ 使用时,其他模式目前不使用。
314+
310315 将请求分类返回:
311316 1. unit reqs 还未完整初始化的请求
312317 2. aborted_reqs aborted 的请求
@@ -335,16 +340,20 @@ def _get_classed_reqs(self, req_ids: List[int], no_decode: bool = False):
335340 ok_finished_reqs .append (req_obj )
336341 continue
337342
338- is_decode = (
339- req_obj .cur_kv_len + 1 == req_obj .get_cur_total_len ()
340- and req_obj .cur_kv_len + 1 != req_obj .shm_req .input_len
341- )
343+ if no_decode :
344+ prefill_reqs .append (req_obj )
345+ continue
346+
347+ is_decode = req_obj .cur_kv_len + 1 == req_obj .get_cur_total_len ()
342348
343349 if not is_decode :
344350 prefill_reqs .append (req_obj )
345351 else :
346- if no_decode :
347- prefill_reqs .append (req_obj )
352+ if strict_prefill :
353+ if req_obj .cur_kv_len + 1 == req_obj .shm_req .input_len :
354+ prefill_reqs .append (req_obj )
355+ else :
356+ decode_reqs .append (req_obj )
348357 else :
349358 decode_reqs .append (req_obj )
350359
0 commit comments