@@ -57,15 +57,14 @@ def decode(self):
5757
5858 # 再 prefill
5959 if self .chunked_prefill_state .need_prefill (prefill_reqs = prefill_reqs , decode_reqs = decode_reqs ):
60- self ._topk_repair (run_reqs = prefill_reqs )
6160 ContinuesBatchBackend .normal_prefill_reqs (
61+ self ,
6262 prefill_reqs = prefill_reqs ,
6363 uninit_reqs = uninit_reqs ,
6464 ok_finished_reqs = ok_finished_reqs ,
6565 mask_func = self ._prefill_mask_callback ,
6666 extra_post_req_handle_func = self ._update_tokenhealing_req_prefix_str ,
6767 )
68- self ._topk_recover (run_reqs = prefill_reqs )
6968
7069 self ._overlap_req_init_and_filter (uninit_reqs = uninit_reqs , ok_finished_reqs = ok_finished_reqs , clear_list = True )
7170 return
@@ -142,21 +141,6 @@ def _mask_decode_not_prefix_token(self, i, run_obj: InferReq, mask):
142141 mask [i , :] = False
143142 return
144143
145- def _topk_repair (self , run_reqs : list [InferReq ]):
146- for req_obj in run_reqs :
147- if len (req_obj .prefix_str ) != 0 :
148- req_obj .origin_topk = req_obj .sampling_param .shm_param .top_k
149- req_obj .sampling_param .shm_param .top_k = 1
150- else :
151- req_obj .origin_topk = req_obj .sampling_param .shm_param .top_k
152- return
153-
154- def _topk_recover (self , run_reqs : list [InferReq ]):
155- for req_obj in run_reqs :
156- if hasattr (req_obj , "origin_topk" ):
157- req_obj .sampling_param .shm_param .top_k = req_obj .origin_topk
158- return
159-
160144 def _init_prefix_infos (self , run_reqs : List [InferReq ]):
161145 for i , run_obj in enumerate (run_reqs ):
162146 if not hasattr (run_obj , "prefix_str" ):
0 commit comments