Skip to content
Open
23 changes: 14 additions & 9 deletions megatron/core/fusions/linear_cross_entropy/blackwell/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,8 @@ def backward(
and num_valid_tokens.dtype == torch.int64
)

d_hidden = torch.empty_like(global_hidden)
# Allocate d_hidden in float32 for better numerical stability
d_hidden = torch.empty_like(global_hidden, dtype=torch.float32)
d_weight = torch.empty_like(weight)
assert d_hidden.is_contiguous() and d_weight.is_contiguous()

Expand Down Expand Up @@ -435,14 +436,15 @@ def backward(
)
valid_d_logits = _d_logits[:, :vocab_right_bound]

torch.addmm(
input=d_hidden.view(-1, dim),
mat1=valid_d_logits,
mat2=weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :],
beta=(split_idx != 0),
alpha=1.0,
out=d_hidden.view(-1, dim),
)
_delta_hidden = torch.mm(
valid_d_logits,
weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :],
out_dtype=torch.float32,
).view_as(d_hidden)
if split_idx == 0:
d_hidden.copy_(_delta_hidden)
else:
d_hidden.add_(_delta_hidden)
torch.matmul(
valid_d_logits.T,
hidden_view,
Expand All @@ -466,6 +468,9 @@ def backward(
]
d_hidden = d_hidden.view(partial_hidden_shape).clone()

# convert d_hidden to the original dtype
d_hidden = d_hidden.type_as(global_hidden)

return d_hidden, d_weight

except ImportError:
Expand Down
65 changes: 1 addition & 64 deletions megatron/core/models/common/language_module/language_module.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
import logging
import os
from typing import Any, Dict, Literal, Optional, Tuple
from typing import Optional, Tuple

import torch
from torch import Tensor
Expand All @@ -14,7 +14,6 @@
except:
te_parallel_cross_entropy = None
from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy
from megatron.core.fusions.fused_linear_cross_entropy import linear_cross_entropy
from megatron.core.pipeline_parallel.utils import (
is_pp_first_stage,
is_pp_last_stage,
Expand Down Expand Up @@ -126,68 +125,6 @@ def check_and_set_env_variable(
check_and_set_env_variable("NVTE_FUSED_ATTN", 1, AttnBackend.auto)
check_and_set_env_variable("NVTE_UNFUSED_ATTN", 1, AttnBackend.auto)

def compute_output_layer_and_language_model_loss(
self,
hidden: Tensor,
labels: Optional[Tensor],
weight: Tensor = None,
sequence_parallel_enabled: bool = False,
column_parallel_linear: torch.nn.Module = None,
col_linear_kwargs: Dict[str, Any] = {},
reduction: Literal["none", "sum", "mean"] = "none",
ignore_index: int = -100,
) -> Tensor:
"""Computes the language model logits and loss (Cross entropy across vocabulary)

Args:
hidden (Tensor): The hidden states from the transformer model
labels (Optional[Tensor]): The labels of dimension [batch size, seq length]
weight (Tensor): The weight tensor of shape [vocab size, hidden size].
Required if using fused linear cross entropy.
column_parallel_linear (torch.nn.Module): The column parallel linear
layer to use for computing logits when not using fused linear cross entropy.
col_linear_kwargs (Dict[str, Any]): Additional kwargs for column parallel linear layer
reduction (Optional[str]): The reduction method. Defaults to "none", and can be
one of "none", "sum", "mean".
ignore_index (Optional[int]): The index to ignore in the loss calculation.
Defaults to -100.

Returns:
Tensor: Loss tensor of dimensions [batch size, sequence_length].
"""
if (
self.config.cross_entropy_loss_fusion
and self.config.cross_entropy_fusion_impl == 'linear'
):
assert (
weight is not None
), "weight cannot be None when using fused linear cross entropy."
assert (
labels is not None
), "labels cannot be None when using fused linear cross entropy."
# [b s] => [s b]
labels = labels.transpose(0, 1).contiguous()
loss = linear_cross_entropy(
hidden,
weight,
labels,
tp_group=self.pg_collection.tp,
sequence_parallel=sequence_parallel_enabled,
reduction=reduction,
ignore_index=ignore_index,
)

# [s b] => [b, s]
loss = loss.view_as(labels).transpose(0, 1).contiguous()
return loss
else:
assert (
column_parallel_linear is not None
), "column_parallel_linear cannot be None when not using fused linear cross entropy."
logits, _ = column_parallel_linear(hidden, **col_linear_kwargs)

return self.compute_language_model_loss(labels, logits)

def compute_language_model_loss(self, labels: Tensor, logits: Tensor) -> Tensor:
"""Computes the language model loss (Cross entropy across vocabulary)

Expand Down
34 changes: 13 additions & 21 deletions megatron/core/models/gpt/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from torch import Tensor

from megatron.core import parallel_state, tensor_parallel
from megatron.core import parallel_state
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.inference.contexts import BaseInferenceContext
Expand All @@ -25,6 +25,7 @@
from megatron.core.quantization.utils import get_quant_config_or_none
from megatron.core.tensor_parallel import gather_from_sequence_parallel_region
from megatron.core.transformer.enums import CudaGraphScope, ModelType
from megatron.core.transformer.linear_cross_entropy import LinearCrossEntropyModule
from megatron.core.transformer.multi_token_prediction import (
MTPLossAutoScaler,
MTPLossLoggingHelper,
Expand Down Expand Up @@ -234,7 +235,7 @@ def __init__(
self.embedding_activation_buffer = None
self.grad_output_buffer = None

self.output_layer = tensor_parallel.ColumnParallelLinear(
self.output_layer = LinearCrossEntropyModule(
config.hidden_size,
self.vocab_size,
config=config,
Expand Down Expand Up @@ -614,16 +615,11 @@ def _postprocess(
)

# Compute mtp loss without storing logits to save memory.
mtp_loss = self.compute_output_layer_and_language_model_loss(
hidden_states_list[mtp_layer_number + 1],
labels=mtp_labels,
weight=self.shared_embedding_or_output_weight(),
sequence_parallel_enabled=self.output_layer.sequence_parallel,
column_parallel_linear=self.output_layer,
col_linear_kwargs={
'weight': output_weight,
'runtime_gather_output': runtime_gather_output,
},
mtp_loss = self.output_layer(
output_cross_entropy_loss=True,
input_=hidden_states_list[mtp_layer_number + 1],
weight=output_weight,
runtime_gather_output=runtime_gather_output,
)

mtp_loss = loss_mask * mtp_loss
Expand Down Expand Up @@ -702,16 +698,12 @@ def _postprocess(
# [s b h] => [b s h]
return logits.transpose(0, 1).contiguous()

loss = self.compute_output_layer_and_language_model_loss(
hidden_states,
loss = self.output_layer(
output_cross_entropy_loss=True,
input_=hidden_states,
labels=labels,
weight=self.shared_embedding_or_output_weight(),
sequence_parallel_enabled=self.output_layer.sequence_parallel,
column_parallel_linear=self.output_layer,
col_linear_kwargs={
'weight': output_weight,
'runtime_gather_output': runtime_gather_output,
},
weight=output_weight,
runtime_gather_output=runtime_gather_output,
)

return loss
Expand Down
20 changes: 8 additions & 12 deletions megatron/core/models/mamba/mamba_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from torch import Tensor

from megatron.core import tensor_parallel
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
Expand All @@ -15,6 +14,7 @@
from megatron.core.tensor_parallel import gather_from_sequence_parallel_region
from megatron.core.transformer import TransformerConfig
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.linear_cross_entropy import LinearCrossEntropyModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.utils import WrappedTensor, deprecate_inference_params

Expand Down Expand Up @@ -131,7 +131,7 @@ def __init__(

# Output
if post_process:
self.output_layer = tensor_parallel.ColumnParallelLinear(
self.output_layer = LinearCrossEntropyModule(
config.hidden_size,
self.vocab_size,
config=config,
Expand Down Expand Up @@ -285,16 +285,12 @@ def forward(
# [s b h] => [b s h]
return logits.transpose(0, 1).contiguous()

loss = self.compute_output_layer_and_language_model_loss(
hidden_states,
labels,
weight=self.shared_embedding_or_output_weight(),
sequence_parallel_enabled=self.output_layer.sequence_parallel,
column_parallel_linear=self.output_layer,
col_linear_kwargs={
"weight": output_weight,
"runtime_gather_output": runtime_gather_output,
},
loss = self.output_layer(
output_cross_entropy_loss=True,
input_=hidden_states,
labels=labels,
weight=output_weight,
runtime_gather_output=runtime_gather_output,
)

return loss
133 changes: 133 additions & 0 deletions megatron/core/transformer/linear_cross_entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.

from typing import Literal, Optional, Tuple, Union

import torch

from megatron.core import tensor_parallel
from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy
from megatron.core.fusions.fused_linear_cross_entropy import linear_cross_entropy
from megatron.core.transformer.enums import CudaGraphScope
from megatron.core.utils import is_te_min_version

try:
from megatron.core.extensions.transformer_engine import te_parallel_cross_entropy
except:
te_parallel_cross_entropy = None


class LinearCrossEntropyModule(tensor_parallel.ColumnParallelLinear):
"""
A module that combines a ColumnParallelLinear layer with fused
linear + cross-entropy loss computation over a tensor-parallel vocabulary.
"""

def forward(
self,
input_: torch.Tensor,
weight: Optional[torch.Tensor] = None,
runtime_gather_output: Optional[bool] = None,
output_cross_entropy_loss: bool = False,
labels: Optional[torch.Tensor] = None,
reduction: Literal["none", "sum", "mean"] = "none",
ignore_index: int = -100,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""Run either the plain ColumnParallelLinear or fused linear+cross-entropy."""
if output_cross_entropy_loss:
return self._compute_linear_and_cross_entropy_loss(
hidden=input_,
weight=weight if weight is not None else self.weight,
labels=labels,
reduction=reduction,
ignore_index=ignore_index,
)

# Fall back to standard ColumnParallelLinear forward.
# ColumnParallelLinear.forward returns (output, bias) or just output
# depending on configuration, so keep the return type as Tensor.
return super().forward(input_, weight, runtime_gather_output)

def _compute_linear_and_cross_entropy_loss(
self,
hidden: torch.Tensor,
weight: torch.Tensor,
runtime_gather_output: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
reduction: Literal["none", "sum", "mean"] = "none",
ignore_index: int = -100,
) -> torch.Tensor:
"""Compute fused linear + cross-entropy over tensor-parallel vocab."""
if (
self.config.cross_entropy_loss_fusion
and self.config.cross_entropy_fusion_impl == 'linear'
):
assert (
weight is not None
), "weight cannot be None when using fused linear cross entropy."
assert (
labels is not None
), "labels cannot be None when using fused linear cross entropy."

# [b s] => [s b]
labels = labels.transpose(0, 1).contiguous()
loss = linear_cross_entropy(
hidden,
self.weight,
labels,
sequence_parallel=self.sequence_parallel,
reduction=reduction,
ignore_index=ignore_index,
tp_group=self.tp_group,
)
# If reduction != "none" this will be a scalar; for "none" it should
# match [s, b] and can be reshaped back to [b, s].
if reduction == "none":
loss = loss.view_as(labels).transpose(0, 1).contiguous()
else:
logits, _ = super().forward(hidden, weight, runtime_gather_output)
loss = self._compute_cross_entropy_loss(labels, logits)

return loss

def _compute_cross_entropy_loss(
self, labels: torch.Tensor, logits: torch.Tensor
) -> Optional[torch.Tensor]:
"""Compute (possibly fused) vocab-parallel cross-entropy loss."""
loss = None

# [b s] => [s b]
labels = labels.transpose(0, 1).contiguous()
if self.config.cross_entropy_loss_fusion:
if self.config.cross_entropy_fusion_impl == 'te':
if te_parallel_cross_entropy is not None:
labels = torch.as_strided(labels, labels.size(), (labels.size()[1], 1))
# Use is_cg_capturable=True for full iteration CUDA graphs
# to avoid torch.equal checks
is_cg_capturable = (
hasattr(self.config, 'cuda_graph_scope')
and CudaGraphScope.full_iteration in self.config.cuda_graph_scope
)
if is_cg_capturable and not is_te_min_version("2.7.0"):
from megatron.core.utils import get_te_version

current_version = get_te_version()
raise AssertionError(
f"CUDA graph compatible cross entropy requires "
f"TransformerEngine >= 2.7.0, but found version {current_version}. "
"Please upgrade TransformerEngine "
f"or set cuda_graph_scope to a value other than 'full_iteration'."
)

loss = te_parallel_cross_entropy(
logits, labels, self.tp_group, is_cg_capturable
)
else:
raise RuntimeError("Trying to use a TE block when it's not present.")
elif self.config.cross_entropy_fusion_impl == 'native':
loss = fused_vocab_parallel_cross_entropy(logits, labels, self.tp_group)
else:
loss = tensor_parallel.vocab_parallel_cross_entropy(logits, labels)

# [s b] => [b, s]
loss = loss.transpose(0, 1).contiguous()
return loss
Loading