@@ -45,7 +45,7 @@ def __init__(self) -> None:
4545 if self .enable_decode_microbatch_overlap :
4646 self .decode = self .decode_overlap_mtp
4747 else :
48- self .decode = self .decode_mtp_eagle
48+ self .decode = self .decode_mtp
4949 else :
5050 if self .enable_prefill_microbatch_overlap :
5151 self .prefill = self .prefill_overlap
@@ -396,7 +396,12 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq]
396396 def decode_mtp (self , event_pack : OverlapEventPack , decode_reqs : List [InferReq ]):
397397 model_input , run_reqs , padded_req_num = padded_prepare_decode_inputs (decode_reqs )
398398 b_mtp_index_cpu = model_input .b_mtp_index
399+ eagle_mem_indexes_cpu = None
399400 req_num = len (run_reqs )
401+ if self .is_mtp_eagle :
402+ draft_model_input , eagle_mem_indexes_cpu = padded_prepare_eagle_decode_inputs (
403+ decode_reqs , padded_req_num = padded_req_num , mtp_step = self .mtp_step
404+ )
400405
401406 with torch .cuda .stream (g_infer_context .get_overlap_stream ()):
402407 model_output = self .model .forward (model_input )
@@ -436,16 +441,25 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]):
436441
437442 verify_event = torch .cuda .Event ()
438443 verify_event .record ()
439-
440- self ._draft_decode_vanilla (
441- model_input = model_input ,
442- model_output = model_output ,
443- draft_next_token_ids_gpu = draft_next_token_ids_gpu ,
444- b_req_mtp_start_loc = b_req_mtp_start_loc [:real_req_num ],
445- mtp_accept_len = mtp_accept_len [:real_req_num ],
446- req_num = req_num ,
447- )
448-
444+ if self .is_mtp_eagle :
445+ self ._draft_decode_eagle (
446+ model_input = model_input ,
447+ model_output = model_output ,
448+ draft_next_token_ids_gpu = draft_next_token_ids_gpu ,
449+ mtp_accept_len = mtp_accept_len ,
450+ eagle_mem_indexes_cpu = eagle_mem_indexes_cpu ,
451+ draft_model_input = draft_model_input ,
452+ padded_req_num = padded_req_num ,
453+ )
454+ else :
455+ self ._draft_decode_vanilla (
456+ model_input = model_input ,
457+ model_output = model_output ,
458+ draft_next_token_ids_gpu = draft_next_token_ids_gpu ,
459+ b_req_mtp_start_loc = b_req_mtp_start_loc [:real_req_num ],
460+ mtp_accept_len = mtp_accept_len [:real_req_num ],
461+ req_num = req_num ,
462+ )
449463 if req_num > 0 :
450464 g_infer_context .req_sampling_manager .update_reqs_out_token_counter_gpu (
451465 b_req_idx = model_input .b_req_idx [:req_num ],
@@ -468,6 +482,8 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]):
468482 event_pack .notify_forward_and_wait_post_handle ()
469483 sync_event .synchronize ()
470484 need_free_mem_indexes = model_input .mem_indexes_cpu [0 :req_num ][accepted_index_cpu == 0 ]
485+ if self .is_mtp_eagle :
486+ need_free_mem_indexes = torch .cat ([need_free_mem_indexes , eagle_mem_indexes_cpu ], dim = 0 )
471487
472488 self ._update_mtp_accept_ratio (decode_reqs = decode_reqs , mtp_accept_len_cpu = mtp_accept_len_cpu )
473489 select_mask = torch .tensor (accepted_index_cpu , dtype = torch .bool , device = "cpu" )
@@ -528,122 +544,14 @@ def _draft_decode_vanilla(
528544 )
529545 return all_next_token_ids
530546
531- def decode_mtp_eagle (self , event_pack : OverlapEventPack , decode_reqs : List [InferReq ]):
532- model_input , run_reqs , padded_req_num = padded_prepare_decode_inputs (decode_reqs )
533- draft_model_input , eagle_mem_indexes_cpu = padded_prepare_eagle_decode_inputs (
534- decode_reqs , padded_req_num = padded_req_num , mtp_step = self .mtp_step
535- )
536- b_mtp_index_cpu = model_input .b_mtp_index
537- req_num = len (run_reqs )
538-
539- with torch .cuda .stream (g_infer_context .get_overlap_stream ()):
540- model_output = self .model .forward (model_input )
541- draft_next_token_ids_gpu = torch .zeros ((model_input .batch_size ), dtype = torch .int64 , device = "cuda" )
542-
543- if req_num > 0 :
544- logits = model_output .logits [0 :req_num , :]
545- next_token_ids , next_token_logprobs = sample (logits , run_reqs , self .eos_id )
546- next_token_ids_cpu , next_token_logprobs_cpu = self ._async_copy_next_token_infos_to_pin_mem (
547- next_token_ids , next_token_logprobs
548- )
549- draft_next_token_ids_gpu [0 :req_num ].copy_ (next_token_ids )
550-
551- # verify the next_token_ids
552- b_req_mtp_start_loc = [index for index , mtp_index in enumerate (b_mtp_index_cpu ) if mtp_index == 0 ]
553- b_req_mtp_start_loc = g_pin_mem_manager .gen_from_list (
554- key = "b_req_mtp_start_loc" ,
555- data = b_req_mtp_start_loc ,
556- dtype = torch .int32 ,
557- ).cuda (non_blocking = True )
558- # 真实的请求数,不包含mtp 扩充的部分
559- real_req_num = b_req_mtp_start_loc .shape [0 ] - padded_req_num
560-
561- mtp_accept_len , accepted_index = self ._verify_mtp_v2 (
562- new_next_token_ids = draft_next_token_ids_gpu ,
563- b_req_idx = model_input .b_req_idx ,
564- b_req_mtp_start_loc = b_req_mtp_start_loc ,
565- )
566- accepted_index_cpu = g_pin_mem_manager .async_copy_from_gpu_tensor (
567- key = "accepted_index" ,
568- gpu_tensor = accepted_index [:req_num ],
569- )
570- mtp_accept_len_cpu = g_pin_mem_manager .async_copy_from_gpu_tensor (
571- key = "mtp_accept_len" ,
572- gpu_tensor = mtp_accept_len [:real_req_num ],
573- )
574-
575- verify_event = torch .cuda .Event ()
576- verify_event .record ()
577-
578- self ._draft_decode_eagle (
579- model_input = model_input ,
580- model_output = model_output ,
581- draft_next_token_ids_gpu = draft_next_token_ids_gpu ,
582- b_req_mtp_start_loc = b_req_mtp_start_loc ,
583- mtp_accept_len = mtp_accept_len ,
584- eagle_mem_indexes_cpu = eagle_mem_indexes_cpu ,
585- draft_model_input = draft_model_input ,
586- req_num = req_num ,
587- padded_req_num = padded_req_num ,
588- )
589-
590- if req_num > 0 :
591- g_infer_context .req_sampling_manager .update_reqs_out_token_counter_gpu (
592- b_req_idx = model_input .b_req_idx [:req_num ],
593- next_token_ids = next_token_ids ,
594- mask = accepted_index [:req_num ] == 1 ,
595- )
596-
597- sync_event = torch .cuda .Event ()
598- sync_event .record ()
599-
600- if req_num > 0 :
601- # 第二阶段
602- accepted_index_cpu = accepted_index_cpu [:req_num ]
603- event_pack .notify_post_handle_and_wait_pre_post_handle ()
604- verify_event .synchronize ()
605- verify_ok_reqs = [run_reqs [i ] for i in range (len (run_reqs )) if accepted_index_cpu [i ] == 1 ]
606- update_packs = self ._pre_post_handle (verify_ok_reqs , is_chuncked_mode = False )
607-
608- # 第三阶段
609- event_pack .notify_forward_and_wait_post_handle ()
610- sync_event .synchronize ()
611- need_free_mem_indexes = torch .cat (
612- [model_input .mem_indexes_cpu [0 :req_num ][accepted_index_cpu == 0 ], eagle_mem_indexes_cpu ], dim = 0
613- )
614-
615- self ._update_mtp_accept_ratio (decode_reqs = decode_reqs , mtp_accept_len_cpu = mtp_accept_len_cpu )
616- select_mask = torch .tensor (accepted_index_cpu , dtype = torch .bool , device = "cpu" )
617- self ._post_handle (
618- run_reqs = verify_ok_reqs ,
619- next_token_ids = next_token_ids_cpu [select_mask ],
620- next_token_logprobs = next_token_logprobs_cpu [select_mask ],
621- run_reqs_update_packs = update_packs ,
622- extra_post_req_handle_func = self .extra_post_req_handle_func ,
623- )
624- if len (need_free_mem_indexes ) > 0 :
625- g_infer_state_lock .acquire ()
626- g_infer_context .req_manager .mem_manager .free (need_free_mem_indexes )
627- g_infer_state_lock .release ()
628-
629- # 第四阶段
630- event_pack .notify_pre_post_handle ()
631- else :
632- event_pack .notify_post_handle_and_wait_pre_post_handle ()
633- event_pack .notify_forward_and_wait_post_handle ()
634- event_pack .notify_pre_post_handle ()
635- return
636-
637547 def _draft_decode_eagle (
638548 self ,
639549 model_input : ModelInput ,
640550 model_output : ModelOutput ,
641551 draft_next_token_ids_gpu : torch .Tensor ,
642- b_req_mtp_start_loc : torch .Tensor ,
643552 mtp_accept_len : torch .Tensor ,
644553 eagle_mem_indexes_cpu : torch .Tensor ,
645554 draft_model_input : ModelInput ,
646- req_num : int ,
647555 padded_req_num : int ,
648556 ):
649557 all_next_token_ids = []
0 commit comments