@@ -104,14 +104,9 @@ def prefill_normal(
104104 model_input .b_req_idx ,
105105 model_input .b_mtp_index ,
106106 )
107- next_token_ids_cpu = g_pin_mem_manager . alloc_pin_tensor (
108- " next_token_ids" , next_token_ids . shape [ 0 ], next_token_ids . dtype
107+ next_token_ids_cpu , next_token_logprobs_cpu = self . _save_next_token_ids_and_logprobs (
108+ next_token_ids , next_token_logprobs
109109 )
110- next_token_logprobs_cpu = g_pin_mem_manager .alloc_pin_tensor (
111- "next_token_logprobs" , next_token_logprobs .shape [0 ], next_token_logprobs .dtype
112- )
113- next_token_ids_cpu .copy_ (next_token_ids , non_blocking = True )
114- next_token_logprobs_cpu .copy_ (next_token_logprobs , non_blocking = True )
115110 sync_event = torch .cuda .Event ()
116111 sync_event .record ()
117112
@@ -152,14 +147,9 @@ def decode_normal(
152147 model_input .b_req_idx ,
153148 model_input .b_mtp_index ,
154149 )
155- next_token_ids_cpu = g_pin_mem_manager .alloc_pin_tensor (
156- "next_token_ids" , next_token_ids .shape [0 ], next_token_ids .dtype
157- )
158- next_token_logprobs_cpu = g_pin_mem_manager .alloc_pin_tensor (
159- "next_token_logprobs" , next_token_logprobs .shape [0 ], next_token_logprobs .dtype
150+ next_token_ids_cpu , next_token_logprobs_cpu = self ._save_next_token_ids_and_logprobs (
151+ next_token_ids , next_token_logprobs
160152 )
161- next_token_ids_cpu .copy_ (next_token_ids , non_blocking = True )
162- next_token_logprobs_cpu .copy_ (next_token_logprobs , non_blocking = True )
163153 sync_event = torch .cuda .Event ()
164154 sync_event .record ()
165155
@@ -190,13 +180,44 @@ def prefill_mtp(
190180 model_input , run_reqs = prepare_prefill_inputs (
191181 prefill_reqs , is_chuncked_mode = not self .disable_chunked_prefill , is_multimodal = self .is_multimodal
192182 )
193- model_output = self .model .forward (model_input )
183+ with torch .cuda .stream (g_infer_context .get_overlap_stream ()):
184+ model_output = self .model .forward (model_input )
185+ next_token_ids , next_token_logprobs = sample (model_output .logits , run_reqs , self .eos_id )
186+
187+ scatter_token (
188+ next_token_ids ,
189+ self .model .req_manager .req_sampling_params_manager .req_to_next_token_ids ,
190+ model_input .b_req_idx ,
191+ model_input .b_mtp_index ,
192+ )
193+ next_token_ids_cpu , next_token_logprobs_cpu = self ._save_next_token_ids_and_logprobs (
194+ next_token_ids , next_token_logprobs
195+ )
196+ # mtp kv fill
197+ draft_next_token_ids_gpu = next_token_ids
198+ draft_model_output = model_output
199+ draft_model_input = model_input
200+ # spec prefill: MTP, 这个地方只是为了填充draft model的 kv, 并不会使用生成的token_id。
201+ for draft_model_idx in range (self .mtp_step ):
202+ draft_model_input = prepare_mtp_prefill_inputs (
203+ model_input = draft_model_input ,
204+ b_next_token_ids = draft_next_token_ids_gpu ,
205+ deepseekv3_mtp_draft_input_hiddens = draft_model_output .deepseekv3_mtp_main_output_hiddens ,
206+ )
207+ draft_model_output = self .draft_models [draft_model_idx ].forward (draft_model_input )
208+ draft_next_token_ids_gpu = self ._gen_argmax_token_ids (draft_model_output )
194209
195- next_token_ids_gpu , next_token_probs = sample (model_output .logits , run_reqs , self .eos_id )
196- next_token_ids_cpu = next_token_ids_gpu .detach ().cpu ().numpy ()
197- next_token_logprobs_cpu = torch .log (next_token_probs ).detach ().cpu ().numpy ()
210+ sync_event = torch .cuda .Event ()
211+ sync_event .record ()
198212
213+ # 第二阶段
214+ event_pack .notify_post_handle_and_wait_pre_post_handle ()
199215 update_packs = self ._pre_post_handle (run_reqs , is_chuncked_mode = not self .disable_chunked_prefill )
216+
217+ # 第三阶段
218+ event_pack .notify_forward_and_wait_post_handle ()
219+ sync_event .synchronize ()
220+
200221 self ._post_handle (
201222 run_reqs = run_reqs ,
202223 next_token_ids = next_token_ids_cpu ,
@@ -205,20 +226,8 @@ def prefill_mtp(
205226 extra_post_req_handle_func = self .extra_post_req_handle_func ,
206227 )
207228
208- # mtp kv fill
209- draft_next_token_ids_gpu = next_token_ids_gpu
210- draft_model_output = model_output
211- draft_model_input = model_input
212- # spec prefill: MTP, 这个地方只是为了填充draft model的 kv, 并不会使用生成的token_id。
213- for draft_model_idx in range (self .mtp_step ):
214- draft_model_input = prepare_mtp_prefill_inputs (
215- model_input = draft_model_input ,
216- b_next_token_ids = draft_next_token_ids_gpu ,
217- deepseekv3_mtp_draft_input_hiddens = draft_model_output .deepseekv3_mtp_main_output_hiddens ,
218- )
219-
220- draft_model_output = self .draft_models [draft_model_idx ].forward (draft_model_input )
221- draft_next_token_ids_gpu , draft_next_token_ids_cpu = self ._gen_argmax_token_ids (draft_model_output )
229+ # 第四阶段
230+ event_pack .notify_pre_post_handle ()
222231 return
223232
224233 def decode_mtp (
@@ -227,47 +236,69 @@ def decode_mtp(
227236 decode_reqs : List [InferReq ],
228237 ):
229238 model_input , run_reqs = prepare_decode_inputs (decode_reqs )
230- model_output = self .model .forward (model_input )
239+ with torch .cuda .stream (g_infer_context .get_overlap_stream ()):
240+ model_output = self .model .forward (model_input )
231241
232- next_token_ids_gpu , next_token_probs = sample (model_output .logits , run_reqs , self .eos_id )
233- next_token_ids_cpu = next_token_ids_gpu .detach ().cpu ().numpy ()
234- next_token_logprobs_cpu = torch .log (next_token_probs ).detach ().cpu ().numpy ()
242+ next_token_ids , next_token_probs = sample (model_output .logits , run_reqs , self .eos_id )
243+ scatter_token (
244+ next_token_ids ,
245+ self .model .req_manager .req_sampling_params_manager .req_to_next_token_ids ,
246+ model_input .b_req_idx ,
247+ model_input .b_mtp_index ,
248+ )
249+ next_token_ids_cpu , next_token_logprobs_cpu = self ._save_next_token_ids_and_logprobs (
250+ next_token_ids , next_token_probs
251+ )
235252
236- # verify
237- mem_indexes_cpu = model_input .mem_indexes .detach ().cpu ().numpy ()
238- verify_ok_reqs , verify_ok_req_indexes , verify_ok_req_last_indexes , need_free_mem_indexes = self ._verify_mtp (
239- run_reqs , next_token_ids_cpu , mem_indexes_cpu
240- )
253+ # verify
254+ mem_indexes_cpu = model_input .mem_indexes .detach ().cpu ().numpy ()
255+ verify_ok_reqs , verify_ok_req_indexes , verify_ok_req_last_indexes , need_free_mem_indexes = self ._verify_mtp (
256+ run_reqs , next_token_ids_cpu , mem_indexes_cpu
257+ )
258+
259+ # share some inference info with the main model
260+ draft_model_input = model_input
261+ draft_model_output = model_output
262+ draft_next_token_ids = next_token_ids
263+ # process the draft model output
264+ for draft_model_idx in range (self .mtp_step ):
265+
266+ draft_model_input .input_ids = draft_next_token_ids
267+ draft_model_input .deepseekv3_mtp_draft_input_hiddens = (
268+ draft_model_output .deepseekv3_mtp_main_output_hiddens
269+ )
270+ # spec decode: MTP
271+ draft_model_output : ModelOutput = self .draft_models [draft_model_idx ].forward (draft_model_input )
272+ draft_next_token_ids = self ._gen_argmax_token_ids (draft_model_output )
273+
274+ unique_reqs = [run_reqs [index ] for index in verify_ok_req_last_indexes ]
275+ draft_next_token_ids_cpu = draft_next_token_ids .detach ().cpu ().numpy ()
276+ self ._update_reqs_mtp_gen_token_ids (
277+ reqs = unique_reqs , mtp_draft_next_token_ids = draft_next_token_ids_cpu [verify_ok_req_last_indexes ]
278+ )
279+ sync_event = torch .cuda .Event ()
280+ sync_event .record ()
241281
282+ # 第二阶段
283+ event_pack .notify_post_handle_and_wait_pre_post_handle ()
242284 update_packs = self ._pre_post_handle (verify_ok_reqs , is_chuncked_mode = False )
285+
286+ # 第三阶段
287+ event_pack .notify_forward_and_wait_post_handle ()
288+ sync_event .synchronize ()
289+
243290 self ._post_handle (
244291 run_reqs = verify_ok_reqs ,
245292 next_token_ids = next_token_ids_cpu [verify_ok_req_indexes ],
246293 next_token_logprobs = next_token_logprobs_cpu [verify_ok_req_indexes ],
247294 run_reqs_update_packs = update_packs ,
248295 extra_post_req_handle_func = self .extra_post_req_handle_func ,
249296 )
250-
251- # share some inference info with the main model
252- draft_model_input = model_input
253- draft_model_output = model_output
254- draft_next_token_ids = next_token_ids_gpu
255- # process the draft model output
256- for draft_model_idx in range (self .mtp_step ):
257-
258- draft_model_input .input_ids = draft_next_token_ids
259- draft_model_input .deepseekv3_mtp_draft_input_hiddens = draft_model_output .deepseekv3_mtp_main_output_hiddens
260- # spec decode: MTP
261- draft_model_output : ModelOutput = self .draft_models [draft_model_idx ].forward (draft_model_input )
262- draft_next_token_ids , draft_next_token_ids_cpu = self ._gen_argmax_token_ids (draft_model_output )
263-
264- unique_reqs = [run_reqs [index ] for index in verify_ok_req_last_indexes ]
265- self ._update_reqs_mtp_gen_token_ids (
266- reqs = unique_reqs , mtp_draft_next_token_ids = draft_next_token_ids_cpu [verify_ok_req_last_indexes ]
267- )
268-
269297 if need_free_mem_indexes :
270298 g_infer_state_lock .acquire ()
271299 g_infer_context .req_manager .mem_manager .free (need_free_mem_indexes )
272300 g_infer_state_lock .release ()
301+
302+ # 第四阶段
303+ event_pack .notify_pre_post_handle ()
273304 return
0 commit comments