Skip to content

Commit d9b4003

Browse files
Current scaling: two-stage HIP amax kernel (#369)
1 parent 0018b1b commit d9b4003

File tree

10 files changed

+259
-1
lines changed

10 files changed

+259
-1
lines changed

tests/cpp/operator/test_cast_current_scaling.cu

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
/*************************************************************************
2+
* This file was modified for portability to AMDGPU
3+
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
24
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
35
*
46
* See LICENSE for license information.
@@ -195,6 +197,43 @@ TEST_P(CastCSTestSuite, TestCastCS) {
195197
);
196198
}
197199

200+
#ifdef __HIP_PLATFORM_AMD__
201+
202+
TEST(AmaxConsistencyTest, AtomicVsWorkspace) {
203+
using namespace transformer_engine;
204+
using namespace test;
205+
206+
std::vector<size_t> shape{256, 1024};
207+
const size_t N = product(shape);
208+
209+
// Input: FP32, Output: FP8 (E4M3) with per-tensor scaling
210+
Tensor input("input", shape, DType::kFloat32);
211+
Tensor out_atomic("out_atomic", shape, DType::kFloat8E4M3, true, false);
212+
Tensor out_ws("out_ws", shape, DType::kFloat8E4M3, true, false);
213+
214+
fillUniform(&input);
215+
216+
// Path 1: atomic-based amax (no workspace)
217+
nvte_compute_amax(input.data(), out_atomic.data(), 0);
218+
219+
// Path 2: two-stage amax using workspace
220+
std::vector<size_t> ws_shape{N};
221+
Tensor workspace("workspace", ws_shape, DType::kFloat32);
222+
nvte_compute_amax_with_workspace(input.data(), out_ws.data(), workspace.data(), 0);
223+
224+
cudaDeviceSynchronize();
225+
auto err = cudaGetLastError();
226+
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
227+
228+
// Compare the resulting amax values
229+
float amax_atomic = out_atomic.amax();
230+
float amax_ws = out_ws.amax();
231+
232+
compareResults("amax_consistency", amax_atomic, amax_ws, /*atol=*/0.0f, /*rtol=*/0.0f);
233+
}
234+
235+
#endif
236+
198237

199238

200239
INSTANTIATE_TEST_SUITE_P(

transformer_engine/common/include/transformer_engine/recipe.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
/*************************************************************************
2+
* This file was modified for portability to AMDGPU
3+
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
24
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
35
*
46
* See LICENSE for license information.
@@ -73,6 +75,12 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
7375
std::vector<NVTETensor> scales, const char* amax_compute_algo, NVTEDType fp8_dtype,
7476
float margin, cudaStream_t stream);
7577

78+
#ifdef __HIP_PLATFORM_AMD__
79+
80+
constexpr int amax_kernel_threads = 512;
81+
82+
#endif
83+
7684
/*! \brief Compute an FP8 tensor's amax.
7785
*
7886
* The amax (maximum absolute value) of the input tensor is computed
@@ -84,6 +92,24 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
8492
*/
8593
void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t stream);
8694

95+
#ifdef __HIP_PLATFORM_AMD__
96+
97+
size_t nvte_amax_workspace_num_blocks(size_t N);
98+
99+
/*! \brief Compute an FP8 tensor's amax.
100+
*
101+
* The amax (maximum absolute value) of the input tensor is computed
102+
* and written to the amax buffer of the output tensor.
103+
*
104+
* \param[in] input Input tensor. Must be unquantized.
105+
* \param[in,out] output Output tensor. Must be an FP8 tensor with per-tensor scaling.
106+
* \param[out] workspace Output tensor. Must be FP32.
107+
* \param[in] stream CUDA stream used for the operation.
108+
*/
109+
void nvte_compute_amax_with_workspace(const NVTETensor input, NVTETensor output, NVTETensor workspace, cudaStream_t stream);
110+
111+
#endif
112+
87113
/*! \brief Update an FP8 tensor's scale based on its amax.
88114
*
89115
* This is only supported for FP8 tensors with per-tensor scaling.

transformer_engine/common/recipe/current_scaling.cu

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,38 @@ using bf16__ = __hip_bfloat16;
2828

2929
constexpr int amax_kernel_threads = 512;
3030

31+
#ifdef __HIP_PLATFORM_AMD__
32+
33+
template <int BLOCK_THREADS>
34+
__global__ void amax_final_reduce(const float* __restrict__ block_amax,
35+
float* __restrict__ global_amax,
36+
int num_blocks) {
37+
float val = 0.f;
38+
39+
for (int i = threadIdx.x; i < num_blocks; i += BLOCK_THREADS) {
40+
val = fmaxf(val, block_amax[i]);
41+
}
42+
43+
const int warp_id = threadIdx.x / THREADS_PER_WARP;
44+
const float block_max =
45+
reduce_max<BLOCK_THREADS / THREADS_PER_WARP>(val, warp_id);
46+
47+
if (threadIdx.x == 0) {
48+
*global_amax = block_max;
49+
}
50+
}
51+
52+
#endif
53+
3154
template <int nvec, bool aligned, typename InputType>
3255
__launch_bounds__(amax_kernel_threads) __global__
56+
#ifdef __HIP_PLATFORM_AMD__
57+
void amax_kernel(const InputType *input, float *amax, float* __restrict__ block_amax, const size_t N,
58+
const size_t num_aligned_elements) {
59+
#else
3360
void amax_kernel(const InputType *input, float *amax, const size_t N,
3461
const size_t num_aligned_elements) {
62+
#endif
3563
VectorizedLoader<InputType, nvec, aligned> loader(input, N);
3664
InputType max{0.f};
3765
const int warp_id = threadIdx.x / THREADS_PER_WARP;
@@ -65,12 +93,23 @@ __launch_bounds__(amax_kernel_threads) __global__
6593
// Reduce amax over block
6694
max = reduce_max<amax_kernel_threads / THREADS_PER_WARP>(max, warp_id);
6795
if (threadIdx.x == 0) {
96+
#ifdef __HIP_PLATFORM_AMD__
97+
if (block_amax != nullptr) {
98+
// 2-stage: write per-block result
99+
block_amax[blockIdx.x] = max;
100+
} else {
101+
// Atomic path: directly update global amax
102+
atomicMaxFloat(amax, max);
103+
}
104+
#else
68105
atomicMaxFloat(amax, max);
106+
#endif
69107
}
70108
}
71109

72110
template <int nvec, typename InputType>
73-
void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cudaStream_t stream) {
111+
void launch_amax_kernel(const InputType *input, float *amax, const size_t N, float *block_amax,
112+
size_t block_capacity, cudaStream_t stream) {
74113
// Zero out amax so we can update with atomic max
75114
(void)cudaMemsetAsync(amax, 0, sizeof(float), stream);
76115

@@ -83,38 +122,90 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud
83122
auto align = CheckAlignment(N, nvec, input);
84123
size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, sizeof(InputType));
85124

125+
#ifndef __HIP_PLATFORM_AMD__
86126
// Figure out CUDA blocks
87127
constexpr size_t threads = amax_kernel_threads;
88128
size_t num_blocks = DIVUP(num_aligned_elements, threads);
89129
constexpr size_t max_blocks = 65535;
90130
num_blocks = std::min(num_blocks, max_blocks);
91131

132+
#else
133+
constexpr size_t threads = amax_kernel_threads;
134+
size_t num_blocks = nvte_amax_workspace_num_blocks(num_aligned_elements);
135+
if (block_capacity < num_blocks)
136+
block_amax = nullptr;
137+
#endif
138+
92139
// Launch kernel
93140
switch (align) {
94141
case Alignment::SAME_ALIGNED:
142+
#ifdef __HIP_PLATFORM_AMD__
143+
amax_kernel<nvec, true, InputType>
144+
<<<num_blocks, threads, 0, stream>>>(input, amax, block_amax, N, num_aligned_elements);
145+
#else
95146
amax_kernel<nvec, true, InputType>
96147
<<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements);
148+
#endif
97149
break;
98150
case Alignment::SAME_UNALIGNED:
151+
#ifdef __HIP_PLATFORM_AMD__
152+
amax_kernel<nvec, false, InputType>
153+
<<<num_blocks, threads, 0, stream>>>(input, amax, block_amax, N, num_aligned_elements);
154+
#else
99155
amax_kernel<nvec, false, InputType>
100156
<<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements);
157+
#endif
101158
break;
102159
case Alignment::DIFFERENT: {
103160
// This case is a logic error, since there is only one pointer (input)
104161
// in the alignment check. Still safe to process without vectorization.
162+
#ifdef __HIP_PLATFORM_AMD__
163+
amax_kernel<1, true, InputType><<<num_blocks, threads, 0, stream>>>(input, amax, block_amax, N, N);
164+
#else
105165
amax_kernel<1, true, InputType><<<num_blocks, threads, 0, stream>>>(input, amax, N, N);
166+
#endif
106167
break;
107168
}
108169
}
109170

171+
#ifdef __HIP_PLATFORM_AMD__
172+
if (block_amax != nullptr) {
173+
constexpr int FINAL_REDUCE_THREADS = 256;
174+
dim3 fr_block(FINAL_REDUCE_THREADS);
175+
dim3 fr_grid(1);
176+
177+
amax_final_reduce<FINAL_REDUCE_THREADS>
178+
<<<fr_grid, fr_block, 0, stream>>>(block_amax, amax, static_cast<int>(num_blocks));
179+
}
180+
#endif
181+
110182
// Check results
111183
NVTE_CHECK_CUDA(cudaGetLastError());
112184
}
113185

114186
} // namespace
115187
} // namespace transformer_engine
116188

189+
190+
#ifdef __HIP_PLATFORM_AMD__
191+
192+
size_t nvte_amax_workspace_num_blocks(size_t N) {
193+
constexpr size_t max_blocks_hw = 65535;
194+
195+
size_t max_blocks = transformer_engine::DIVUP(N, static_cast<size_t>(amax_kernel_threads));
196+
size_t workspace_blocks = std::min(max_blocks, max_blocks_hw);
197+
return workspace_blocks;
198+
}
199+
200+
#endif
201+
117202
void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) {
203+
#ifdef __HIP_PLATFORM_AMD__
204+
nvte_compute_amax_with_workspace(input_, output_, /*workspace=*/nullptr, stream);
205+
}
206+
207+
void nvte_compute_amax_with_workspace(const NVTETensor input_, const NVTETensor output_, const NVTETensor workspace_, cudaStream_t stream) {
208+
#endif
118209
NVTE_API_CALL(nvte_compute_amax);
119210
using namespace transformer_engine;
120211

@@ -150,11 +241,31 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt
150241
to_string(output.amax.dtype), ")");
151242
CheckOutputTensor(output, "output_compute_amax", true);
152243

244+
#ifdef __HIP_PLATFORM_AMD__
245+
// Optional workspace
246+
float* block_amax = nullptr;
247+
size_t block_capacity = 0;
248+
249+
if (workspace_ != nullptr) {
250+
auto &workspace = *reinterpret_cast<Tensor *>(workspace_);
251+
if (workspace.data.dptr != nullptr) {
252+
NVTE_CHECK(workspace.data.dtype == DType::kFloat32,
253+
"Workspace tensor for amax computation must be FP32, got dtype=",
254+
to_string(workspace.data.dtype));
255+
block_amax = reinterpret_cast<float*>(workspace.data.dptr);
256+
block_capacity = workspace.data.numel();
257+
}
258+
}
259+
#endif
260+
153261
// Compute amax
154262
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
155263
input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType);
156264
launch_amax_kernel<nvec>(reinterpret_cast<const IType *>(input.data.dptr),
157265
reinterpret_cast<float *>(output.amax.dptr), input.data.numel(),
266+
#ifdef __HIP_PLATFORM_AMD__
267+
block_amax, block_capacity,
268+
#endif
158269
stream);); // NOLINT(*)
159270
}
160271

transformer_engine/pytorch/csrc/common.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
/*************************************************************************
2+
* This file was modified for portability to AMDGPU
3+
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
24
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
35
*
46
* See LICENSE for license information.
@@ -10,6 +12,10 @@
1012
#include "pybind.h"
1113
#include "transformer_engine/transformer_engine.h"
1214

15+
#ifdef __HIP_PLATFORM_AMD__
16+
#include "common/common.h"
17+
#endif
18+
1319
namespace transformer_engine::pytorch {
1420

1521
std::vector<size_t> getTensorShape(at::Tensor t) {
@@ -277,4 +283,29 @@ int roundup(const int value, const int multiple) {
277283
return ((value + multiple - 1) / multiple) * multiple;
278284
}
279285

286+
#ifdef __HIP_PLATFORM_AMD__
287+
288+
inline bool nvte_use_atomic_amax() {
289+
const char *env_p = std::getenv("NVTE_USE_ATOMIC_AMAX");
290+
if (env_p && std::string(env_p) == "1")
291+
return true;
292+
return false;
293+
}
294+
295+
TensorWrapper allocate_amax_workspace(const TensorWrapper& input_tensor) {
296+
if (nvte_use_atomic_amax() || input_tensor.numel() == 0) {
297+
// User chose atomic path, or empty tensor -> no need for workspace
298+
return TensorWrapper{};
299+
}
300+
301+
const auto N = input_tensor.numel();
302+
size_t workspace_blocks = nvte_amax_workspace_num_blocks(N);
303+
304+
at::Tensor ws = at::empty(workspace_blocks, at::CUDA(at::kFloat));
305+
306+
return makeTransformerEngineTensor(ws);
307+
}
308+
309+
#endif
310+
280311
} // namespace transformer_engine::pytorch

transformer_engine/pytorch/csrc/common.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,9 @@ std::vector<size_t> convertShape(const NVTEShape& shape);
374374

375375
int roundup(const int value, const int multiple);
376376

377+
#ifdef __HIP_PLATFORM_AMD__
378+
TensorWrapper allocate_amax_workspace(const TensorWrapper& input_tensor);
379+
#endif
377380
} // namespace transformer_engine::pytorch
378381

379382
namespace std {

transformer_engine/pytorch/csrc/extensions/activation.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
/*************************************************************************
2+
* This file was modified for portability to AMDGPU
3+
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
24
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
35
*
46
* See LICENSE for license information.
@@ -36,10 +38,18 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int
3638
auto [te_output_act, out_act] =
3739
my_quantizer_none->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type));
3840

41+
#ifdef __HIP_PLATFORM_AMD__
42+
auto workspace = allocate_amax_workspace(te_input);
43+
#endif
3944
NVTE_SCOPED_GIL_RELEASE({
4045
act_func(te_input.data(), te_output_act.data(), at::cuda::getCurrentCUDAStream());
4146
// use te_output_act as input to the compute amax and find the amax of activated tensor
47+
#ifdef __HIP_PLATFORM_AMD__
48+
nvte_compute_amax_with_workspace(te_output_act.data(), te_output.data(),
49+
workspace.data(), at::cuda::getCurrentCUDAStream());
50+
#else
4251
nvte_compute_amax(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
52+
#endif
4353
});
4454

4555
// my_quantizer here has to be a Float8CurrentScalingQuantizer

transformer_engine/pytorch/csrc/extensions/bias.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
/*************************************************************************
2+
* This file was modified for portability to AMDGPU
3+
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
24
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
35
*
46
* See LICENSE for license information.
@@ -49,7 +51,13 @@ std::vector<py::object> bgrad_quantize(const at::Tensor& input, py::handle py_qu
4951
// my_quantizer here has to be a Float8CurrentScalingQuantizer
5052
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(quantizer.get());
5153
NVTE_SCOPED_GIL_RELEASE({
54+
#ifdef __HIP_PLATFORM_AMD__
55+
nvte_compute_amax_with_workspace(input_tensor.data(), out_tensor.data(),
56+
allocate_amax_workspace(input_tensor).data(),
57+
at::cuda::getCurrentCUDAStream());
58+
#else
5259
nvte_compute_amax(input_tensor.data(), out_tensor.data(), at::cuda::getCurrentCUDAStream());
60+
#endif
5361
});
5462
// check if we need to do amax reudction (depending on model parallel configs)
5563
if (my_quantizer_cs->with_amax_reduction) {

0 commit comments

Comments
 (0)