Skip to content

Commit a0cc8ca

Browse files
committed
Revert "Add MTP support for hybrid models (#2363)"
This reverts commit 300d1b6.
1 parent 31d0c87 commit a0cc8ca

23 files changed

+205
-1170
lines changed

mamba_builders.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,15 @@
88
from megatron.training.arguments import core_transformer_config_from_args
99
from megatron.core.models.mamba.mamba_layer_specs import mamba_inference_stack_spec
1010

11-
1211
def mamba_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_collection=None):
1312
print_rank_0('building MAMBA model ...')
1413
if config is None:
1514
config = core_transformer_config_from_args(args, TransformerConfig)
1615
assert args.use_legacy_models is False, "Mamba only supported in Mcore!"
1716

1817
if config.transformer_impl == "inference_optimized":
19-
mamba_stack_spec = mamba_inference_stack_spec
20-
assert (
21-
not config.inference_fuse_tp_communication
22-
), "inference_fuse_tp_communication is not supported for Mamba"
18+
mamba_stack_spec = mamba_inference_stack_spec
19+
assert not config.inference_fuse_tp_communication, "inference_fuse_tp_communication is not supported for Mamba"
2320
elif args.spec is not None:
2421
mamba_stack_spec = import_module(args.spec)
2522
else:
@@ -42,7 +39,6 @@ def mamba_builder(args, pre_process, post_process, vp_stage=None, config=None, p
4239
rotary_percent=args.rotary_percent,
4340
rotary_base=args.rotary_base,
4441
pg_collection=pg_collection,
45-
vp_stage=vp_stage,
4642
)
4743

4844
for l in range(model.decoder.num_layers_per_pipeline_rank):

megatron/core/models/common/language_module/language_module.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from megatron.core.process_groups_config import ProcessGroupCollection
2424
from megatron.core.transformer.enums import AttnBackend, CudaGraphScope
2525
from megatron.core.transformer.module import MegatronModule
26-
from megatron.core.transformer.multi_token_prediction import tie_word_embeddings_state_dict
2726
from megatron.core.transformer.transformer_config import TransformerConfig
2827
from megatron.core.transformer.utils import ensure_metadata_has_dp_cp_group
2928
from megatron.core.utils import (
@@ -256,20 +255,12 @@ def setup_embeddings_and_output_layer(self) -> None:
256255
LanguageModule.embedding_warning_printed = True
257256

258257
def shared_embedding_or_output_weight(self) -> Tensor:
259-
"""Gets the embedding weight or output logit weights when share embedding and output weights set to True
260-
or when use Multi-Token Prediction (MTP).
258+
"""Gets the emedding weight or output logit weights when share embedding and output weights set to True.
261259
262260
Returns:
263-
Tensor: During pre processing or MTP process it returns the input embeddings weight while during post processing it returns the final output layers weight
261+
Tensor: During pre processing it returns the input embeddings weight while during post processing it returns the final output layers weight
264262
"""
265-
if self.pre_process or getattr(self, 'mtp_process', False):
266-
# Multi-Token Prediction (MTP) need both embedding layer and output layer.
267-
# So there will be both embedding layer and output layer in the mtp process stage.
268-
# When share_embeddings_and_output_weights is True, the embedding weight is the
269-
# canonical shared weight and is passed to the output layer during forward.
270-
assert hasattr(
271-
self, 'embedding'
272-
), f"embedding is needed in this pipeline stage, but it is not initialized."
263+
if self.pre_process:
273264
return self.embedding.word_embeddings.weight
274265
elif self.post_process:
275266
return self.output_layer.weight
@@ -302,21 +293,6 @@ def sharded_state_dict(
302293
output_layer_weight_key = f'{prefix}output_layer.weight'
303294
output_layer_bias_key = f'{prefix}output_layer.bias'
304295

305-
# Multi-Token Prediction (MTP) needs embedding layer in mtp process stage.
306-
# If MTP is not placed in the pre processing stage, we need to maintain a copy of
307-
# embedding layer in the mtp process stage and tie it to the embedding in the pre
308-
# processing stage.
309-
# Note: MTP loss is computed at post_process stage, so the output_layer on mtp_process
310-
# rank doesn't need special tying - it's not used for loss computation.
311-
if getattr(self, 'mtp_process', False) and not self.pre_process:
312-
emb_weight = self.embedding.word_embeddings.weight
313-
tie_word_embeddings_state_dict(
314-
sharded_state_dict,
315-
emb_weight,
316-
first_stage_word_emb_key,
317-
tp_group=self.tp_group,
318-
dp_cp_group=metadata['dp_cp_group'],
319-
)
320296
if self.share_embeddings_and_output_weights:
321297
self.tie_embeddings_and_output_weights_state_dict(
322298
sharded_state_dict, output_layer_weight_key, first_stage_word_emb_key, metadata

megatron/core/models/common/model_chunk_schedule_plan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def _build_callable_nodes(self, event, comp_stream, comm_stream, extra_args):
123123
# get flags for latter use
124124
is_mtp = isinstance(self.layer, MultiTokenPredictionLayer)
125125
is_moe = (
126-
isinstance(self.layer.mtp_model_layer.mlp, MoELayer)
126+
isinstance(self.layer.transformer_layer.mlp, MoELayer)
127127
if is_mtp
128128
else isinstance(self.layer.mlp, MoELayer)
129129
)

megatron/core/models/gpt/fine_grained_callables.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -613,9 +613,9 @@ def build_mtp_layer_callables(layer):
613613
multi-token prediction layer nodes (attention, MLP, etc.)
614614
"""
615615

616-
forward_funcs, backward_dw = build_transformer_layer_callables(layer.mtp_model_layer)
616+
forward_funcs, backward_dw = build_transformer_layer_callables(layer.transformer_layer)
617617
attn_forward, dispatch_forward, mlp_forward, combine_forward, _ = forward_funcs
618-
is_moe = isinstance(layer.mtp_model_layer.mlp, MoELayer)
618+
is_moe = isinstance(layer.transformer_layer.mlp, MoELayer)
619619
assert is_moe, "MTP layer in a2a overlap only supports MoE layer for now."
620620

621621
def submodule_mtp_attn_forward(node, hidden_states):

megatron/core/models/gpt/gpt_layer_specs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,7 @@ def get_gpt_mtp_block_spec_for_backend(
704704
raise ValueError(f"Invalid spec: {spec}")
705705

706706
mtp_layer_spec = get_mtp_layer_spec_for_backend(
707-
mtp_model_layer_spec=transformer_layer_spec, backend=backend
707+
transformer_layer_spec=transformer_layer_spec, backend=backend
708708
)
709709
mtp_num_layers = config.mtp_num_layers if config.mtp_num_layers else 0
710710
mtp_layer_specs = [mtp_layer_spec] * mtp_num_layers

megatron/core/models/gpt/gpt_model.py

Lines changed: 93 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
from torch import Tensor
88

9-
from megatron.core import tensor_parallel
9+
from megatron.core import parallel_state, tensor_parallel
1010
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
1111
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
1212
from megatron.core.inference.contexts import BaseInferenceContext
@@ -26,9 +26,11 @@
2626
from megatron.core.tensor_parallel import gather_from_sequence_parallel_region
2727
from megatron.core.transformer.enums import CudaGraphScope, ModelType
2828
from megatron.core.transformer.multi_token_prediction import (
29+
MTPLossAutoScaler,
30+
MTPLossLoggingHelper,
2931
MultiTokenPredictionBlock,
30-
mtp_on_this_rank,
31-
process_mtp_loss,
32+
roll_tensor,
33+
tie_word_embeddings_state_dict,
3234
)
3335
from megatron.core.transformer.spec_utils import ModuleSpec
3436
from megatron.core.transformer.transformer_block import TransformerBlock
@@ -142,9 +144,7 @@ def __init__(
142144
self.rotary_base = rotary_base
143145
self.rotary_scaling = rope_scaling
144146
self.mtp_block_spec = mtp_block_spec
145-
self.mtp_process = mtp_block_spec is not None and mtp_on_this_rank(
146-
self.config, ignore_virtual=False, vp_stage=vp_stage
147-
)
147+
self.mtp_process = mtp_block_spec is not None
148148

149149
if self.pre_process or self.mtp_process:
150150
self.embedding = LanguageModelEmbedding(
@@ -609,19 +609,56 @@ def _postprocess(
609609
return hidden_states
610610

611611
if self.config.mtp_num_layers is not None:
612-
hidden_states = process_mtp_loss(
613-
hidden_states=hidden_states,
614-
labels=labels,
615-
loss_mask=loss_mask,
616-
output_layer=self.output_layer,
617-
output_weight=output_weight,
618-
runtime_gather_output=runtime_gather_output,
619-
is_training=self.training,
620-
compute_language_model_loss=self.compute_language_model_loss,
621-
config=self.config,
622-
cp_group=self.pg_collection.cp,
623-
packed_seq_params=packed_seq_params,
624-
)
612+
mtp_labels = labels.clone()
613+
hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
614+
hidden_states = hidden_states_list[0]
615+
if loss_mask is None:
616+
# if loss_mask is not provided, use all ones as loss_mask
617+
loss_mask = torch.ones_like(mtp_labels)
618+
for mtp_layer_number in range(self.config.mtp_num_layers):
619+
# output
620+
mtp_logits, _ = self.output_layer(
621+
hidden_states_list[mtp_layer_number + 1],
622+
weight=output_weight,
623+
runtime_gather_output=runtime_gather_output,
624+
)
625+
# Calc loss for the current Multi-Token Prediction (MTP) layers.
626+
mtp_labels, _ = roll_tensor(
627+
mtp_labels,
628+
shifts=-1,
629+
dims=-1,
630+
cp_group=self.cp_group,
631+
packed_seq_params=packed_seq_params,
632+
)
633+
loss_mask, num_tokens = roll_tensor(
634+
loss_mask,
635+
shifts=-1,
636+
dims=-1,
637+
cp_group=self.cp_group,
638+
packed_seq_params=packed_seq_params,
639+
)
640+
mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits)
641+
mtp_loss = loss_mask * mtp_loss
642+
if self.training:
643+
# TODO(shifangx): remove the use of parallel_state here
644+
# after moving loss logging to loss_func in pretrain_gpt.py
645+
MTPLossLoggingHelper.save_loss_to_tracker(
646+
torch.sum(mtp_loss) / num_tokens,
647+
mtp_layer_number,
648+
self.config.mtp_num_layers,
649+
avg_group=parallel_state.get_data_parallel_group(
650+
with_context_parallel=True
651+
),
652+
)
653+
mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers
654+
if self.config.calculate_per_token_loss:
655+
hidden_states = MTPLossAutoScaler.apply(
656+
hidden_states, mtp_loss_scale * mtp_loss
657+
)
658+
else:
659+
hidden_states = MTPLossAutoScaler.apply(
660+
hidden_states, mtp_loss_scale * mtp_loss / num_tokens
661+
)
625662
sequence_parallel_override = False
626663

627664
if in_inference_mode and inference_context.materialize_only_last_token_logits:
@@ -678,6 +715,27 @@ def _postprocess(
678715

679716
return loss
680717

718+
def shared_embedding_or_output_weight(self) -> Tensor:
719+
"""Gets the embedding weight or output logit weights when share input embedding and
720+
output weights set to True or when use Multi-Token Prediction (MTP) feature.
721+
722+
Returns:
723+
Tensor: During pre processing or MTP process it returns the input embeddings weight.
724+
Otherwise, during post processing it returns the final output layers weight.
725+
"""
726+
if self.pre_process or self.mtp_process:
727+
# Multi-Token Prediction (MTP) need both embedding layer and output layer.
728+
# So there will be both embedding layer and output layer in the mtp process stage.
729+
# In this case, if share_embeddings_and_output_weights is True, the shared weights
730+
# will be stored in embedding layer, and output layer will not have any weight.
731+
assert hasattr(
732+
self, 'embedding'
733+
), f"embedding is needed in this pipeline stage, but it is not initialized."
734+
return self.embedding.word_embeddings.weight
735+
elif self.post_process:
736+
return self.output_layer.weight
737+
return None
738+
681739
def build_schedule_plan(
682740
self,
683741
input_ids: Tensor,
@@ -768,4 +826,20 @@ def sharded_state_dict(
768826
output_extra_state and output_extra_state.data
769827
), f'Expected output layer extra state to be empty, got: {output_extra_state}'
770828

829+
# Multi-Token Prediction (MTP) need embedding layer in mtp process stage.
830+
# If MTP is not placed in the pre processing stage, we need to maintain a copy of
831+
# embedding layer in the mtp process stage and tie it to the embedding in the pre
832+
# processing stage.
833+
# Now MTP loss is computed in post processing stage, so the output_layer is not needed.
834+
if self.mtp_process and not self.pre_process:
835+
emb_weight_key = f'{prefix}embedding.word_embeddings.weight'
836+
emb_weight = self.embedding.word_embeddings.weight
837+
tie_word_embeddings_state_dict(
838+
sharded_state_dict,
839+
emb_weight,
840+
emb_weight_key,
841+
tp_group=self.tp_group,
842+
dp_cp_group=metadata['dp_cp_group'],
843+
)
844+
771845
return sharded_state_dict

megatron/core/models/mamba/mamba_layer_specs.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
22

33
from megatron.core.extensions.transformer_engine import (
4-
TEColumnParallelLinear,
54
TEDotProductAttention,
65
TELayerNormColumnParallelLinear,
76
TENorm,
@@ -20,49 +19,20 @@
2019
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
2120
from megatron.core.transformer.enums import AttnMaskType
2221
from megatron.core.transformer.mlp import MLP, MLPSubmodules
23-
from megatron.core.transformer.multi_token_prediction import (
24-
MultiTokenPredictionBlock,
25-
MultiTokenPredictionBlockSubmodules,
26-
MultiTokenPredictionLayer,
27-
MultiTokenPredictionLayerSubmodules,
28-
)
2922
from megatron.core.transformer.spec_utils import ModuleSpec
3023
from megatron.core.transformer.transformer_layer import (
3124
MoETransformerLayer,
3225
TransformerLayer,
3326
TransformerLayerSubmodules,
3427
)
3528

36-
# This should be private and should not be used outside of this file.
3729
moe = get_moe_module_spec(
3830
use_te=True,
3931
num_experts=8, # Can be any positive integer (must not be None).
4032
moe_grouped_gemm=True,
4133
moe_use_legacy_grouped_gemm=False,
4234
)
4335

44-
45-
# MTP block spec for Mamba - provides norms and projection only.
46-
# Inner layers are built by MultiTokenPredictionLayer using nested MambaStack
47-
_mamba_mtp_block_spec = ModuleSpec(
48-
module=MultiTokenPredictionBlock,
49-
submodules=MultiTokenPredictionBlockSubmodules(
50-
layer_specs=[
51-
ModuleSpec(
52-
module=MultiTokenPredictionLayer,
53-
submodules=MultiTokenPredictionLayerSubmodules(
54-
enorm=TENorm,
55-
hnorm=TENorm,
56-
eh_proj=TEColumnParallelLinear,
57-
mtp_model_layer=None, # Built via pattern + mamba_submodules
58-
layer_norm=TENorm,
59-
),
60-
)
61-
]
62-
),
63-
)
64-
65-
6636
mamba_stack_spec = ModuleSpec(
6737
module=MambaStack,
6838
submodules=MambaStackSubmodules(
@@ -117,11 +87,9 @@
11787
pre_mlp_layernorm=TENorm, mlp=moe, mlp_bda=get_bias_dropout_add
11888
),
11989
),
120-
mtp_block_spec=_mamba_mtp_block_spec,
12190
),
12291
)
12392

124-
12593
mamba_inference_stack_spec = ModuleSpec(
12694
module=MambaStack,
12795
submodules=MambaStackSubmodules(
@@ -179,6 +147,5 @@
179147
pre_mlp_layernorm=TENorm, mlp=moe, mlp_bda=get_bias_dropout_add
180148
),
181149
),
182-
mtp_block_spec=_mamba_mtp_block_spec,
183150
),
184151
)

0 commit comments

Comments
 (0)