|
6 | 6 | import torch |
7 | 7 | from torch import Tensor |
8 | 8 |
|
9 | | -from megatron.core import tensor_parallel |
| 9 | +from megatron.core import parallel_state, tensor_parallel |
10 | 10 | from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk |
11 | 11 | from megatron.core.dist_checkpointing.mapping import ShardedStateDict |
12 | 12 | from megatron.core.inference.contexts import BaseInferenceContext |
|
26 | 26 | from megatron.core.tensor_parallel import gather_from_sequence_parallel_region |
27 | 27 | from megatron.core.transformer.enums import CudaGraphScope, ModelType |
28 | 28 | from megatron.core.transformer.multi_token_prediction import ( |
| 29 | + MTPLossAutoScaler, |
| 30 | + MTPLossLoggingHelper, |
29 | 31 | MultiTokenPredictionBlock, |
30 | | - mtp_on_this_rank, |
31 | | - process_mtp_loss, |
| 32 | + roll_tensor, |
| 33 | + tie_word_embeddings_state_dict, |
32 | 34 | ) |
33 | 35 | from megatron.core.transformer.spec_utils import ModuleSpec |
34 | 36 | from megatron.core.transformer.transformer_block import TransformerBlock |
@@ -142,9 +144,7 @@ def __init__( |
142 | 144 | self.rotary_base = rotary_base |
143 | 145 | self.rotary_scaling = rope_scaling |
144 | 146 | self.mtp_block_spec = mtp_block_spec |
145 | | - self.mtp_process = mtp_block_spec is not None and mtp_on_this_rank( |
146 | | - self.config, ignore_virtual=False, vp_stage=vp_stage |
147 | | - ) |
| 147 | + self.mtp_process = mtp_block_spec is not None |
148 | 148 |
|
149 | 149 | if self.pre_process or self.mtp_process: |
150 | 150 | self.embedding = LanguageModelEmbedding( |
@@ -609,19 +609,56 @@ def _postprocess( |
609 | 609 | return hidden_states |
610 | 610 |
|
611 | 611 | if self.config.mtp_num_layers is not None: |
612 | | - hidden_states = process_mtp_loss( |
613 | | - hidden_states=hidden_states, |
614 | | - labels=labels, |
615 | | - loss_mask=loss_mask, |
616 | | - output_layer=self.output_layer, |
617 | | - output_weight=output_weight, |
618 | | - runtime_gather_output=runtime_gather_output, |
619 | | - is_training=self.training, |
620 | | - compute_language_model_loss=self.compute_language_model_loss, |
621 | | - config=self.config, |
622 | | - cp_group=self.pg_collection.cp, |
623 | | - packed_seq_params=packed_seq_params, |
624 | | - ) |
| 612 | + mtp_labels = labels.clone() |
| 613 | + hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) |
| 614 | + hidden_states = hidden_states_list[0] |
| 615 | + if loss_mask is None: |
| 616 | + # if loss_mask is not provided, use all ones as loss_mask |
| 617 | + loss_mask = torch.ones_like(mtp_labels) |
| 618 | + for mtp_layer_number in range(self.config.mtp_num_layers): |
| 619 | + # output |
| 620 | + mtp_logits, _ = self.output_layer( |
| 621 | + hidden_states_list[mtp_layer_number + 1], |
| 622 | + weight=output_weight, |
| 623 | + runtime_gather_output=runtime_gather_output, |
| 624 | + ) |
| 625 | + # Calc loss for the current Multi-Token Prediction (MTP) layers. |
| 626 | + mtp_labels, _ = roll_tensor( |
| 627 | + mtp_labels, |
| 628 | + shifts=-1, |
| 629 | + dims=-1, |
| 630 | + cp_group=self.cp_group, |
| 631 | + packed_seq_params=packed_seq_params, |
| 632 | + ) |
| 633 | + loss_mask, num_tokens = roll_tensor( |
| 634 | + loss_mask, |
| 635 | + shifts=-1, |
| 636 | + dims=-1, |
| 637 | + cp_group=self.cp_group, |
| 638 | + packed_seq_params=packed_seq_params, |
| 639 | + ) |
| 640 | + mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits) |
| 641 | + mtp_loss = loss_mask * mtp_loss |
| 642 | + if self.training: |
| 643 | + # TODO(shifangx): remove the use of parallel_state here |
| 644 | + # after moving loss logging to loss_func in pretrain_gpt.py |
| 645 | + MTPLossLoggingHelper.save_loss_to_tracker( |
| 646 | + torch.sum(mtp_loss) / num_tokens, |
| 647 | + mtp_layer_number, |
| 648 | + self.config.mtp_num_layers, |
| 649 | + avg_group=parallel_state.get_data_parallel_group( |
| 650 | + with_context_parallel=True |
| 651 | + ), |
| 652 | + ) |
| 653 | + mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers |
| 654 | + if self.config.calculate_per_token_loss: |
| 655 | + hidden_states = MTPLossAutoScaler.apply( |
| 656 | + hidden_states, mtp_loss_scale * mtp_loss |
| 657 | + ) |
| 658 | + else: |
| 659 | + hidden_states = MTPLossAutoScaler.apply( |
| 660 | + hidden_states, mtp_loss_scale * mtp_loss / num_tokens |
| 661 | + ) |
625 | 662 | sequence_parallel_override = False |
626 | 663 |
|
627 | 664 | if in_inference_mode and inference_context.materialize_only_last_token_logits: |
@@ -678,6 +715,27 @@ def _postprocess( |
678 | 715 |
|
679 | 716 | return loss |
680 | 717 |
|
| 718 | + def shared_embedding_or_output_weight(self) -> Tensor: |
| 719 | + """Gets the embedding weight or output logit weights when share input embedding and |
| 720 | + output weights set to True or when use Multi-Token Prediction (MTP) feature. |
| 721 | +
|
| 722 | + Returns: |
| 723 | + Tensor: During pre processing or MTP process it returns the input embeddings weight. |
| 724 | + Otherwise, during post processing it returns the final output layers weight. |
| 725 | + """ |
| 726 | + if self.pre_process or self.mtp_process: |
| 727 | + # Multi-Token Prediction (MTP) need both embedding layer and output layer. |
| 728 | + # So there will be both embedding layer and output layer in the mtp process stage. |
| 729 | + # In this case, if share_embeddings_and_output_weights is True, the shared weights |
| 730 | + # will be stored in embedding layer, and output layer will not have any weight. |
| 731 | + assert hasattr( |
| 732 | + self, 'embedding' |
| 733 | + ), f"embedding is needed in this pipeline stage, but it is not initialized." |
| 734 | + return self.embedding.word_embeddings.weight |
| 735 | + elif self.post_process: |
| 736 | + return self.output_layer.weight |
| 737 | + return None |
| 738 | + |
681 | 739 | def build_schedule_plan( |
682 | 740 | self, |
683 | 741 | input_ids: Tensor, |
@@ -768,4 +826,20 @@ def sharded_state_dict( |
768 | 826 | output_extra_state and output_extra_state.data |
769 | 827 | ), f'Expected output layer extra state to be empty, got: {output_extra_state}' |
770 | 828 |
|
| 829 | + # Multi-Token Prediction (MTP) need embedding layer in mtp process stage. |
| 830 | + # If MTP is not placed in the pre processing stage, we need to maintain a copy of |
| 831 | + # embedding layer in the mtp process stage and tie it to the embedding in the pre |
| 832 | + # processing stage. |
| 833 | + # Now MTP loss is computed in post processing stage, so the output_layer is not needed. |
| 834 | + if self.mtp_process and not self.pre_process: |
| 835 | + emb_weight_key = f'{prefix}embedding.word_embeddings.weight' |
| 836 | + emb_weight = self.embedding.word_embeddings.weight |
| 837 | + tie_word_embeddings_state_dict( |
| 838 | + sharded_state_dict, |
| 839 | + emb_weight, |
| 840 | + emb_weight_key, |
| 841 | + tp_group=self.tp_group, |
| 842 | + dp_cp_group=metadata['dp_cp_group'], |
| 843 | + ) |
| 844 | + |
771 | 845 | return sharded_state_dict |
0 commit comments