Skip to content

Commit 10c6f01

Browse files
HaochenYuanPhlip79
andauthored
Remove calculation of padding token in moe routing loss (#2142)
Co-authored-by: Philip Petrakian <ppetrakian@nvidia.com>
1 parent 03e0915 commit 10c6f01

File tree

16 files changed

+639
-88
lines changed

16 files changed

+639
-88
lines changed

megatron/core/extensions/transformer_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2161,7 +2161,7 @@ def forward_post_hook(module, *_) -> None:
21612161
"TEFusedMLP module does not support submodules with post-backward hooks"
21622162
)
21632163

2164-
def forward(self, hidden_states: torch.Tensor) -> Tuple[Tensor, Optional[Tensor]]:
2164+
def forward(self, hidden_states: torch.Tensor, **kwargs) -> Tuple[Tensor, Optional[Tensor]]:
21652165
"""Forward."""
21662166

21672167
# Construct fused impl if needed

megatron/core/models/common/model_chunk_schedule_plan.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ def __init__(
281281
extra_block_kwargs=None,
282282
runtime_gather_output: Optional[bool] = None,
283283
loss_mask: Optional[Tensor] = None,
284+
padding_mask=None,
284285
):
285286
"""Initialize the schedule plan of all Transformer layers' sub-modules.
286287
@@ -323,6 +324,7 @@ def __init__(
323324
self._model_chunk_state.mtp_hidden_states = None
324325
self._model_chunk_state.loss_mask = loss_mask
325326
self._model_chunk_state.packed_seq_params = packed_seq_params
327+
self._model_chunk_state.padding_mask = padding_mask
326328
self._model_chunk_state.extra_block_kwargs = extra_block_kwargs
327329
self._model_chunk_state.runtime_gather_output = runtime_gather_output
328330
self._model_chunk_state.model = model

megatron/core/models/gpt/fine_grained_callables.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,19 @@ def forward_impl(self):
131131
if not self.gpt_model.pre_process:
132132
self.chunk_state.decoder_input = self.gpt_model.decoder.input_tensor
133133
# Run GPTModel._preprocess
134-
decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = (
135-
self.gpt_model._preprocess(
136-
input_ids=self.chunk_state.input_ids,
137-
position_ids=self.chunk_state.position_ids,
138-
decoder_input=self.chunk_state.decoder_input,
139-
packed_seq_params=self.chunk_state.packed_seq_params,
140-
)
134+
(
135+
decoder_input,
136+
rotary_pos_emb,
137+
rotary_pos_cos,
138+
rotary_pos_sin,
139+
sequence_len_offset,
140+
padding_mask,
141+
) = self.gpt_model._preprocess(
142+
input_ids=self.chunk_state.input_ids,
143+
position_ids=self.chunk_state.position_ids,
144+
decoder_input=self.chunk_state.decoder_input,
145+
packed_seq_params=self.chunk_state.packed_seq_params,
146+
padding_mask=self.chunk_state.padding_mask,
141147
)
142148

143149
# Saved for later use
@@ -146,6 +152,7 @@ def forward_impl(self):
146152
self.chunk_state.rotary_pos_cos = rotary_pos_cos
147153
self.chunk_state.rotary_pos_sin = rotary_pos_sin
148154
self.chunk_state.sequence_len_offset = sequence_len_offset
155+
self.chunk_state.padding_mask = padding_mask
149156
return decoder_input
150157

151158

megatron/core/models/gpt/gpt_model.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def _preprocess(
288288
decoder_input: Tensor = None,
289289
inference_context: BaseInferenceContext = None,
290290
packed_seq_params: PackedSeqParams = None,
291+
padding_mask: Optional[Tensor] = None,
291292
):
292293
"""Preprocesses inputs for the transformer decoder.
293294
@@ -304,7 +305,20 @@ def _preprocess(
304305
if decoder_input is not None:
305306
pass
306307
elif self.pre_process:
308+
if padding_mask is not None:
309+
assert padding_mask.shape == input_ids.shape, (
310+
f"padding_mask shape {padding_mask.shape} does not match "
311+
f"input_ids shape {input_ids.shape}"
312+
)
307313
decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
314+
if padding_mask is not None and self.config.sequence_parallel:
315+
padding_mask = (
316+
tensor_parallel.scatter_to_sequence_parallel_region(
317+
padding_mask.transpose(0, 1).contiguous()
318+
)
319+
.transpose(0, 1)
320+
.contiguous()
321+
)
308322
else:
309323
# intermediate stage of pipeline
310324
# decoder will get hidden_states from encoder.input_tensor
@@ -423,6 +437,7 @@ def _preprocess(
423437
rotary_pos_cos,
424438
rotary_pos_sin,
425439
sequence_len_offset,
440+
padding_mask,
426441
)
427442
if rotary_pos_cos_sin is not None:
428443
# only in the case of flashinfer fused rope will we
@@ -466,6 +481,7 @@ def forward(
466481
*,
467482
inference_params: Optional[BaseInferenceContext] = None,
468483
loss_mask: Optional[Tensor] = None,
484+
padding_mask: Optional[Tensor] = None,
469485
) -> Tensor:
470486
"""Forward function of the GPT Model This function passes the input tensors
471487
through the embedding layer, and then the decoder and finally into the post
@@ -476,6 +492,9 @@ def forward(
476492
Args:
477493
runtime_gather_output (bool): Gather output at runtime. Default None means
478494
`parallel_output` arg in the constructor will be used.
495+
padding_mask (Tensor, optional): Padding mask for MoE routing.
496+
Shape [bsz, seq_length]. True = padding (exclude), False = valid (include).
497+
Only used for MoE layers to exclude padding tokens from routing computations.
479498
"""
480499
if self.config.fine_grained_activation_offloading:
481500
self.preprocess_for_fine_grained_offloading()
@@ -488,13 +507,19 @@ def forward(
488507
decoder_input=decoder_input,
489508
inference_context=inference_context,
490509
packed_seq_params=packed_seq_params,
510+
padding_mask=padding_mask,
491511
)
492512

493-
(decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset) = (
494-
preproc_output[:5]
495-
)
513+
(
514+
decoder_input,
515+
rotary_pos_emb,
516+
rotary_pos_cos,
517+
rotary_pos_sin,
518+
sequence_len_offset,
519+
padding_mask,
520+
) = preproc_output[:6]
496521

497-
rotary_pos_cos_sin = preproc_output[5] if len(preproc_output) == 6 else None
522+
rotary_pos_cos_sin = preproc_output[6] if len(preproc_output) == 7 else None
498523

499524
# Run decoder.
500525
hidden_states = self.decoder(
@@ -507,6 +532,7 @@ def forward(
507532
rotary_pos_cos_sin=rotary_pos_cos_sin,
508533
packed_seq_params=packed_seq_params,
509534
sequence_len_offset=sequence_len_offset,
535+
padding_mask=padding_mask,
510536
**(extra_block_kwargs or {}),
511537
)
512538

@@ -723,6 +749,7 @@ def build_schedule_plan(
723749
runtime_gather_output: Optional[bool] = None,
724750
inference_params: Optional[BaseInferenceContext] = None,
725751
loss_mask: Optional[Tensor] = None,
752+
padding_mask: Optional[Tensor] = None,
726753
):
727754
"""Builds a computation schedule plan for the model.
728755
@@ -748,6 +775,7 @@ def build_schedule_plan(
748775
inference_params (InferenceParams, optional):
749776
Parameters for inference. Defaults to None.
750777
loss_mask (Optional[Tensor], optional): Loss mask. Defaults to None.
778+
padding_mask (Optional[Tensor], optional): Padding mask. Defaults to None.
751779
752780
Returns:
753781
TransformerModelChunkSchedulePlan: The model chunk schedule plan.
@@ -769,6 +797,7 @@ def build_schedule_plan(
769797
extra_block_kwargs,
770798
runtime_gather_output,
771799
loss_mask,
800+
padding_mask,
772801
)
773802

774803
def sharded_state_dict(

megatron/core/models/mamba/mamba_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def forward(
185185
*,
186186
inference_params: Optional[BaseInferenceContext] = None,
187187
packed_seq_params: Optional[PackedSeqParams] = None,
188+
padding_mask: Optional[Tensor] = None,
188189
) -> Tensor:
189190
"""Forward function of the Mamba model. This function passes the input tensors
190191
through the embedding layer, and then the decoder and finally into the post
@@ -254,6 +255,7 @@ def forward(
254255
inference_context=inference_context,
255256
rotary_pos_emb=rotary_pos_emb,
256257
packed_seq_params=packed_seq_params,
258+
padding_mask=padding_mask,
257259
)
258260

259261
if not self.post_process:

megatron/core/ssm/mamba_block.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def forward(
211211
*,
212212
inference_params: Optional[BaseInferenceContext] = None,
213213
packed_seq_params: Optional[PackedSeqParams] = None,
214+
padding_mask=None,
214215
):
215216
"""
216217
Forward function of the MambaStack class.
@@ -293,6 +294,7 @@ def forward(
293294
rotary_pos_emb=rotary_pos_emb,
294295
sequence_len_offset=sequence_len_offset,
295296
packed_seq_params=packed_seq_params,
297+
padding_mask=padding_mask,
296298
)
297299
else: # MambaLayer
298300
hidden_states = layer(

megatron/core/transformer/mlp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def __init__(
148148
tp_group=tp_group,
149149
)
150150

151-
def forward(self, hidden_states, per_token_scale=None):
151+
def forward(self, hidden_states, per_token_scale=None, **kwargs):
152152
"""Perform the forward pass through the MLP block."""
153153
# [s, b, 4 * h/p]
154154
nvtx_range_push(suffix="linear_fc1")

megatron/core/transformer/moe/moe_layer.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -239,13 +239,13 @@ def __init__(
239239
self.cudagraph_tensor_store = MoECudaGraphTensorStore()
240240

241241
@maybe_skip_or_early_return_by_cudagraph("route")
242-
def route(self, hidden_states: torch.Tensor):
242+
def route(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
243243
"""Compute token routing for preprocessing.
244244
245245
This method uses the router to determine which experts to send each token to,
246246
producing routing probabilities and a mapping.
247247
"""
248-
probs, routing_map = apply_module(self.router)(hidden_states)
248+
probs, routing_map = apply_module(self.router)(hidden_states, padding_mask)
249249
return probs, routing_map
250250

251251
@maybe_skip_or_early_return_by_cudagraph("preprocess")
@@ -346,7 +346,7 @@ def router_and_preprocess(self, hidden_states: torch.Tensor):
346346
hidden_states, probs, residual = self.preprocess(hidden_states, probs, routing_map)
347347
return hidden_states, probs, residual
348348

349-
def forward(self, hidden_states: torch.Tensor):
349+
def forward(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
350350
"""Forward pass for the MoE layer.
351351
352352
The forward pass comprises four main steps:
@@ -356,8 +356,10 @@ def forward(self, hidden_states: torch.Tensor):
356356
4. Combine: The outputs from the experts are combined and returned.
357357
358358
Args:
359-
hidden_states (torch.Tensor): The input tensor to the MoE layer.
360-
359+
hidden_states (torch.Tensor): The input tensor shape [seq_length, bsz, hidden_size].
360+
padding_mask (torch.Tensor, optional): Boolean mask indicating non-padding tokens.
361+
Shape [seq_length, bsz]. True for valid tokens,
362+
False for padding tokens. Defaults to None.
361363
Returns:
362364
A tuple containing the output tensor and the MLP bias, if any.
363365
"""
@@ -366,12 +368,15 @@ def forward(self, hidden_states: torch.Tensor):
366368
"During training, performance may degrade if MoE and tensor parallelism"
367369
"are enabled without also enabling sequence parallelism."
368370
)
371+
# Transpose from [bsz, seq_length] to [seq_length, bsz] to align with hidden_states
372+
if padding_mask is not None:
373+
padding_mask = padding_mask.transpose(0, 1).bool()
369374

370375
# MoE forward: route -> dispatch -> compute -> combine
371376
def custom_forward(hidden_states):
372377
try:
373378
shared_expert_output = self.shared_experts_compute(hidden_states)
374-
probs, routing_map = self.route(hidden_states)
379+
probs, routing_map = self.route(hidden_states, padding_mask)
375380
hidden_states, probs = self.preprocess(hidden_states, probs, routing_map)
376381
except MoECudaGraphPartialCaptureSignal as e:
377382
# This signal is raised from the maybe_skip_or_early_return_by_cudagraph decorator.
@@ -398,7 +403,9 @@ def custom_forward(hidden_states):
398403
hidden_states,
399404
)
400405
else:
401-
outputs = tensor_parallel.checkpoint(custom_forward, False, hidden_states)
406+
outputs = tensor_parallel.checkpoint(
407+
custom_forward, False, hidden_states, padding_mask
408+
)
402409
else:
403410
outputs = custom_forward(hidden_states)
404411

0 commit comments

Comments
 (0)