Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions tests/cpp/operator/test_cast_current_scaling.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
Expand Down Expand Up @@ -196,6 +198,41 @@ TEST_P(CastCSTestSuite, TestCastCS) {
}


TEST(AmaxConsistencyTest, AtomicVsWorkspace) {
using namespace transformer_engine;
using namespace test;

std::vector<size_t> shape{256, 1024};
const size_t N = product(shape);

// Input: FP32, Output: FP8 (E4M3) with per-tensor scaling
Tensor input("input", shape, DType::kFloat32);
Tensor out_atomic("out_atomic", shape, DType::kFloat8E4M3, true, false);
Tensor out_ws("out_ws", shape, DType::kFloat8E4M3, true, false);

fillUniform(&input);

// Path 1: atomic-based amax (no workspace)
nvte_compute_amax(input.data(), out_atomic.data(), 0);

// Path 2: two-stage amax using workspace
// Use a workspace capacity >= number of blocks
std::vector<size_t> ws_shape{N};
Tensor workspace("workspace", ws_shape, DType::kFloat32);
nvte_compute_amax_with_workspace(input.data(), out_ws.data(), workspace.data(), 0);

cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);

// Compare the resulting amax values
float amax_atomic = out_atomic.amax();
float amax_ws = out_ws.amax();

compareResults("amax_consistency", amax_atomic, amax_ws, /*atol=*/0.0f, /*rtol=*/0.0f);
}



INSTANTIATE_TEST_SUITE_P(
OperatorTest,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
Expand Down Expand Up @@ -84,6 +86,8 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
*/
void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t stream);

void nvte_compute_amax_with_workspace(const NVTETensor input_, const NVTETensor output_, const NVTETensor workspace_, cudaStream_t stream);

/*! \brief Update an FP8 tensor's scale based on its amax.
*
* This is only supported for FP8 tensors with per-tensor scaling.
Expand Down
93 changes: 84 additions & 9 deletions transformer_engine/common/recipe/current_scaling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <algorithm>
#include <limits>
#include <type_traits>
#include <cstdlib>

#include "../common.h"
#include "../util/logging.h"
Expand All @@ -28,20 +29,51 @@ using bf16__ = __hip_bfloat16;

constexpr int amax_kernel_threads = 512;

inline bool nvte_use_atomic_amax() {
static int cached = -1;
if (cached == -1) {
cached = 0;
const char *env_p = std::getenv("NVTE_USE_ATOMIC_AMAX");
if (env_p && std::string(env_p) == "1") {
cached = 1;
}
}
return cached == 1;
}

template <int BLOCK_THREADS>
__global__ void amax_final_reduce(const float* __restrict__ block_amax,
float* __restrict__ global_amax,
int num_blocks) {
float val = 0.f;

for (int i = threadIdx.x; i < num_blocks; i += BLOCK_THREADS) {
val = fmaxf(val, block_amax[i]);
}

const int warp_id = threadIdx.x / THREADS_PER_WARP;
const float block_max =
reduce_max<BLOCK_THREADS / THREADS_PER_WARP>(val, warp_id);

if (threadIdx.x == 0) {
*global_amax = block_max;
}
}
template <int nvec, bool aligned, typename InputType>
__launch_bounds__(amax_kernel_threads) __global__
void amax_kernel(const InputType *input, float *amax, const size_t N,
const size_t num_aligned_elements) {
void amax_kernel(const InputType *input, float *amax, float* __restrict__ block_amax, const size_t N,
const size_t num_aligned_elements, bool use_block_amax) {
VectorizedLoader<InputType, nvec, aligned> loader(input, N);
InputType max{0.f};
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const size_t M = num_aligned_elements;

for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) {
loader.load(tid, N);
auto v = loader.separate();
#pragma unroll
for (int i = 0; i < nvec; ++i) {
const InputType val = static_cast<InputType>(loader.separate()[i]);
const InputType val = static_cast<InputType>(v[i]);
__builtin_assume(max >= InputType{0.f});
if constexpr (std::is_same_v<InputType, bf16__>) {
#ifndef __HIP_PLATFORM_AMD__
Expand All @@ -65,12 +97,17 @@ __launch_bounds__(amax_kernel_threads) __global__
// Reduce amax over block
max = reduce_max<amax_kernel_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
atomicMaxFloat(amax, max);
if (use_block_amax) {
block_amax[blockIdx.x] = max;
} else {
atomicMaxFloat(amax, max);
}
}
}

template <int nvec, typename InputType>
void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cudaStream_t stream) {
void launch_amax_kernel(const InputType *input, float *amax, const size_t N, float *block_amax,
size_t block_capacity, cudaStream_t stream) {
// Zero out amax so we can update with atomic max
(void)cudaMemsetAsync(amax, 0, sizeof(float), stream);

Expand All @@ -89,24 +126,42 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud
constexpr size_t max_blocks = 65535;
num_blocks = std::min(num_blocks, max_blocks);

const bool UseBlockAmax =
(block_amax != nullptr) &&
(block_capacity >= num_blocks) &&
!nvte_use_atomic_amax();

// Launch kernel
switch (align) {
case Alignment::SAME_ALIGNED:
amax_kernel<nvec, true, InputType>
<<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements);
<<<num_blocks, threads, 0, stream>>>(
input, amax, block_amax, N, num_aligned_elements, UseBlockAmax);
break;
case Alignment::SAME_UNALIGNED:
amax_kernel<nvec, false, InputType>
<<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements);
<<<num_blocks, threads, 0, stream>>>(
input, amax, block_amax, N, num_aligned_elements, UseBlockAmax);
break;
case Alignment::DIFFERENT: {
// This case is a logic error, since there is only one pointer (input)
// in the alignment check. Still safe to process without vectorization.
amax_kernel<1, true, InputType><<<num_blocks, threads, 0, stream>>>(input, amax, N, N);
amax_kernel<1, true, InputType>
<<<num_blocks, threads, 0, stream>>>(
input, amax, block_amax, N, N, UseBlockAmax);
break;
}
}

if (UseBlockAmax) {
constexpr int FINAL_REDUCE_THREADS = 256;
dim3 fr_block(FINAL_REDUCE_THREADS);
dim3 fr_grid(1);

amax_final_reduce<FINAL_REDUCE_THREADS>
<<<fr_grid, fr_block, 0, stream>>>(block_amax, amax, static_cast<int>(num_blocks));
}

// Check results
NVTE_CHECK_CUDA(cudaGetLastError());
}
Expand All @@ -115,6 +170,10 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud
} // namespace transformer_engine

void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) {
nvte_compute_amax_with_workspace(input_, output_, /*workspace=*/nullptr, stream);
}

void nvte_compute_amax_with_workspace(const NVTETensor input_, const NVTETensor output_, const NVTETensor workspace_, cudaStream_t stream) {
NVTE_API_CALL(nvte_compute_amax);
using namespace transformer_engine;

Expand Down Expand Up @@ -150,11 +209,27 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt
to_string(output.amax.dtype), ")");
CheckOutputTensor(output, "output_compute_amax", true);

// Optional workspace
float* block_amax = nullptr;
size_t block_capacity = 0;

if (workspace_ != nullptr) {
auto &workspace = *reinterpret_cast<Tensor *>(workspace_);
NVTE_CHECK(workspace.data.dptr != nullptr,
"Workspace tensor for amax computation has no data");
NVTE_CHECK(workspace.data.dtype == DType::kFloat32,
"Workspace tensor for amax computation must be FP32, got dtype=",
to_string(workspace.data.dtype));
block_amax = reinterpret_cast<float*>(workspace.data.dptr);
block_capacity = workspace.data.numel();
}

// Compute amax
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType);
launch_amax_kernel<nvec>(reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<float *>(output.amax.dptr), input.data.numel(),
reinterpret_cast<float *>(output.amax.dptr), input.data.numel(), block_amax,
block_capacity,
stream);); // NOLINT(*)
}

Expand Down
21 changes: 20 additions & 1 deletion transformer_engine/pytorch/csrc/extensions/cast.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
Expand Down Expand Up @@ -52,8 +54,25 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(my_quantizer.get());

// workspace for nvte_compute_amax_with_workspace
const auto N = static_cast<size_t>(input_tensor.numel());
constexpr size_t threads = 512; // FIXME: should match amax_kernel_threads
constexpr size_t max_blocks_hw = 65535;

// Worst-case (nvec = 1) upper bound on number of blocks.
size_t max_blocks = std::min(DIVUP(N, threads), max_blocks_hw);

// Allocate FP32 workspace for block-wise amax
auto ws = at::empty({static_cast<long>(max_blocks)},
tensor.options().dtype(at::kFloat));

TensorWrapper te_workspace = makeTransformerEngineTensor(ws);

NVTE_SCOPED_GIL_RELEASE({
nvte_compute_amax(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
nvte_compute_amax_with_workspace(te_input.data(), te_output.data(),
te_workspace.data(),
at::cuda::getCurrentCUDAStream());
});
// check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) {
Expand Down
26 changes: 25 additions & 1 deletion transformer_engine/pytorch/csrc/extensions/recipe.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
Expand All @@ -20,12 +22,34 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) {

TORCH_CHECK(amax.scalar_type() == at::kFloat, "amax must be a float tensor");
TORCH_CHECK(amax.numel() == 1, "amax must have exactly one element");

// Compute an upper bound on the number of blocks for this input.
const auto N = input_tensor.numel();
constexpr size_t threads = 512; // FIXME: should grab amax_kernel_threads here
constexpr size_t max_blocks_hw = 65535;

// Assume worst-case vectorization (nvec = 1) as an upper bound.
size_t max_blocks = std::min(DIVUP(static_cast<size_t>(N), threads),
max_blocks_hw);

// Allocate workspace for the block_amax buffer.
auto ws = at::empty({static_cast<long>(max_blocks)},
tensor.options().dtype(at::kFloat));

std::vector<size_t> ws_shape{static_cast<size_t>(max_blocks)};

TensorWrapper fake_te_output(
nullptr, te_input.shape(),
DType::kFloat8E4M3, // It doesn't matter because we only compute amax.
amax.data_ptr<float>());

nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream());
TensorWrapper te_workspace(
ws.data_ptr(), ws_shape,
DType::kFloat32,
nullptr
);

nvte_compute_amax_with_workspace(te_input.data(), fake_te_output.data(), te_workspace.data(), at::cuda::getCurrentCUDAStream());
}

void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reduction_buffer,
Expand Down