diff --git a/tests/pytorch/test_parallel_cross_entropy.py b/tests/pytorch/test_parallel_cross_entropy.py index fdb9b7f0b..c8164908e 100644 --- a/tests/pytorch/test_parallel_cross_entropy.py +++ b/tests/pytorch/test_parallel_cross_entropy.py @@ -3,10 +3,11 @@ # See LICENSE for license information. import random -import pytest import torch from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy +from utils import dtype_tols + class TestParallelCrossEntropy: @@ -19,19 +20,25 @@ def generate_infra(self, reduce_loss: bool, label_smoothing: float): label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none" ) - def generate_input(self, dtype: torch.dtype, swap_dim: bool, ignore_idx: bool): - + def generate_input( + self, + dtype: torch.dtype, + swap_dim: bool, + ignore_idx: bool, + device: torch.device = "cuda", + ): SQ = random.choice([64, 128]) batch = random.choice([1, 2]) vocab = random.choice([64000, 128000]) ignore = random.sample(range(0, SQ - 1), 5) + # Generate random data if swap_dim: - self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype).cuda() - self.tar_test = torch.randint(0, vocab, (SQ, batch)).cuda() + self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype, device=device) + self.tar_test = torch.randint(0, vocab, (SQ, batch), device=device) else: - self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype).cuda() - self.tar_test = torch.randint(0, vocab, (batch, SQ)).cuda() + self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype, device=device) + self.tar_test = torch.randint(0, vocab, (batch, SQ), device=device) if ignore_idx: for i in ignore: @@ -41,9 +48,14 @@ def generate_input(self, dtype: torch.dtype, swap_dim: bool, ignore_idx: bool): else: self.tar_test[0][i] = -100 + # Make copy of data for reference implementation self.input_ref = torch.reshape(self.input_test.clone().detach(), (batch * SQ, vocab)) self.tar_ref = torch.reshape(self.tar_test.clone().detach(), (batch * SQ,)) + # Enable autograd + self.input_test.requires_grad_() + self.input_ref.requires_grad_() + def one_iteration_test( self, dtype: torch.dtype, @@ -53,31 +65,39 @@ def one_iteration_test( ignore_idx: bool = False, ): + # Random data self.generate_input(dtype, swap_dim, ignore_idx) - self.input_test.requires_grad_(True) - self.input_ref.requires_grad_(True) - + # Forward pass test_loss = self.test_loss_func( self.input_test, self.tar_test, label_smoothing, reduce_loss, None ) - if reduce_loss: - test_loss.backward() - ref_loss = self.ref_loss_func(self.input_ref, self.tar_ref) - if reduce_loss: - ref_loss.backward() - test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss + # Compute square to avoid trivial backward pass + test_loss = torch.square(test_loss) + ref_loss = torch.square(ref_loss) - torch.testing.assert_close(test_loss, ref_loss, check_dtype=False) - if ignore_idx: - print(test_loss, ref_loss) + # Backward pass if reduce_loss: - torch.testing.assert_close( - torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad - ) - + test_loss.backward() + ref_loss.backward() + else: + test_loss.sum().backward() + ref_loss.sum().backward() + + # Check that loss and grad input match + tols = dtype_tols(dtype) + test_loss = test_loss.to(dtype=torch.float64, device="cpu") + ref_loss = test_loss.to(dtype=torch.float64, device="cpu") + ref_loss = ref_loss.reshape(test_loss.size()) + test_grad_input = self.input_test.grad.to(dtype=torch.float64, device="cpu") + ref_grad_input = self.input_ref.grad.to(dtype=torch.float64, device="cpu") + ref_grad_input = ref_grad_input.reshape(test_grad_input.size()) + torch.testing.assert_close(test_loss, ref_loss, **tols) + torch.testing.assert_close(test_grad_input, ref_grad_input, **tols) + + # Reset data self.input_test = None self.input_ref = None self.tar_test = None @@ -133,4 +153,4 @@ def test_ignore_idx(self): label_smoothing=0, reduce_loss=False, ignore_idx=True, - ) + ) \ No newline at end of file diff --git a/transformer_engine/pytorch/cross_entropy.py b/transformer_engine/pytorch/cross_entropy.py index 75b5de37b..0d05babb6 100644 --- a/transformer_engine/pytorch/cross_entropy.py +++ b/transformer_engine/pytorch/cross_entropy.py @@ -1,3 +1,4 @@ +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -76,4 +77,4 @@ def backward(ctx, grad_output): ) -parallel_cross_entropy = CrossEntropyFunction.apply +parallel_cross_entropy = CrossEntropyFunction.apply \ No newline at end of file diff --git a/transformer_engine/pytorch/triton/cross_entropy.py b/transformer_engine/pytorch/triton/cross_entropy.py index f015c8871..a0431fe18 100644 --- a/transformer_engine/pytorch/triton/cross_entropy.py +++ b/transformer_engine/pytorch/triton/cross_entropy.py @@ -17,7 +17,6 @@ import triton.language as tl from torch.utils.cpp_extension import IS_HIP_EXTENSION - @triton.jit def online_softmax_kernel( X_ptr, @@ -100,6 +99,7 @@ def cross_entropy_kernel( ignore_idx, n_cols, n_non_ignore, + reduce_loss: tl.constexpr, label_smoothing: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): @@ -117,7 +117,7 @@ def cross_entropy_kernel( m_d_X_y_stride: The stride of m/d/X_y tensor. rank (int): The rank of this device in the TP group. world_size (int): The size of world involved in this distributed loss calculation. - ignore_idx (int): Tokens to be ignored for loss and gradient calculation. + ignore_idx (int): Tokens to be ignored for loss and gradient calculation. (default -100) n_cols (int): The number of columns in the input tensor. n_non_ignore (int): The number of non-ignored elements in the batch. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. @@ -179,7 +179,13 @@ def cross_entropy_kernel( if label_smoothing > 0: # scale X beforehand to avoid overflow scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) - X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore) + # Scale gradients based on reduction mode + # For reduce_loss=True: PyTorch will scale by 1/n_rows, so we need to scale by n_rows/n_non_ignore + # For reduce_loss=False: No additional scaling from PyTorch, so we don't scale here + if reduce_loss: + X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore) + else: + X_block = tl.exp(X_block - m) / d - eps tl.store(X_ptr + X_offsets, X_block.to(grad_dtype), mask=X_offsets < n_cols) # We need tl.debug_barrier() to ensure the new result of X_ptr is written @@ -207,7 +213,11 @@ def cross_entropy_kernel( if y >= vocab_start_idx: if y < vocab_end_idx: X_y = tl.load(X_ptr + y - vocab_start_idx) - X_y += -(1 - label_smoothing) / (n_non_ignore) + # Apply the same conditional scaling logic for the target token + if reduce_loss: + X_y += -(1 - label_smoothing) / (n_non_ignore) + else: + X_y += -(1 - label_smoothing) tl.store(X_ptr + y - vocab_start_idx, X_y) tl.store(loss_ptr, loss) @@ -220,11 +230,13 @@ def cross_entropy_kernel( else: NUM_WARPS = 32 + @triton.jit def element_mul_kernel( X_ptr, X_stride, grad_output_ptr, + grad_output_stride, n_cols, BLOCK_SIZE: tl.constexpr, ): @@ -247,6 +259,7 @@ def element_mul_kernel( X_ptr += program_id * X_stride # Load the gradient output value + grad_output_ptr += program_id * grad_output_stride grad_output = tl.load(grad_output_ptr) # Perform the element-wise multiplication @@ -268,6 +281,8 @@ def cross_entropy_forward( B, SQ, V = _input.shape n_rows = B * SQ + valid_token_count = int((target != ignore_idx).sum().item()) + denom = max(1, valid_token_count) assert reduce(mul, list(target.size())) == (B * SQ), "Each token needs a target token ID." @@ -323,24 +338,29 @@ def cross_entropy_forward( world_size=world_size, ignore_idx=ignore_idx, n_cols=V, - n_non_ignore=n_rows, + n_non_ignore=denom, + reduce_loss=reduce_loss, label_smoothing=label_smoothing, BLOCK_SIZE=BLOCK_SIZE, num_warps=NUM_WARPS, ) - loss = torch.reshape(loss_1d, (B, SQ)) if not reduce_loss else (torch.sum(loss_1d) / n_rows) + loss = torch.reshape(loss_1d, (B, SQ)) if not reduce_loss else (torch.sum(loss_1d) / denom) return loss, _input -def cross_entropy_backward(_input: torch.Tensor, grad_output: torch.Tensor): +def cross_entropy_backward( + _input: torch.Tensor, grad_output: torch.Tensor, is_cg_capturable: bool = False +): """Backward implementation of cross entropy loss kernel""" # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time - if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + # Only check torch.equal when not in CUDA graph capturable mode + if not is_cg_capturable and torch.equal( + grad_output, torch.tensor(1.0, device=grad_output.device) + ): pass - else: B, SQ, V = _input.shape n_rows = B * SQ @@ -350,9 +370,10 @@ def cross_entropy_backward(_input: torch.Tensor, grad_output: torch.Tensor): _input, _input.stride(-2), grad_output, + 1 if grad_output.numel() > 1 else 0, V, BLOCK_SIZE=BLOCK_SIZE, num_warps=NUM_WARPS, ) - return _input + return _input \ No newline at end of file