diff --git a/tests/cpp/operator/test_cast_current_scaling.cu b/tests/cpp/operator/test_cast_current_scaling.cu index f7425f0f3..1f518c6c0 100644 --- a/tests/cpp/operator/test_cast_current_scaling.cu +++ b/tests/cpp/operator/test_cast_current_scaling.cu @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -196,6 +198,41 @@ TEST_P(CastCSTestSuite, TestCastCS) { } +TEST(AmaxConsistencyTest, AtomicVsWorkspace) { + using namespace transformer_engine; + using namespace test; + + std::vector shape{256, 1024}; + const size_t N = product(shape); + + // Input: FP32, Output: FP8 (E4M3) with per-tensor scaling + Tensor input("input", shape, DType::kFloat32); + Tensor out_atomic("out_atomic", shape, DType::kFloat8E4M3, true, false); + Tensor out_ws("out_ws", shape, DType::kFloat8E4M3, true, false); + + fillUniform(&input); + + // Path 1: atomic-based amax (no workspace) + nvte_compute_amax(input.data(), out_atomic.data(), 0); + + // Path 2: two-stage amax using workspace + // Use a workspace capacity >= number of blocks + std::vector ws_shape{N}; + Tensor workspace("workspace", ws_shape, DType::kFloat32); + nvte_compute_amax_with_workspace(input.data(), out_ws.data(), workspace.data(), 0); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + // Compare the resulting amax values + float amax_atomic = out_atomic.amax(); + float amax_ws = out_ws.amax(); + + compareResults("amax_consistency", amax_atomic, amax_ws, /*atol=*/0.0f, /*rtol=*/0.0f); +} + + INSTANTIATE_TEST_SUITE_P( OperatorTest, diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index 50fb696ea..34b4b0116 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -84,6 +86,8 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( */ void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t stream); +void nvte_compute_amax_with_workspace(const NVTETensor input_, const NVTETensor output_, const NVTETensor workspace_, cudaStream_t stream); + /*! \brief Update an FP8 tensor's scale based on its amax. * * This is only supported for FP8 tensors with per-tensor scaling. diff --git a/transformer_engine/common/recipe/current_scaling.cu b/transformer_engine/common/recipe/current_scaling.cu index 709ab200f..4b8422749 100644 --- a/transformer_engine/common/recipe/current_scaling.cu +++ b/transformer_engine/common/recipe/current_scaling.cu @@ -11,6 +11,7 @@ #include #include #include +#include #include "../common.h" #include "../util/logging.h" @@ -28,10 +29,40 @@ using bf16__ = __hip_bfloat16; constexpr int amax_kernel_threads = 512; +inline bool nvte_use_atomic_amax() { + static int cached = -1; + if (cached == -1) { + cached = 0; + const char *env_p = std::getenv("NVTE_USE_ATOMIC_AMAX"); + if (env_p && std::string(env_p) == "1") { + cached = 1; + } + } + return cached == 1; +} + +template +__global__ void amax_final_reduce(const float* __restrict__ block_amax, + float* __restrict__ global_amax, + int num_blocks) { + float val = 0.f; + + for (int i = threadIdx.x; i < num_blocks; i += BLOCK_THREADS) { + val = fmaxf(val, block_amax[i]); + } + + const int warp_id = threadIdx.x / THREADS_PER_WARP; + const float block_max = + reduce_max(val, warp_id); + + if (threadIdx.x == 0) { + *global_amax = block_max; + } +} template __launch_bounds__(amax_kernel_threads) __global__ - void amax_kernel(const InputType *input, float *amax, const size_t N, - const size_t num_aligned_elements) { + void amax_kernel(const InputType *input, float *amax, float* __restrict__ block_amax, const size_t N, + const size_t num_aligned_elements, bool use_block_amax) { VectorizedLoader loader(input, N); InputType max{0.f}; const int warp_id = threadIdx.x / THREADS_PER_WARP; @@ -39,9 +70,10 @@ __launch_bounds__(amax_kernel_threads) __global__ for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { loader.load(tid, N); + auto v = loader.separate(); #pragma unroll for (int i = 0; i < nvec; ++i) { - const InputType val = static_cast(loader.separate()[i]); + const InputType val = static_cast(v[i]); __builtin_assume(max >= InputType{0.f}); if constexpr (std::is_same_v) { #ifndef __HIP_PLATFORM_AMD__ @@ -65,12 +97,17 @@ __launch_bounds__(amax_kernel_threads) __global__ // Reduce amax over block max = reduce_max(max, warp_id); if (threadIdx.x == 0) { - atomicMaxFloat(amax, max); + if (use_block_amax) { + block_amax[blockIdx.x] = max; + } else { + atomicMaxFloat(amax, max); + } } } template -void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cudaStream_t stream) { +void launch_amax_kernel(const InputType *input, float *amax, const size_t N, float *block_amax, + size_t block_capacity, cudaStream_t stream) { // Zero out amax so we can update with atomic max (void)cudaMemsetAsync(amax, 0, sizeof(float), stream); @@ -89,24 +126,42 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud constexpr size_t max_blocks = 65535; num_blocks = std::min(num_blocks, max_blocks); + const bool UseBlockAmax = + (block_amax != nullptr) && + (block_capacity >= num_blocks) && + !nvte_use_atomic_amax(); + // Launch kernel switch (align) { case Alignment::SAME_ALIGNED: amax_kernel - <<>>(input, amax, N, num_aligned_elements); + <<>>( + input, amax, block_amax, N, num_aligned_elements, UseBlockAmax); break; case Alignment::SAME_UNALIGNED: amax_kernel - <<>>(input, amax, N, num_aligned_elements); + <<>>( + input, amax, block_amax, N, num_aligned_elements, UseBlockAmax); break; case Alignment::DIFFERENT: { // This case is a logic error, since there is only one pointer (input) // in the alignment check. Still safe to process without vectorization. - amax_kernel<1, true, InputType><<>>(input, amax, N, N); + amax_kernel<1, true, InputType> + <<>>( + input, amax, block_amax, N, N, UseBlockAmax); break; } } + if (UseBlockAmax) { + constexpr int FINAL_REDUCE_THREADS = 256; + dim3 fr_block(FINAL_REDUCE_THREADS); + dim3 fr_grid(1); + + amax_final_reduce + <<>>(block_amax, amax, static_cast(num_blocks)); + } + // Check results NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -115,6 +170,10 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud } // namespace transformer_engine void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) { + nvte_compute_amax_with_workspace(input_, output_, /*workspace=*/nullptr, stream); +} + +void nvte_compute_amax_with_workspace(const NVTETensor input_, const NVTETensor output_, const NVTETensor workspace_, cudaStream_t stream) { NVTE_API_CALL(nvte_compute_amax); using namespace transformer_engine; @@ -150,11 +209,27 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt to_string(output.amax.dtype), ")"); CheckOutputTensor(output, "output_compute_amax", true); + // Optional workspace + float* block_amax = nullptr; + size_t block_capacity = 0; + + if (workspace_ != nullptr) { + auto &workspace = *reinterpret_cast(workspace_); + NVTE_CHECK(workspace.data.dptr != nullptr, + "Workspace tensor for amax computation has no data"); + NVTE_CHECK(workspace.data.dtype == DType::kFloat32, + "Workspace tensor for amax computation must be FP32, got dtype=", + to_string(workspace.data.dtype)); + block_amax = reinterpret_cast(workspace.data.dptr); + block_capacity = workspace.data.numel(); + } + // Compute amax TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType); launch_amax_kernel(reinterpret_cast(input.data.dptr), - reinterpret_cast(output.amax.dptr), input.data.numel(), + reinterpret_cast(output.amax.dptr), input.data.numel(), block_amax, + block_capacity, stream);); // NOLINT(*) } diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 1edbef8cd..bc21ff821 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -52,8 +54,25 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // my_quantizer here has to be a Float8CurrentScalingQuantizer auto my_quantizer_cs = static_cast(my_quantizer.get()); + + // workspace for nvte_compute_amax_with_workspace + const auto N = static_cast(input_tensor.numel()); + constexpr size_t threads = 512; // FIXME: should match amax_kernel_threads + constexpr size_t max_blocks_hw = 65535; + + // Worst-case (nvec = 1) upper bound on number of blocks. + size_t max_blocks = std::min(DIVUP(N, threads), max_blocks_hw); + + // Allocate FP32 workspace for block-wise amax + auto ws = at::empty({static_cast(max_blocks)}, + tensor.options().dtype(at::kFloat)); + + TensorWrapper te_workspace = makeTransformerEngineTensor(ws); + NVTE_SCOPED_GIL_RELEASE({ - nvte_compute_amax(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); + nvte_compute_amax_with_workspace(te_input.data(), te_output.data(), + te_workspace.data(), + at::cuda::getCurrentCUDAStream()); }); // check if we need to do amax reudction (depending on model parallel configs) if (my_quantizer_cs->with_amax_reduction) { diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cpp b/transformer_engine/pytorch/csrc/extensions/recipe.cpp index eb4d60bd0..d0dab1f28 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cpp +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cpp @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -20,12 +22,34 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) { TORCH_CHECK(amax.scalar_type() == at::kFloat, "amax must be a float tensor"); TORCH_CHECK(amax.numel() == 1, "amax must have exactly one element"); + + // Compute an upper bound on the number of blocks for this input. + const auto N = input_tensor.numel(); + constexpr size_t threads = 512; // FIXME: should grab amax_kernel_threads here + constexpr size_t max_blocks_hw = 65535; + + // Assume worst-case vectorization (nvec = 1) as an upper bound. + size_t max_blocks = std::min(DIVUP(static_cast(N), threads), + max_blocks_hw); + + // Allocate workspace for the block_amax buffer. + auto ws = at::empty({static_cast(max_blocks)}, + tensor.options().dtype(at::kFloat)); + + std::vector ws_shape{static_cast(max_blocks)}; + TensorWrapper fake_te_output( nullptr, te_input.shape(), DType::kFloat8E4M3, // It doesn't matter because we only compute amax. amax.data_ptr()); - nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream()); + TensorWrapper te_workspace( + ws.data_ptr(), ws_shape, + DType::kFloat32, + nullptr + ); + + nvte_compute_amax_with_workspace(te_input.data(), fake_te_output.data(), te_workspace.data(), at::cuda::getCurrentCUDAStream()); } void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reduction_buffer,