Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion transformer_engine/pytorch/cross_entropy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# This file was modified for portability to AMDGPU
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no real change in this file. Let's keep this file intact and then we don't need to add the AMD copyright statement.

# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand Down Expand Up @@ -76,4 +78,4 @@ def backward(ctx, grad_output):
)


parallel_cross_entropy = CrossEntropyFunction.apply
parallel_cross_entropy = CrossEntropyFunction.apply
41 changes: 31 additions & 10 deletions transformer_engine/pytorch/triton/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import triton.language as tl
from torch.utils.cpp_extension import IS_HIP_EXTENSION


@triton.jit
def online_softmax_kernel(
X_ptr,
Expand Down Expand Up @@ -100,6 +99,7 @@ def cross_entropy_kernel(
ignore_idx,
n_cols,
n_non_ignore,
reduce_loss: tl.constexpr,
label_smoothing: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
Expand All @@ -117,7 +117,7 @@ def cross_entropy_kernel(
m_d_X_y_stride: The stride of m/d/X_y tensor.
rank (int): The rank of this device in the TP group.
world_size (int): The size of world involved in this distributed loss calculation.
ignore_idx (int): Tokens to be ignored for loss and gradient calculation.
ignore_idx (int): Tokens to be ignored for loss and gradient calculation. (default -100)
n_cols (int): The number of columns in the input tensor.
n_non_ignore (int): The number of non-ignored elements in the batch.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
Expand Down Expand Up @@ -179,7 +179,13 @@ def cross_entropy_kernel(
if label_smoothing > 0:
# scale X beforehand to avoid overflow
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
# Scale gradients based on reduction mode
# For reduce_loss=True: PyTorch will scale by 1/n_rows, so we need to scale by n_rows/n_non_ignore
# For reduce_loss=False: No additional scaling from PyTorch, so we don't scale here
if reduce_loss:
X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
else:
X_block = tl.exp(X_block - m) / d - eps
tl.store(X_ptr + X_offsets, X_block.to(grad_dtype), mask=X_offsets < n_cols)

# We need tl.debug_barrier() to ensure the new result of X_ptr is written
Expand Down Expand Up @@ -207,7 +213,11 @@ def cross_entropy_kernel(
if y >= vocab_start_idx:
if y < vocab_end_idx:
X_y = tl.load(X_ptr + y - vocab_start_idx)
X_y += -(1 - label_smoothing) / (n_non_ignore)
# Apply the same conditional scaling logic for the target token
if reduce_loss:
X_y += -(1 - label_smoothing) / (n_non_ignore)
else:
X_y += -(1 - label_smoothing)
tl.store(X_ptr + y - vocab_start_idx, X_y)

tl.store(loss_ptr, loss)
Expand All @@ -220,11 +230,13 @@ def cross_entropy_kernel(
else:
NUM_WARPS = 32


@triton.jit
def element_mul_kernel(
X_ptr,
X_stride,
grad_output_ptr,
grad_output_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
Expand All @@ -247,6 +259,7 @@ def element_mul_kernel(
X_ptr += program_id * X_stride

# Load the gradient output value
grad_output_ptr += program_id * grad_output_stride
grad_output = tl.load(grad_output_ptr)

# Perform the element-wise multiplication
Expand All @@ -268,6 +281,8 @@ def cross_entropy_forward(

B, SQ, V = _input.shape
n_rows = B * SQ
valid_token_count = int((target != ignore_idx).sum().item())
denom = max(1, valid_token_count)

assert reduce(mul, list(target.size())) == (B * SQ), "Each token needs a target token ID."

Expand Down Expand Up @@ -323,24 +338,29 @@ def cross_entropy_forward(
world_size=world_size,
ignore_idx=ignore_idx,
n_cols=V,
n_non_ignore=n_rows,
n_non_ignore=denom,
reduce_loss=reduce_loss,
label_smoothing=label_smoothing,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=NUM_WARPS,
)

loss = torch.reshape(loss_1d, (B, SQ)) if not reduce_loss else (torch.sum(loss_1d) / n_rows)
loss = torch.reshape(loss_1d, (B, SQ)) if not reduce_loss else (torch.sum(loss_1d) / denom)

return loss, _input


def cross_entropy_backward(_input: torch.Tensor, grad_output: torch.Tensor):
def cross_entropy_backward(
_input: torch.Tensor, grad_output: torch.Tensor, is_cg_capturable: bool = False
):
"""Backward implementation of cross entropy loss kernel"""

# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
# Only check torch.equal when not in CUDA graph capturable mode
if not is_cg_capturable and torch.equal(
grad_output, torch.tensor(1.0, device=grad_output.device)
):
pass

else:
B, SQ, V = _input.shape
n_rows = B * SQ
Expand All @@ -350,9 +370,10 @@ def cross_entropy_backward(_input: torch.Tensor, grad_output: torch.Tensor):
_input,
_input.stride(-2),
grad_output,
1 if grad_output.numel() > 1 else 0,
V,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=NUM_WARPS,
)

return _input
return _input