diff --git a/megatron/core/fusions/linear_cross_entropy/blackwell/entry.py b/megatron/core/fusions/linear_cross_entropy/blackwell/entry.py index dc369a7c558..07e018b51ff 100644 --- a/megatron/core/fusions/linear_cross_entropy/blackwell/entry.py +++ b/megatron/core/fusions/linear_cross_entropy/blackwell/entry.py @@ -345,7 +345,8 @@ def backward( and num_valid_tokens.dtype == torch.int64 ) - d_hidden = torch.empty_like(global_hidden) + # Allocate d_hidden in float32 for better numerical stability + d_hidden = torch.empty_like(global_hidden, dtype=torch.float32) d_weight = torch.empty_like(weight) assert d_hidden.is_contiguous() and d_weight.is_contiguous() @@ -435,14 +436,15 @@ def backward( ) valid_d_logits = _d_logits[:, :vocab_right_bound] - torch.addmm( - input=d_hidden.view(-1, dim), - mat1=valid_d_logits, - mat2=weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :], - beta=(split_idx != 0), - alpha=1.0, - out=d_hidden.view(-1, dim), - ) + _delta_hidden = torch.mm( + valid_d_logits, + weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :], + out_dtype=torch.float32, + ).view_as(d_hidden) + if split_idx == 0: + d_hidden.copy_(_delta_hidden) + else: + d_hidden.add_(_delta_hidden) torch.matmul( valid_d_logits.T, hidden_view, @@ -466,6 +468,9 @@ def backward( ] d_hidden = d_hidden.view(partial_hidden_shape).clone() + # convert d_hidden to the original dtype + d_hidden = d_hidden.type_as(global_hidden) + return d_hidden, d_weight except ImportError: diff --git a/megatron/core/models/common/language_module/language_module.py b/megatron/core/models/common/language_module/language_module.py index 13d74aa5271..259bb716a93 100644 --- a/megatron/core/models/common/language_module/language_module.py +++ b/megatron/core/models/common/language_module/language_module.py @@ -1,7 +1,7 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import logging import os -from typing import Any, Dict, Literal, Optional, Tuple +from typing import Optional, Tuple import torch from torch import Tensor @@ -14,7 +14,6 @@ except: te_parallel_cross_entropy = None from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy -from megatron.core.fusions.fused_linear_cross_entropy import linear_cross_entropy from megatron.core.pipeline_parallel.utils import ( is_pp_first_stage, is_pp_last_stage, @@ -126,68 +125,6 @@ def check_and_set_env_variable( check_and_set_env_variable("NVTE_FUSED_ATTN", 1, AttnBackend.auto) check_and_set_env_variable("NVTE_UNFUSED_ATTN", 1, AttnBackend.auto) - def compute_output_layer_and_language_model_loss( - self, - hidden: Tensor, - labels: Optional[Tensor], - weight: Tensor = None, - sequence_parallel_enabled: bool = False, - column_parallel_linear: torch.nn.Module = None, - col_linear_kwargs: Dict[str, Any] = {}, - reduction: Literal["none", "sum", "mean"] = "none", - ignore_index: int = -100, - ) -> Tensor: - """Computes the language model logits and loss (Cross entropy across vocabulary) - - Args: - hidden (Tensor): The hidden states from the transformer model - labels (Optional[Tensor]): The labels of dimension [batch size, seq length] - weight (Tensor): The weight tensor of shape [vocab size, hidden size]. - Required if using fused linear cross entropy. - column_parallel_linear (torch.nn.Module): The column parallel linear - layer to use for computing logits when not using fused linear cross entropy. - col_linear_kwargs (Dict[str, Any]): Additional kwargs for column parallel linear layer - reduction (Optional[str]): The reduction method. Defaults to "none", and can be - one of "none", "sum", "mean". - ignore_index (Optional[int]): The index to ignore in the loss calculation. - Defaults to -100. - - Returns: - Tensor: Loss tensor of dimensions [batch size, sequence_length]. - """ - if ( - self.config.cross_entropy_loss_fusion - and self.config.cross_entropy_fusion_impl == 'linear' - ): - assert ( - weight is not None - ), "weight cannot be None when using fused linear cross entropy." - assert ( - labels is not None - ), "labels cannot be None when using fused linear cross entropy." - # [b s] => [s b] - labels = labels.transpose(0, 1).contiguous() - loss = linear_cross_entropy( - hidden, - weight, - labels, - tp_group=self.pg_collection.tp, - sequence_parallel=sequence_parallel_enabled, - reduction=reduction, - ignore_index=ignore_index, - ) - - # [s b] => [b, s] - loss = loss.view_as(labels).transpose(0, 1).contiguous() - return loss - else: - assert ( - column_parallel_linear is not None - ), "column_parallel_linear cannot be None when not using fused linear cross entropy." - logits, _ = column_parallel_linear(hidden, **col_linear_kwargs) - - return self.compute_language_model_loss(labels, logits) - def compute_language_model_loss(self, labels: Tensor, logits: Tensor) -> Tensor: """Computes the language model loss (Cross entropy across vocabulary) diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 16462d6e426..2caa29bf062 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -6,7 +6,7 @@ import torch from torch import Tensor -from megatron.core import parallel_state, tensor_parallel +from megatron.core import parallel_state from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.inference.contexts import BaseInferenceContext @@ -25,6 +25,7 @@ from megatron.core.quantization.utils import get_quant_config_or_none from megatron.core.tensor_parallel import gather_from_sequence_parallel_region from megatron.core.transformer.enums import CudaGraphScope, ModelType +from megatron.core.transformer.linear_cross_entropy import LinearCrossEntropyModule from megatron.core.transformer.multi_token_prediction import ( MTPLossAutoScaler, MTPLossLoggingHelper, @@ -234,7 +235,7 @@ def __init__( self.embedding_activation_buffer = None self.grad_output_buffer = None - self.output_layer = tensor_parallel.ColumnParallelLinear( + self.output_layer = LinearCrossEntropyModule( config.hidden_size, self.vocab_size, config=config, @@ -614,16 +615,11 @@ def _postprocess( ) # Compute mtp loss without storing logits to save memory. - mtp_loss = self.compute_output_layer_and_language_model_loss( - hidden_states_list[mtp_layer_number + 1], - labels=mtp_labels, - weight=self.shared_embedding_or_output_weight(), - sequence_parallel_enabled=self.output_layer.sequence_parallel, - column_parallel_linear=self.output_layer, - col_linear_kwargs={ - 'weight': output_weight, - 'runtime_gather_output': runtime_gather_output, - }, + mtp_loss = self.output_layer( + output_cross_entropy_loss=True, + input_=hidden_states_list[mtp_layer_number + 1], + weight=output_weight, + runtime_gather_output=runtime_gather_output, ) mtp_loss = loss_mask * mtp_loss @@ -702,16 +698,12 @@ def _postprocess( # [s b h] => [b s h] return logits.transpose(0, 1).contiguous() - loss = self.compute_output_layer_and_language_model_loss( - hidden_states, + loss = self.output_layer( + output_cross_entropy_loss=True, + input_=hidden_states, labels=labels, - weight=self.shared_embedding_or_output_weight(), - sequence_parallel_enabled=self.output_layer.sequence_parallel, - column_parallel_linear=self.output_layer, - col_linear_kwargs={ - 'weight': output_weight, - 'runtime_gather_output': runtime_gather_output, - }, + weight=output_weight, + runtime_gather_output=runtime_gather_output, ) return loss diff --git a/megatron/core/models/mamba/mamba_model.py b/megatron/core/models/mamba/mamba_model.py index e4074eda806..6026b6275ba 100644 --- a/megatron/core/models/mamba/mamba_model.py +++ b/megatron/core/models/mamba/mamba_model.py @@ -4,7 +4,6 @@ from torch import Tensor -from megatron.core import tensor_parallel from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding @@ -15,6 +14,7 @@ from megatron.core.tensor_parallel import gather_from_sequence_parallel_region from megatron.core.transformer import TransformerConfig from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.linear_cross_entropy import LinearCrossEntropyModule from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.utils import WrappedTensor, deprecate_inference_params @@ -131,7 +131,7 @@ def __init__( # Output if post_process: - self.output_layer = tensor_parallel.ColumnParallelLinear( + self.output_layer = LinearCrossEntropyModule( config.hidden_size, self.vocab_size, config=config, @@ -285,16 +285,12 @@ def forward( # [s b h] => [b s h] return logits.transpose(0, 1).contiguous() - loss = self.compute_output_layer_and_language_model_loss( - hidden_states, - labels, - weight=self.shared_embedding_or_output_weight(), - sequence_parallel_enabled=self.output_layer.sequence_parallel, - column_parallel_linear=self.output_layer, - col_linear_kwargs={ - "weight": output_weight, - "runtime_gather_output": runtime_gather_output, - }, + loss = self.output_layer( + output_cross_entropy_loss=True, + input_=hidden_states, + labels=labels, + weight=output_weight, + runtime_gather_output=runtime_gather_output, ) return loss diff --git a/megatron/core/transformer/linear_cross_entropy.py b/megatron/core/transformer/linear_cross_entropy.py new file mode 100644 index 00000000000..e2f151cfdf1 --- /dev/null +++ b/megatron/core/transformer/linear_cross_entropy.py @@ -0,0 +1,133 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +from typing import Literal, Optional, Tuple, Union + +import torch + +from megatron.core import tensor_parallel +from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy +from megatron.core.fusions.fused_linear_cross_entropy import linear_cross_entropy +from megatron.core.transformer.enums import CudaGraphScope +from megatron.core.utils import is_te_min_version + +try: + from megatron.core.extensions.transformer_engine import te_parallel_cross_entropy +except: + te_parallel_cross_entropy = None + + +class LinearCrossEntropyModule(tensor_parallel.ColumnParallelLinear): + """ + A module that combines a ColumnParallelLinear layer with fused + linear + cross-entropy loss computation over a tensor-parallel vocabulary. + """ + + def forward( + self, + input_: torch.Tensor, + weight: Optional[torch.Tensor] = None, + runtime_gather_output: Optional[bool] = None, + output_cross_entropy_loss: bool = False, + labels: Optional[torch.Tensor] = None, + reduction: Literal["none", "sum", "mean"] = "none", + ignore_index: int = -100, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: + """Run either the plain ColumnParallelLinear or fused linear+cross-entropy.""" + if output_cross_entropy_loss: + return self._compute_linear_and_cross_entropy_loss( + hidden=input_, + weight=weight if weight is not None else self.weight, + labels=labels, + reduction=reduction, + ignore_index=ignore_index, + ) + + # Fall back to standard ColumnParallelLinear forward. + # ColumnParallelLinear.forward returns (output, bias) or just output + # depending on configuration, so keep the return type as Tensor. + return super().forward(input_, weight, runtime_gather_output) + + def _compute_linear_and_cross_entropy_loss( + self, + hidden: torch.Tensor, + weight: torch.Tensor, + runtime_gather_output: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + reduction: Literal["none", "sum", "mean"] = "none", + ignore_index: int = -100, + ) -> torch.Tensor: + """Compute fused linear + cross-entropy over tensor-parallel vocab.""" + if ( + self.config.cross_entropy_loss_fusion + and self.config.cross_entropy_fusion_impl == 'linear' + ): + assert ( + weight is not None + ), "weight cannot be None when using fused linear cross entropy." + assert ( + labels is not None + ), "labels cannot be None when using fused linear cross entropy." + + # [b s] => [s b] + labels = labels.transpose(0, 1).contiguous() + loss = linear_cross_entropy( + hidden, + self.weight, + labels, + sequence_parallel=self.sequence_parallel, + reduction=reduction, + ignore_index=ignore_index, + tp_group=self.tp_group, + ) + # If reduction != "none" this will be a scalar; for "none" it should + # match [s, b] and can be reshaped back to [b, s]. + if reduction == "none": + loss = loss.view_as(labels).transpose(0, 1).contiguous() + else: + logits, _ = super().forward(hidden, weight, runtime_gather_output) + loss = self._compute_cross_entropy_loss(labels, logits) + + return loss + + def _compute_cross_entropy_loss( + self, labels: torch.Tensor, logits: torch.Tensor + ) -> Optional[torch.Tensor]: + """Compute (possibly fused) vocab-parallel cross-entropy loss.""" + loss = None + + # [b s] => [s b] + labels = labels.transpose(0, 1).contiguous() + if self.config.cross_entropy_loss_fusion: + if self.config.cross_entropy_fusion_impl == 'te': + if te_parallel_cross_entropy is not None: + labels = torch.as_strided(labels, labels.size(), (labels.size()[1], 1)) + # Use is_cg_capturable=True for full iteration CUDA graphs + # to avoid torch.equal checks + is_cg_capturable = ( + hasattr(self.config, 'cuda_graph_scope') + and CudaGraphScope.full_iteration in self.config.cuda_graph_scope + ) + if is_cg_capturable and not is_te_min_version("2.7.0"): + from megatron.core.utils import get_te_version + + current_version = get_te_version() + raise AssertionError( + f"CUDA graph compatible cross entropy requires " + f"TransformerEngine >= 2.7.0, but found version {current_version}. " + "Please upgrade TransformerEngine " + f"or set cuda_graph_scope to a value other than 'full_iteration'." + ) + + loss = te_parallel_cross_entropy( + logits, labels, self.tp_group, is_cg_capturable + ) + else: + raise RuntimeError("Trying to use a TE block when it's not present.") + elif self.config.cross_entropy_fusion_impl == 'native': + loss = fused_vocab_parallel_cross_entropy(logits, labels, self.tp_group) + else: + loss = tensor_parallel.vocab_parallel_cross_entropy(logits, labels) + + # [s b] => [b, s] + loss = loss.transpose(0, 1).contiguous() + return loss