@@ -438,6 +438,7 @@ async def _async_model_forward(
438438 ):
439439 """Model forward."""
440440 max_prefill_token_num = self .cache_config .max_prefill_token_num
441+ strategy = self .agent_strategy
441442
442443 class _OutputGather :
443444 """Output gather."""
@@ -469,7 +470,11 @@ def gather(self, output):
469470 def get_output (self ):
470471 """Get tmp_output."""
471472 if not return_logits :
472- return self ._output [:, - 1 :]
473+ seqlen = torch .full ((1 , ),
474+ self ._output .numel () // self ._output .size (- 1 ),
475+ device = self ._output .device ,
476+ dtype = self ._output .dtype )
477+ return strategy .slice_outputs (self ._output , seqlen )
473478 torch .cuda .synchronize ()
474479 return self ._output .to (self ._device )
475480
@@ -562,17 +567,14 @@ def _push_output(self, output: BatchedOutputs):
562567 self ._out_que .put_nowait ((output , event ))
563568
564569 @contextmanager
565- def _broadcast_next_token (self , next_token_ids : torch .Tensor , dist_ctx : DistContext = None , enable : bool = True ):
570+ def _broadcast_next_token (self , next_token_ids : torch .Tensor , extra_inputs : ExtraInputs , enable : bool = True ):
566571 if not enable :
567572 yield
568573 return
569574
570- if dist_ctx is None :
571- dist_ctx = get_dist_manager ().current_context ()
572- tp_gpu_group = dist_ctx .tp_gpu_group
573- handle = dist .broadcast (next_token_ids , src = 0 , group = tp_gpu_group , async_op = True )
574- yield
575- handle .wait ()
575+ dist_ctx = self .dist_ctx
576+ with self .agent_strategy .broadcast_next_token (next_token_ids , extra_inputs , dist_ctx ) as handle :
577+ yield handle
576578
577579 async def _async_step_background (
578580 self ,
@@ -698,6 +700,7 @@ async def __prepare_dp():
698700 seq_length = output .get ('seq_length' , inputs .seq_length )
699701 last_logits = self ._slice_outs (logits , seq_length ) # [bs, 1, prob] -> [bs, prob]
700702 extra_inputs = self .agent_strategy .slice_extra_inputs (extra_inputs , seq_length )
703+ model_metas = output .get ('model_metas' )
701704
702705 # output empty for dummy inputs
703706 if is_dummy :
@@ -711,47 +714,40 @@ async def __prepare_dp():
711714 # sampling
712715 next_token_ids , logprobs = await self .async_sampling_logits (last_logits , sampling_inputs , inputs )
713716
714- with self ._broadcast_next_token (next_token_ids , dist_ctx , enable = need_broadcast_next ):
715- logger .debug (f'<ForwardTask> rank[{ rank } ]: synchronize token ids [{ idx } ]' )
717+ # post sampling
718+ next_token_ids , extra_inputs = self .agent_strategy .post_sampling (inputs , last_logits , next_token_ids ,
719+ extra_inputs )
716720
717- # post sampling
718- next_token_ids , extra_inputs = self .agent_strategy .post_sampling (
719- inputs , last_logits , next_token_ids , extra_inputs )
721+ with self ._broadcast_next_token (next_token_ids , extra_inputs , enable = need_broadcast_next ):
722+ logger .debug (f'<ForwardTask> rank[{ rank } ]: synchronize token ids [{ idx } ]' )
720723
721724 # stopping criteria
722725 stopped , stop_pos , stopping_criteria = stopping_criteria .step (next_token_ids ,
723726 sampling_inputs .stop_words ,
724727 inputs = inputs ,
725728 extra_inputs = extra_inputs )
729+
730+ # send output
731+ logger .debug (f'<ForwardTask> rank[{ rank } ]: Output [{ idx } ]' )
732+ extra_outputs = self .agent_strategy .make_extra_outputs (extra_inputs )
733+ self ._push_output (
734+ BatchedOutputs (next_token_ids = next_token_ids ,
735+ logits = logits if return_logits else None ,
736+ stopped = stopped ,
737+ stop_pos = stop_pos ,
738+ model_metas = model_metas ,
739+ logprobs = logprobs ,
740+ extra_outputs = extra_outputs ))
726741 else :
727742 # Avoid adding the ADInplaceOrView dispatch key to `next_token_ids`,
728743 # as it can trigger recompilation on different ranks when using torch.compile.
729- with torch .inference_mode ():
730- next_token_ids = inputs .input_ids .new_zeros (last_logits .size (0 ))
731- logprobs = None
744+ next_token_ids , extra_inputs = self .agent_strategy .make_dummy_next_token (
745+ inputs , last_logits , extra_inputs )
732746
733747 # broadcast next token for TP > 1
734- with self ._broadcast_next_token (next_token_ids , dist_ctx , enable = need_broadcast_next ):
748+ with self ._broadcast_next_token (next_token_ids , extra_inputs , enable = need_broadcast_next ):
735749 logger .debug (f'<ForwardTask> rank[{ rank } ]: synchronize token ids [{ idx } ]' )
736750
737- # post sampling
738- next_token_ids , extra_inputs = self .agent_strategy .post_sampling (inputs , last_logits , next_token_ids ,
739- extra_inputs )
740-
741- # send output
742- model_metas = output .get ('model_metas' )
743- if need_output :
744- logger .debug (f'<ForwardTask> rank[{ rank } ]: Output [{ idx } ]' )
745- extra_outputs = self .agent_strategy .make_extra_outputs (extra_inputs )
746- self ._push_output (
747- BatchedOutputs (next_token_ids = next_token_ids ,
748- logits = logits if return_logits else None ,
749- stopped = stopped ,
750- stop_pos = stop_pos ,
751- model_metas = model_metas ,
752- logprobs = logprobs ,
753- extra_outputs = extra_outputs ))
754-
755751 # update for next loop
756752 if is_decoding and idx < loop_count - 1 :
757753 inputs , extra_inputs = __update_inputs (next_token_ids , model_metas , extra_inputs )
0 commit comments