Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 44 additions & 24 deletions tests/pytorch/test_parallel_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
# See LICENSE for license information.

import random
import pytest
import torch
from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy

from utils import dtype_tols


class TestParallelCrossEntropy:

Expand All @@ -19,19 +20,25 @@ def generate_infra(self, reduce_loss: bool, label_smoothing: float):
label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none"
)

def generate_input(self, dtype: torch.dtype, swap_dim: bool, ignore_idx: bool):

def generate_input(
self,
dtype: torch.dtype,
swap_dim: bool,
ignore_idx: bool,
device: torch.device = "cuda",
):
SQ = random.choice([64, 128])
batch = random.choice([1, 2])
vocab = random.choice([64000, 128000])
ignore = random.sample(range(0, SQ - 1), 5)

# Generate random data
if swap_dim:
self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype).cuda()
self.tar_test = torch.randint(0, vocab, (SQ, batch)).cuda()
self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype, device=device)
self.tar_test = torch.randint(0, vocab, (SQ, batch), device=device)
else:
self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype).cuda()
self.tar_test = torch.randint(0, vocab, (batch, SQ)).cuda()
self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype, device=device)
self.tar_test = torch.randint(0, vocab, (batch, SQ), device=device)

if ignore_idx:
for i in ignore:
Expand All @@ -41,9 +48,14 @@ def generate_input(self, dtype: torch.dtype, swap_dim: bool, ignore_idx: bool):
else:
self.tar_test[0][i] = -100

# Make copy of data for reference implementation
self.input_ref = torch.reshape(self.input_test.clone().detach(), (batch * SQ, vocab))
self.tar_ref = torch.reshape(self.tar_test.clone().detach(), (batch * SQ,))

# Enable autograd
self.input_test.requires_grad_()
self.input_ref.requires_grad_()

def one_iteration_test(
self,
dtype: torch.dtype,
Expand All @@ -53,31 +65,39 @@ def one_iteration_test(
ignore_idx: bool = False,
):

# Random data
self.generate_input(dtype, swap_dim, ignore_idx)

self.input_test.requires_grad_(True)
self.input_ref.requires_grad_(True)

# Forward pass
test_loss = self.test_loss_func(
self.input_test, self.tar_test, label_smoothing, reduce_loss, None
)
if reduce_loss:
test_loss.backward()

ref_loss = self.ref_loss_func(self.input_ref, self.tar_ref)
if reduce_loss:
ref_loss.backward()

test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss
# Compute square to avoid trivial backward pass
test_loss = torch.square(test_loss)
ref_loss = torch.square(ref_loss)

torch.testing.assert_close(test_loss, ref_loss, check_dtype=False)
if ignore_idx:
print(test_loss, ref_loss)
# Backward pass
if reduce_loss:
torch.testing.assert_close(
torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad
)

test_loss.backward()
ref_loss.backward()
else:
test_loss.sum().backward()
ref_loss.sum().backward()

# Check that loss and grad input match
tols = dtype_tols(dtype)
test_loss = test_loss.to(dtype=torch.float64, device="cpu")
ref_loss = test_loss.to(dtype=torch.float64, device="cpu")
ref_loss = ref_loss.reshape(test_loss.size())
test_grad_input = self.input_test.grad.to(dtype=torch.float64, device="cpu")
ref_grad_input = self.input_ref.grad.to(dtype=torch.float64, device="cpu")
ref_grad_input = ref_grad_input.reshape(test_grad_input.size())
torch.testing.assert_close(test_loss, ref_loss, **tols)
torch.testing.assert_close(test_grad_input, ref_grad_input, **tols)

# Reset data
self.input_test = None
self.input_ref = None
self.tar_test = None
Expand Down Expand Up @@ -133,4 +153,4 @@ def test_ignore_idx(self):
label_smoothing=0,
reduce_loss=False,
ignore_idx=True,
)
)
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/cross_entropy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand Down Expand Up @@ -76,4 +77,4 @@ def backward(ctx, grad_output):
)


parallel_cross_entropy = CrossEntropyFunction.apply
parallel_cross_entropy = CrossEntropyFunction.apply
41 changes: 31 additions & 10 deletions transformer_engine/pytorch/triton/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import triton.language as tl
from torch.utils.cpp_extension import IS_HIP_EXTENSION


@triton.jit
def online_softmax_kernel(
X_ptr,
Expand Down Expand Up @@ -100,6 +99,7 @@ def cross_entropy_kernel(
ignore_idx,
n_cols,
n_non_ignore,
reduce_loss: tl.constexpr,
label_smoothing: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
Expand All @@ -117,7 +117,7 @@ def cross_entropy_kernel(
m_d_X_y_stride: The stride of m/d/X_y tensor.
rank (int): The rank of this device in the TP group.
world_size (int): The size of world involved in this distributed loss calculation.
ignore_idx (int): Tokens to be ignored for loss and gradient calculation.
ignore_idx (int): Tokens to be ignored for loss and gradient calculation. (default -100)
n_cols (int): The number of columns in the input tensor.
n_non_ignore (int): The number of non-ignored elements in the batch.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
Expand Down Expand Up @@ -179,7 +179,13 @@ def cross_entropy_kernel(
if label_smoothing > 0:
# scale X beforehand to avoid overflow
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
# Scale gradients based on reduction mode
# For reduce_loss=True: PyTorch will scale by 1/n_rows, so we need to scale by n_rows/n_non_ignore
# For reduce_loss=False: No additional scaling from PyTorch, so we don't scale here
if reduce_loss:
X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
else:
X_block = tl.exp(X_block - m) / d - eps
tl.store(X_ptr + X_offsets, X_block.to(grad_dtype), mask=X_offsets < n_cols)

# We need tl.debug_barrier() to ensure the new result of X_ptr is written
Expand Down Expand Up @@ -207,7 +213,11 @@ def cross_entropy_kernel(
if y >= vocab_start_idx:
if y < vocab_end_idx:
X_y = tl.load(X_ptr + y - vocab_start_idx)
X_y += -(1 - label_smoothing) / (n_non_ignore)
# Apply the same conditional scaling logic for the target token
if reduce_loss:
X_y += -(1 - label_smoothing) / (n_non_ignore)
else:
X_y += -(1 - label_smoothing)
tl.store(X_ptr + y - vocab_start_idx, X_y)

tl.store(loss_ptr, loss)
Expand All @@ -220,11 +230,13 @@ def cross_entropy_kernel(
else:
NUM_WARPS = 32


@triton.jit
def element_mul_kernel(
X_ptr,
X_stride,
grad_output_ptr,
grad_output_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
Expand All @@ -247,6 +259,7 @@ def element_mul_kernel(
X_ptr += program_id * X_stride

# Load the gradient output value
grad_output_ptr += program_id * grad_output_stride
grad_output = tl.load(grad_output_ptr)

# Perform the element-wise multiplication
Expand All @@ -268,6 +281,8 @@ def cross_entropy_forward(

B, SQ, V = _input.shape
n_rows = B * SQ
valid_token_count = int((target != ignore_idx).sum().item())
denom = max(1, valid_token_count)

assert reduce(mul, list(target.size())) == (B * SQ), "Each token needs a target token ID."

Expand Down Expand Up @@ -323,24 +338,29 @@ def cross_entropy_forward(
world_size=world_size,
ignore_idx=ignore_idx,
n_cols=V,
n_non_ignore=n_rows,
n_non_ignore=denom,
reduce_loss=reduce_loss,
label_smoothing=label_smoothing,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=NUM_WARPS,
)

loss = torch.reshape(loss_1d, (B, SQ)) if not reduce_loss else (torch.sum(loss_1d) / n_rows)
loss = torch.reshape(loss_1d, (B, SQ)) if not reduce_loss else (torch.sum(loss_1d) / denom)

return loss, _input


def cross_entropy_backward(_input: torch.Tensor, grad_output: torch.Tensor):
def cross_entropy_backward(
_input: torch.Tensor, grad_output: torch.Tensor, is_cg_capturable: bool = False
):
"""Backward implementation of cross entropy loss kernel"""

# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
# Only check torch.equal when not in CUDA graph capturable mode
if not is_cg_capturable and torch.equal(
grad_output, torch.tensor(1.0, device=grad_output.device)
):
pass

else:
B, SQ, V = _input.shape
n_rows = B * SQ
Expand All @@ -350,9 +370,10 @@ def cross_entropy_backward(_input: torch.Tensor, grad_output: torch.Tensor):
_input,
_input.stride(-2),
grad_output,
1 if grad_output.numel() > 1 else 0,
V,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=NUM_WARPS,
)

return _input
return _input