@@ -288,6 +288,7 @@ def _preprocess(
288288 decoder_input : Tensor = None ,
289289 inference_context : BaseInferenceContext = None ,
290290 packed_seq_params : PackedSeqParams = None ,
291+ padding_mask : Optional [Tensor ] = None ,
291292 ):
292293 """Preprocesses inputs for the transformer decoder.
293294
@@ -304,7 +305,20 @@ def _preprocess(
304305 if decoder_input is not None :
305306 pass
306307 elif self .pre_process :
308+ if padding_mask is not None :
309+ assert padding_mask .shape == input_ids .shape , (
310+ f"padding_mask shape { padding_mask .shape } does not match "
311+ f"input_ids shape { input_ids .shape } "
312+ )
307313 decoder_input = self .embedding (input_ids = input_ids , position_ids = position_ids )
314+ if padding_mask is not None and self .config .sequence_parallel :
315+ padding_mask = (
316+ tensor_parallel .scatter_to_sequence_parallel_region (
317+ padding_mask .transpose (0 , 1 ).contiguous ()
318+ )
319+ .transpose (0 , 1 )
320+ .contiguous ()
321+ )
308322 else :
309323 # intermediate stage of pipeline
310324 # decoder will get hidden_states from encoder.input_tensor
@@ -423,6 +437,7 @@ def _preprocess(
423437 rotary_pos_cos ,
424438 rotary_pos_sin ,
425439 sequence_len_offset ,
440+ padding_mask ,
426441 )
427442 if rotary_pos_cos_sin is not None :
428443 # only in the case of flashinfer fused rope will we
@@ -466,6 +481,7 @@ def forward(
466481 * ,
467482 inference_params : Optional [BaseInferenceContext ] = None ,
468483 loss_mask : Optional [Tensor ] = None ,
484+ padding_mask : Optional [Tensor ] = None ,
469485 ) -> Tensor :
470486 """Forward function of the GPT Model This function passes the input tensors
471487 through the embedding layer, and then the decoder and finally into the post
@@ -476,6 +492,9 @@ def forward(
476492 Args:
477493 runtime_gather_output (bool): Gather output at runtime. Default None means
478494 `parallel_output` arg in the constructor will be used.
495+ padding_mask (Tensor, optional): Padding mask for MoE routing.
496+ Shape [bsz, seq_length]. True = padding (exclude), False = valid (include).
497+ Only used for MoE layers to exclude padding tokens from routing computations.
479498 """
480499 if self .config .fine_grained_activation_offloading :
481500 self .preprocess_for_fine_grained_offloading ()
@@ -488,13 +507,19 @@ def forward(
488507 decoder_input = decoder_input ,
489508 inference_context = inference_context ,
490509 packed_seq_params = packed_seq_params ,
510+ padding_mask = padding_mask ,
491511 )
492512
493- (decoder_input , rotary_pos_emb , rotary_pos_cos , rotary_pos_sin , sequence_len_offset ) = (
494- preproc_output [:5 ]
495- )
513+ (
514+ decoder_input ,
515+ rotary_pos_emb ,
516+ rotary_pos_cos ,
517+ rotary_pos_sin ,
518+ sequence_len_offset ,
519+ padding_mask ,
520+ ) = preproc_output [:6 ]
496521
497- rotary_pos_cos_sin = preproc_output [5 ] if len (preproc_output ) == 6 else None
522+ rotary_pos_cos_sin = preproc_output [6 ] if len (preproc_output ) == 7 else None
498523
499524 # Run decoder.
500525 hidden_states = self .decoder (
@@ -507,6 +532,7 @@ def forward(
507532 rotary_pos_cos_sin = rotary_pos_cos_sin ,
508533 packed_seq_params = packed_seq_params ,
509534 sequence_len_offset = sequence_len_offset ,
535+ padding_mask = padding_mask ,
510536 ** (extra_block_kwargs or {}),
511537 )
512538
@@ -723,6 +749,7 @@ def build_schedule_plan(
723749 runtime_gather_output : Optional [bool ] = None ,
724750 inference_params : Optional [BaseInferenceContext ] = None ,
725751 loss_mask : Optional [Tensor ] = None ,
752+ padding_mask : Optional [Tensor ] = None ,
726753 ):
727754 """Builds a computation schedule plan for the model.
728755
@@ -748,6 +775,7 @@ def build_schedule_plan(
748775 inference_params (InferenceParams, optional):
749776 Parameters for inference. Defaults to None.
750777 loss_mask (Optional[Tensor], optional): Loss mask. Defaults to None.
778+ padding_mask (Optional[Tensor], optional): Padding mask. Defaults to None.
751779
752780 Returns:
753781 TransformerModelChunkSchedulePlan: The model chunk schedule plan.
@@ -769,6 +797,7 @@ def build_schedule_plan(
769797 extra_block_kwargs ,
770798 runtime_gather_output ,
771799 loss_mask ,
800+ padding_mask ,
772801 )
773802
774803 def sharded_state_dict (
0 commit comments