@@ -102,8 +102,15 @@ def prefill_normal(
102102 prefill_reqs , is_chuncked_mode = not self .disable_chunked_prefill , is_multimodal = self .is_multimodal
103103 )
104104 with torch .cuda .stream (g_infer_context .get_overlap_stream ()):
105- _ , next_token_ids_cpu , next_token_logprobs_cpu , _ = self ._main_model_forward (
106- model_input , run_reqs , self .prefill_mask_func
105+ model_output = self .model .forward (model_input )
106+ next_token_ids , next_token_ids_cpu , next_token_logprobs_cpu = self ._sample_and_scatter_token (
107+ logits = model_output .logits ,
108+ b_req_idx = model_input .b_req_idx ,
109+ b_mtp_index = model_input .b_mtp_index ,
110+ run_reqs = run_reqs ,
111+ is_prefill = True ,
112+ b_prefill_has_output_cpu = model_input .b_prefill_has_output_cpu ,
113+ mask_func = self .prefill_mask_func ,
107114 )
108115 sync_event = torch .cuda .Event ()
109116 sync_event .record ()
@@ -133,8 +140,14 @@ def decode_normal(
133140 ):
134141 model_input , run_reqs = prepare_decode_inputs (decode_reqs )
135142 with torch .cuda .stream (g_infer_context .get_overlap_stream ()):
136- _ , next_token_ids_cpu , next_token_logprobs_cpu , _ = self ._main_model_forward (
137- model_input , run_reqs , self .decode_mask_func
143+ model_output = self .model .forward (model_input )
144+ next_token_ids , next_token_ids_cpu , next_token_logprobs_cpu = self ._sample_and_scatter_token (
145+ logits = model_output .logits ,
146+ b_req_idx = model_input .b_req_idx ,
147+ b_mtp_index = model_input .b_mtp_index ,
148+ run_reqs = run_reqs ,
149+ is_prefill = False ,
150+ mask_func = self .decode_mask_func ,
138151 )
139152 sync_event = torch .cuda .Event ()
140153 sync_event .record ()
@@ -167,8 +180,15 @@ def prefill_mtp(
167180 prefill_reqs , is_chuncked_mode = not self .disable_chunked_prefill , is_multimodal = self .is_multimodal
168181 )
169182 with torch .cuda .stream (g_infer_context .get_overlap_stream ()):
170- next_token_ids , next_token_ids_cpu , next_token_logprobs_cpu , model_output = self ._main_model_forward (
171- model_input , run_reqs , self .prefill_mask_func
183+ model_output = self .model .forward (model_input )
184+ next_token_ids , next_token_ids_cpu , next_token_logprobs_cpu = self ._sample_and_scatter_token (
185+ logits = model_output .logits ,
186+ b_req_idx = model_input .b_req_idx ,
187+ b_mtp_index = model_input .b_mtp_index ,
188+ run_reqs = run_reqs ,
189+ is_prefill = True ,
190+ b_prefill_has_output_cpu = model_input .b_prefill_has_output_cpu ,
191+ mask_func = self .prefill_mask_func ,
172192 )
173193 # mtp kv fill
174194 self ._draft_prefill_forward (model_input , model_output , self .prefill_mtp_step , next_token_ids )
@@ -201,7 +221,7 @@ def decode_mtp(
201221 decode_reqs : List [InferReq ],
202222 ):
203223 if self .is_mtp_eagle :
204- draft_model_input , _ , eagle_mem_indexes_cpu = prepare_eagle_decode_inputs (decode_reqs , self .mtp_step )
224+ draft_model_input , eagle_mem_indexes_cpu = prepare_eagle_decode_inputs (decode_reqs , self .mtp_step )
205225 self ._decode_mtp_common (
206226 event_pack = event_pack ,
207227 decode_reqs = decode_reqs ,
@@ -218,39 +238,6 @@ def decode_mtp(
218238 )
219239 return
220240
221- def _main_model_forward (
222- self , model_input : ModelInput , run_reqs : List [InferReq ], mask_func : Optional [Callable ] = None
223- ):
224- model_output = self .model .forward (model_input )
225- logits = model_output .logits
226-
227- if mask_func is not None :
228- mask_func (run_reqs , logits )
229-
230- next_token_ids , next_token_logprobs = sample (logits , run_reqs , self .eos_id )
231- b_has_out = None
232- if model_input .is_prefill :
233- b_has_out = g_pin_mem_manager .gen_from_list (
234- key = "b_has_out" , data = model_input .b_prefill_has_output_cpu , dtype = torch .bool
235- ).cuda (non_blocking = True )
236-
237- scatter_token (
238- next_token_ids = next_token_ids ,
239- req_to_next_token_ids = self .model .req_manager .req_sampling_params_manager .req_to_next_token_ids ,
240- b_req_idx = model_input .b_req_idx ,
241- b_mtp_index = model_input .b_mtp_index ,
242- b_has_out = b_has_out ,
243- )
244- g_infer_context .req_sampling_manager .update_reqs_out_token_counter_gpu (
245- b_req_idx = model_input .b_req_idx ,
246- next_token_ids = next_token_ids ,
247- mask = b_has_out ,
248- )
249- next_token_ids_cpu , next_token_logprobs_cpu = self ._async_copy_next_token_infos_to_pin_mem (
250- next_token_ids , next_token_logprobs
251- )
252- return next_token_ids , next_token_ids_cpu , next_token_logprobs_cpu , model_output
253-
254241 def _draft_prefill_forward (
255242 self , model_input : ModelInput , model_output : ModelOutput , mtp_step : int , next_token_ids : torch .Tensor
256243 ):
0 commit comments