Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
c15d93b
Current scaling: two-stage amax kernel
matthiasdiener Nov 12, 2025
51fab36
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 13, 2025
ae35e4c
bugfix graph capture
matthiasdiener Nov 13, 2025
77a68a7
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 17, 2025
c0d8e73
outline workspace allocation
matthiasdiener Nov 17, 2025
6c3507d
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 18, 2025
3c9de07
Proper allocation of workspace
matthiasdiener Nov 18, 2025
91249cc
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 19, 2025
be0e0c8
add a test to compare the accuracy of both amax implementations
matthiasdiener Nov 19, 2025
bce34da
add possibility to force using previous (atomic) kernel
matthiasdiener Nov 19, 2025
8c388cc
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 19, 2025
6388604
add copyrights
matthiasdiener Nov 20, 2025
9e6586f
don't add extra template to kernel
matthiasdiener Nov 20, 2025
18292bf
make amax_kernel_threads usable in pytorch
matthiasdiener Nov 21, 2025
a389455
update remaining calls to nvte_compute_amax
matthiasdiener Nov 21, 2025
d87ab8a
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 24, 2025
fd5dead
additional copyrights
matthiasdiener Nov 24, 2025
16d3bf9
avoid workspace allocations if NVTE_USE_ATOMIC_AMAX is set
matthiasdiener Nov 24, 2025
50b34aa
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 25, 2025
ef532b1
remove use_block_amax parameter, more cleanups
matthiasdiener Nov 25, 2025
f933ef3
Factor workspace allocation into function
matthiasdiener Nov 25, 2025
7d4054e
expand test slightly
matthiasdiener Nov 25, 2025
63cff98
Revert "expand test slightly"
Nov 25, 2025
c7d44a7
guard by HIP macro, address review comments
matthiasdiener Nov 26, 2025
f92b926
bugfix workspace.data.dptr
matthiasdiener Nov 26, 2025
eba552e
various cleanups
matthiasdiener Nov 26, 2025
0d6a177
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 26, 2025
8eda427
simplify types in allocate_amax_workspace
matthiasdiener Nov 26, 2025
6990928
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Dec 1, 2025
9ee618f
fix indentation
matthiasdiener Dec 1, 2025
77b1bc3
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Dec 1, 2025
1357d4b
Use private implementation of DIVUP
matthiasdiener Dec 2, 2025
01b61b5
define amax_kernel_threads on non-AMD
matthiasdiener Dec 2, 2025
ed16f8f
Revert "Use private implementation of DIVUP"
matthiasdiener Dec 2, 2025
95dcbdf
Factor out workspace size calculation
matthiasdiener Dec 2, 2025
b07edf6
change name
matthiasdiener Dec 2, 2025
233eb0a
add copyright
matthiasdiener Dec 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions tests/cpp/operator/test_cast_current_scaling.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
Expand Down Expand Up @@ -195,6 +197,43 @@ TEST_P(CastCSTestSuite, TestCastCS) {
);
}

#ifdef __HIP_PLATFORM_AMD__

TEST(AmaxConsistencyTest, AtomicVsWorkspace) {
using namespace transformer_engine;
using namespace test;

std::vector<size_t> shape{256, 1024};
const size_t N = product(shape);

// Input: FP32, Output: FP8 (E4M3) with per-tensor scaling
Tensor input("input", shape, DType::kFloat32);
Tensor out_atomic("out_atomic", shape, DType::kFloat8E4M3, true, false);
Tensor out_ws("out_ws", shape, DType::kFloat8E4M3, true, false);

fillUniform(&input);

// Path 1: atomic-based amax (no workspace)
nvte_compute_amax(input.data(), out_atomic.data(), 0);

// Path 2: two-stage amax using workspace
std::vector<size_t> ws_shape{N};
Tensor workspace("workspace", ws_shape, DType::kFloat32);
nvte_compute_amax_with_workspace(input.data(), out_ws.data(), workspace.data(), 0);

cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);

// Compare the resulting amax values
float amax_atomic = out_atomic.amax();
float amax_ws = out_ws.amax();

compareResults("amax_consistency", amax_atomic, amax_ws, /*atol=*/0.0f, /*rtol=*/0.0f);
}

#endif



INSTANTIATE_TEST_SUITE_P(
Expand Down
26 changes: 26 additions & 0 deletions transformer_engine/common/include/transformer_engine/recipe.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
Expand Down Expand Up @@ -73,6 +75,12 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
std::vector<NVTETensor> scales, const char* amax_compute_algo, NVTEDType fp8_dtype,
float margin, cudaStream_t stream);

#ifdef __HIP_PLATFORM_AMD__

constexpr int amax_kernel_threads = 512;

#endif

/*! \brief Compute an FP8 tensor's amax.
*
* The amax (maximum absolute value) of the input tensor is computed
Expand All @@ -84,6 +92,24 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
*/
void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t stream);

#ifdef __HIP_PLATFORM_AMD__

size_t nvte_amax_workspace_num_blocks(size_t N);

/*! \brief Compute an FP8 tensor's amax.
*
* The amax (maximum absolute value) of the input tensor is computed
* and written to the amax buffer of the output tensor.
*
* \param[in] input Input tensor. Must be unquantized.
* \param[in,out] output Output tensor. Must be an FP8 tensor with per-tensor scaling.
* \param[out] workspace Output tensor. Must be FP32.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_compute_amax_with_workspace(const NVTETensor input, NVTETensor output, NVTETensor workspace, cudaStream_t stream);

#endif

/*! \brief Update an FP8 tensor's scale based on its amax.
*
* This is only supported for FP8 tensors with per-tensor scaling.
Expand Down
113 changes: 112 additions & 1 deletion transformer_engine/common/recipe/current_scaling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,38 @@ using bf16__ = __hip_bfloat16;

constexpr int amax_kernel_threads = 512;

#ifdef __HIP_PLATFORM_AMD__

template <int BLOCK_THREADS>
__global__ void amax_final_reduce(const float* __restrict__ block_amax,
float* __restrict__ global_amax,
int num_blocks) {
float val = 0.f;

for (int i = threadIdx.x; i < num_blocks; i += BLOCK_THREADS) {
val = fmaxf(val, block_amax[i]);
}

const int warp_id = threadIdx.x / THREADS_PER_WARP;
const float block_max =
reduce_max<BLOCK_THREADS / THREADS_PER_WARP>(val, warp_id);

if (threadIdx.x == 0) {
*global_amax = block_max;
}
}

#endif

template <int nvec, bool aligned, typename InputType>
__launch_bounds__(amax_kernel_threads) __global__
#ifdef __HIP_PLATFORM_AMD__
void amax_kernel(const InputType *input, float *amax, float* __restrict__ block_amax, const size_t N,
const size_t num_aligned_elements) {
#else
void amax_kernel(const InputType *input, float *amax, const size_t N,
const size_t num_aligned_elements) {
#endif
VectorizedLoader<InputType, nvec, aligned> loader(input, N);
InputType max{0.f};
const int warp_id = threadIdx.x / THREADS_PER_WARP;
Expand Down Expand Up @@ -65,12 +93,23 @@ __launch_bounds__(amax_kernel_threads) __global__
// Reduce amax over block
max = reduce_max<amax_kernel_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
#ifdef __HIP_PLATFORM_AMD__
if (block_amax != nullptr) {
// 2-stage: write per-block result
block_amax[blockIdx.x] = max;
} else {
// Atomic path: directly update global amax
atomicMaxFloat(amax, max);
}
#else
atomicMaxFloat(amax, max);
#endif
}
}

template <int nvec, typename InputType>
void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cudaStream_t stream) {
void launch_amax_kernel(const InputType *input, float *amax, const size_t N, float *block_amax,
size_t block_capacity, cudaStream_t stream) {
// Zero out amax so we can update with atomic max
(void)cudaMemsetAsync(amax, 0, sizeof(float), stream);

Expand All @@ -83,38 +122,90 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud
auto align = CheckAlignment(N, nvec, input);
size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, sizeof(InputType));

#ifndef __HIP_PLATFORM_AMD__
// Figure out CUDA blocks
constexpr size_t threads = amax_kernel_threads;
size_t num_blocks = DIVUP(num_aligned_elements, threads);
constexpr size_t max_blocks = 65535;
num_blocks = std::min(num_blocks, max_blocks);

#else
constexpr size_t threads = amax_kernel_threads;
size_t num_blocks = nvte_amax_workspace_num_blocks(num_aligned_elements);
if (block_capacity < num_blocks)
block_amax = nullptr;
#endif

// Launch kernel
switch (align) {
case Alignment::SAME_ALIGNED:
#ifdef __HIP_PLATFORM_AMD__
amax_kernel<nvec, true, InputType>
<<<num_blocks, threads, 0, stream>>>(input, amax, block_amax, N, num_aligned_elements);
#else
amax_kernel<nvec, true, InputType>
<<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements);
#endif
break;
case Alignment::SAME_UNALIGNED:
#ifdef __HIP_PLATFORM_AMD__
amax_kernel<nvec, false, InputType>
<<<num_blocks, threads, 0, stream>>>(input, amax, block_amax, N, num_aligned_elements);
#else
amax_kernel<nvec, false, InputType>
<<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements);
#endif
break;
case Alignment::DIFFERENT: {
// This case is a logic error, since there is only one pointer (input)
// in the alignment check. Still safe to process without vectorization.
#ifdef __HIP_PLATFORM_AMD__
amax_kernel<1, true, InputType><<<num_blocks, threads, 0, stream>>>(input, amax, block_amax, N, N);
#else
amax_kernel<1, true, InputType><<<num_blocks, threads, 0, stream>>>(input, amax, N, N);
#endif
break;
}
}

#ifdef __HIP_PLATFORM_AMD__
if (block_amax != nullptr) {
constexpr int FINAL_REDUCE_THREADS = 256;
dim3 fr_block(FINAL_REDUCE_THREADS);
dim3 fr_grid(1);

amax_final_reduce<FINAL_REDUCE_THREADS>
<<<fr_grid, fr_block, 0, stream>>>(block_amax, amax, static_cast<int>(num_blocks));
}
#endif

// Check results
NVTE_CHECK_CUDA(cudaGetLastError());
}

} // namespace
} // namespace transformer_engine


#ifdef __HIP_PLATFORM_AMD__

size_t nvte_amax_workspace_num_blocks(size_t N) {
constexpr size_t max_blocks_hw = 65535;

size_t max_blocks = transformer_engine::DIVUP(N, static_cast<size_t>(amax_kernel_threads));
size_t workspace_blocks = std::min(max_blocks, max_blocks_hw);
return workspace_blocks;
}

#endif

void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
nvte_compute_amax_with_workspace(input_, output_, /*workspace=*/nullptr, stream);
}

void nvte_compute_amax_with_workspace(const NVTETensor input_, const NVTETensor output_, const NVTETensor workspace_, cudaStream_t stream) {
#endif
NVTE_API_CALL(nvte_compute_amax);
using namespace transformer_engine;

Expand Down Expand Up @@ -150,11 +241,31 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt
to_string(output.amax.dtype), ")");
CheckOutputTensor(output, "output_compute_amax", true);

#ifdef __HIP_PLATFORM_AMD__
// Optional workspace
float* block_amax = nullptr;
size_t block_capacity = 0;

if (workspace_ != nullptr) {
auto &workspace = *reinterpret_cast<Tensor *>(workspace_);
if (workspace.data.dptr != nullptr) {
NVTE_CHECK(workspace.data.dtype == DType::kFloat32,
"Workspace tensor for amax computation must be FP32, got dtype=",
to_string(workspace.data.dtype));
block_amax = reinterpret_cast<float*>(workspace.data.dptr);
block_capacity = workspace.data.numel();
}
}
#endif

// Compute amax
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType);
launch_amax_kernel<nvec>(reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<float *>(output.amax.dptr), input.data.numel(),
#ifdef __HIP_PLATFORM_AMD__
block_amax, block_capacity,
#endif
stream);); // NOLINT(*)
}

Expand Down
31 changes: 31 additions & 0 deletions transformer_engine/pytorch/csrc/common.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
Expand All @@ -10,6 +12,10 @@
#include "pybind.h"
#include "transformer_engine/transformer_engine.h"

#ifdef __HIP_PLATFORM_AMD__
#include "common/common.h"
#endif

namespace transformer_engine::pytorch {

std::vector<size_t> getTensorShape(at::Tensor t) {
Expand Down Expand Up @@ -277,4 +283,29 @@ int roundup(const int value, const int multiple) {
return ((value + multiple - 1) / multiple) * multiple;
}

#ifdef __HIP_PLATFORM_AMD__

inline bool nvte_use_atomic_amax() {
const char *env_p = std::getenv("NVTE_USE_ATOMIC_AMAX");
if (env_p && std::string(env_p) == "1")
return true;
return false;
}

TensorWrapper allocate_amax_workspace(const TensorWrapper& input_tensor) {
if (nvte_use_atomic_amax() || input_tensor.numel() == 0) {
// User chose atomic path, or empty tensor -> no need for workspace
return TensorWrapper{};
}

const auto N = input_tensor.numel();
size_t workspace_blocks = nvte_amax_workspace_num_blocks(N);

at::Tensor ws = at::empty(workspace_blocks, at::CUDA(at::kFloat));

return makeTransformerEngineTensor(ws);
}

#endif

} // namespace transformer_engine::pytorch
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/csrc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,9 @@ std::vector<size_t> convertShape(const NVTEShape& shape);

int roundup(const int value, const int multiple);

#ifdef __HIP_PLATFORM_AMD__
TensorWrapper allocate_amax_workspace(const TensorWrapper& input_tensor);
#endif
} // namespace transformer_engine::pytorch

namespace std {
Expand Down
10 changes: 10 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/activation.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
Expand Down Expand Up @@ -36,10 +38,18 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int
auto [te_output_act, out_act] =
my_quantizer_none->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type));

#ifdef __HIP_PLATFORM_AMD__
auto workspace = allocate_amax_workspace(te_input);
#endif
NVTE_SCOPED_GIL_RELEASE({
act_func(te_input.data(), te_output_act.data(), at::cuda::getCurrentCUDAStream());
// use te_output_act as input to the compute amax and find the amax of activated tensor
#ifdef __HIP_PLATFORM_AMD__
nvte_compute_amax_with_workspace(te_output_act.data(), te_output.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
#else
nvte_compute_amax(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
#endif
});

// my_quantizer here has to be a Float8CurrentScalingQuantizer
Expand Down
8 changes: 8 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/bias.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
Expand Down Expand Up @@ -49,7 +51,13 @@ std::vector<py::object> bgrad_quantize(const at::Tensor& input, py::handle py_qu
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(quantizer.get());
NVTE_SCOPED_GIL_RELEASE({
#ifdef __HIP_PLATFORM_AMD__
nvte_compute_amax_with_workspace(input_tensor.data(), out_tensor.data(),
allocate_amax_workspace(input_tensor).data(),
at::cuda::getCurrentCUDAStream());
#else
nvte_compute_amax(input_tensor.data(), out_tensor.data(), at::cuda::getCurrentCUDAStream());
#endif
});
// check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) {
Expand Down
Loading