@@ -311,127 +311,82 @@ def _decode(
311311 return self ._token_forward (model_input .input_ids , infer_state )
312312
313313 @torch .no_grad ()
314- def microbatch_overlap_decode (self , batch : DecodeMicroBatch , batch1 : DecodeMicroBatch ):
315- assert batch .batch_size == batch1 .batch_size
316- assert batch .mem_indexes .is_cuda
317- assert batch1 .mem_indexes .is_cuda
318- input_ids , input_ids1 = batch .input_ids , batch1 .input_ids
319-
320- def create_inferstate (cur_batch : DecodeMicroBatch , batch_index ):
321- infer_state = self .infer_state_class ()
322- infer_state .is_prefill = False
323- infer_state .batch_size = cur_batch .batch_size
324- infer_state .total_token_num = cur_batch .total_token_num
325- infer_state .max_len_in_batch = cur_batch .max_len_in_batch
326- infer_state .use_dynamic_prompt_cache = self .use_dynamic_prompt_cache
327- assert cur_batch .b_req_idx .shape [0 ] == cur_batch .b_seq_len .shape [0 ]
328- infer_state .b_req_idx = cur_batch .b_req_idx
329- infer_state .b_seq_len = cur_batch .b_seq_len
330- infer_state .multimodal_params = None
331- infer_state .microbatch_index = batch_index
332-
333- infer_state .mem_manager = self .mem_manager
334- infer_state .req_manager = self .req_manager
335-
336- infer_state .mem_index = cur_batch .mem_indexes
337- infer_state .kv_buffer_shapedtype = (
338- (cur_batch .batch_size , self .tp_k_head_num_ + self .tp_v_head_num_ , self .head_dim_ ),
339- self .data_type ,
340- )
341- infer_state .dist_group = dist_group_manager .get_group (batch_index )
342- copy_kv_index_to_req (
343- self .req_manager .req_to_token_indexs , cur_batch .b_req_idx , cur_batch .b_seq_len , infer_state .mem_index
344- )
345- return infer_state
314+ def microbatch_overlap_decode (self , model_input0 : ModelInput , model_input1 : ModelInput ):
315+ assert model_input0 .batch_size == model_input1 .batch_size
316+ assert model_input0 .mem_indexes .is_cuda
317+ assert model_input1 .mem_indexes .is_cuda
318+ input_ids0 , input_ids1 = model_input0 .input_ids , model_input1 .input_ids
346319
347- infer_state = create_inferstate (batch , 0 )
348- infer_state1 = create_inferstate (batch1 , 1 )
320+ infer_state0 = self ._create_inferstate (model_input0 , 0 )
321+ copy_kv_index_to_req (
322+ self .req_manager .req_to_token_indexs , model_input0 .b_req_idx , model_input0 .b_seq_len , infer_state0 .mem_index
323+ )
324+ infer_state0 .init_some_extra_state (self , input_ids0 )
349325
350- infer_state .init_some_extra_state (self , input_ids )
326+ infer_state1 = self ._create_inferstate (model_input1 , 1 )
327+ copy_kv_index_to_req (
328+ self .req_manager .req_to_token_indexs , model_input1 .b_req_idx , model_input1 .b_seq_len , infer_state1 .mem_index
329+ )
351330 infer_state1 .init_some_extra_state (self , input_ids1 )
352331
353- batch_size = batch .batch_size
354- max_len_in_batch = max (batch .max_len_in_batch , batch1 .max_len_in_batch )
332+ batch_size = model_input0 .batch_size
333+ max_len_in_batch = max (model_input0 .max_len_in_batch , model_input1 .max_len_in_batch )
355334
356335 if self .graph is not None and self .graph .can_run (batch_size , max_len_in_batch ):
357336 if self .graph .need_capture (batch_size ):
358- infer_state .is_cuda_graph = True
337+ infer_state0 .is_cuda_graph = True
359338 infer_state1 .is_cuda_graph = True
360339
361- predict_logits , predict_logits1 = self .graph .capture_decode (
340+ model_output0 , model_output1 = self .graph .capture_decode (
362341 self ._overlap_tpsp_token_forward ,
363- input_ids ,
364- infer_state ,
342+ input_ids0 ,
343+ infer_state0 ,
365344 input_ids1 = input_ids1 ,
366345 infer_state1 = infer_state1 ,
367346 )
368347 else :
369- predict_logits , predict_logits1 = self .graph .replay (
370- input_ids , infer_state , input_ids1 = input_ids1 , infer_state1 = infer_state1
348+ model_output0 , model_output1 = self .graph .replay (
349+ input_ids0 , infer_state0 , input_ids1 = input_ids1 , infer_state1 = infer_state1
371350 )
372351 else :
373- predict_logits , predict_logits1 = self ._overlap_tpsp_token_forward (
374- input_ids , infer_state , input_ids1 = input_ids1 , infer_state1 = infer_state1
352+ model_output0 , model_output1 = self ._overlap_tpsp_token_forward (
353+ input_ids0 , infer_state0 , input_ids1 = input_ids1 , infer_state1 = infer_state1
375354 )
376- return predict_logits , predict_logits1
355+ return model_output0 , model_output1
377356
378357 @torch .no_grad ()
379- def microbatch_overlap_prefill (self , batch : PrefillMicroBatch , batch1 : PrefillMicroBatch ):
380- assert batch .mem_indexes .is_cuda
381- assert batch1 .mem_indexes .is_cuda
382- input_ids , input_ids1 = batch .input_ids , batch1 .input_ids
383-
384- def create_inferstate (cur_batch : PrefillMicroBatch , batch_index ):
385- infer_state = self .infer_state_class ()
386- infer_state .is_prefill = True
387- infer_state .is_token_healing = self .is_token_healing
388- infer_state .return_all_prompt_logics = self .return_all_prompt_logics
389- infer_state .use_dynamic_prompt_cache = self .use_dynamic_prompt_cache
390- infer_state .batch_size = cur_batch .batch_size
391- infer_state .total_token_num = cur_batch .total_token_num
392- infer_state .max_len_in_batch = cur_batch .max_len_in_batch
393- assert cur_batch .b_req_idx .shape [0 ] == cur_batch .b_seq_len .shape [0 ]
394- infer_state .b_req_idx = cur_batch .b_req_idx
395- infer_state .b_seq_len = cur_batch .b_seq_len
396- if cur_batch .b_ready_cache_len is not None :
397- infer_state .b_ready_cache_len = cur_batch .b_ready_cache_len
398- else :
399- infer_state .b_ready_cache_len = torch .zeros_like (
400- cur_batch .b_seq_len , dtype = cur_batch .b_seq_len .dtype , device = cur_batch .b_seq_len .device
401- )
402- infer_state .multimodal_params = cur_batch .multimodal_params
403- infer_state .microbatch_index = batch_index
358+ def microbatch_overlap_prefill (self , model_input0 : ModelInput , model_input1 : ModelInput ):
359+ assert model_input0 .mem_indexes .is_cuda
360+ assert model_input1 .mem_indexes .is_cuda
361+ input_ids0 , input_ids1 = model_input0 .input_ids , model_input1 .input_ids
404362
405- infer_state .mem_manager = self .mem_manager
406- infer_state .req_manager = self .req_manager
407-
408- infer_state .mem_index = cur_batch .mem_indexes
409- infer_state .kv_buffer_shapedtype = (
410- (cur_batch .input_ids .shape [0 ], self .tp_k_head_num_ + self .tp_v_head_num_ , self .head_dim_ ),
411- self .data_type ,
412- )
413- infer_state .dist_group = dist_group_manager .get_group (batch_index )
414- init_req_to_token_indexes (
415- self .req_manager .req_to_token_indexs ,
416- cur_batch .b_req_idx ,
417- cur_batch .b_seq_len ,
418- infer_state .b_ready_cache_len ,
419- cur_batch .max_len_in_batch ,
420- infer_state .mem_index ,
421- )
422- return infer_state
423-
424- infer_state = create_inferstate (batch , 0 )
425- infer_state1 = create_inferstate (batch1 , 1 )
426-
427- infer_state .init_some_extra_state (self , input_ids )
363+ infer_state0 = self ._create_inferstate (model_input0 , 0 )
364+ init_req_to_token_indexes (
365+ self .req_manager .req_to_token_indexs ,
366+ model_input0 .b_req_idx ,
367+ model_input0 .b_seq_len ,
368+ infer_state0 .b_ready_cache_len ,
369+ model_input0 .max_len_in_batch ,
370+ infer_state0 .mem_index ,
371+ )
372+ infer_state0 .init_some_extra_state (self , input_ids0 )
373+
374+ infer_state1 = self ._create_inferstate (model_input1 , 1 )
375+ init_req_to_token_indexes (
376+ self .req_manager .req_to_token_indexs ,
377+ model_input1 .b_req_idx ,
378+ model_input1 .b_seq_len ,
379+ infer_state1 .b_ready_cache_len ,
380+ model_input1 .max_len_in_batch ,
381+ infer_state1 .mem_index ,
382+ )
428383 infer_state1 .init_some_extra_state (self , input_ids1 )
429384
430- predict_logits , predict_logits1 = self ._overlap_tpsp_context_forward (
431- input_ids , infer_state , input_ids1 = input_ids1 , infer_state1 = infer_state1
385+ model_output0 , model_output1 = self ._overlap_tpsp_context_forward (
386+ input_ids0 , infer_state0 , input_ids1 = input_ids1 , infer_state1 = infer_state1
432387 )
433388 dist_group_manager .clear_deepep_buffer ()
434- return predict_logits , predict_logits1
389+ return model_output0 , model_output1
435390
436391 @final
437392 def _context_forward (self , input_ids , infer_state : InferStateInfo ):
@@ -508,9 +463,21 @@ def _overlap_tpsp_token_forward(
508463 predict_logits , predict_logits1 = self .post_infer .overlap_tpsp_token_forward (
509464 input_embs , input_embs1 , infer_state , infer_state1 , self .pre_post_weight
510465 )
511-
466+
512467 g_cache_manager .cache_env_out ()
513- return predict_logits , predict_logits1
468+ is_return_hidden_states = self .spec_algo .is_mtp () or (
469+ self .spec_algo .is_mtp_module () and not self .last_mtp_module
470+ )
471+ model_output = ModelOutput (
472+ logits = predict_logits ,
473+ hidden_states = input_embs if is_return_hidden_states else None ,
474+ )
475+
476+ model_output1 = ModelOutput (
477+ logits = predict_logits1 ,
478+ hidden_states = input_embs1 if is_return_hidden_states else None ,
479+ )
480+ return model_output , model_output1
514481
515482 @final
516483 def _overlap_tpsp_context_forward (
@@ -528,7 +495,21 @@ def _overlap_tpsp_context_forward(
528495 input_embs , input_embs1 , infer_state , infer_state1 , self .pre_post_weight
529496 )
530497 g_cache_manager .cache_env_out ()
531- return predict_logits , predict_logits1
498+
499+ is_return_hidden_states = self .spec_algo .is_mtp () or (
500+ self .spec_algo .is_mtp_module () and not self .last_mtp_module
501+ )
502+ model_output = ModelOutput (
503+ logits = predict_logits ,
504+ hidden_states = input_embs if is_return_hidden_states else None ,
505+ )
506+
507+ model_output1 = ModelOutput (
508+ logits = predict_logits1 ,
509+ hidden_states = input_embs1 if is_return_hidden_states else None ,
510+ )
511+
512+ return model_output , model_output1
532513
533514 @final
534515 @torch .no_grad ()
0 commit comments