Skip to content

Commit c15d93b

Browse files
Current scaling: two-stage amax kernel
1 parent 7887013 commit c15d93b

File tree

1 file changed

+38
-6
lines changed

1 file changed

+38
-6
lines changed

transformer_engine/common/recipe/current_scaling.cu

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,28 @@ using bf16__ = __hip_bfloat16;
2828

2929
constexpr int amax_kernel_threads = 512;
3030

31+
template <int BLOCK_THREADS>
32+
__global__ void amax_final_reduce(const float* __restrict__ block_amax,
33+
float* __restrict__ global_amax,
34+
int num_blocks) {
35+
float val = 0.f;
36+
37+
for (int i = threadIdx.x; i < num_blocks; i += BLOCK_THREADS) {
38+
val = fmaxf(val, block_amax[i]);
39+
}
40+
41+
const int warp_id = threadIdx.x / THREADS_PER_WARP;
42+
const float block_max =
43+
reduce_max<BLOCK_THREADS / THREADS_PER_WARP>(val, warp_id);
44+
45+
if (threadIdx.x == 0) {
46+
*global_amax = block_max;
47+
}
48+
}
49+
3150
template <int nvec, bool aligned, typename InputType>
3251
__launch_bounds__(amax_kernel_threads) __global__
33-
void amax_kernel(const InputType *input, float *amax, const size_t N,
52+
void amax_kernel(const InputType *input, float* __restrict__ block_amax, const size_t N,
3453
const size_t num_aligned_elements) {
3554
VectorizedLoader<InputType, nvec, aligned> loader(input, N);
3655
InputType max{0.f};
@@ -39,9 +58,10 @@ __launch_bounds__(amax_kernel_threads) __global__
3958

4059
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) {
4160
loader.load(tid, N);
61+
auto v = loader.separate();
4262
#pragma unroll
4363
for (int i = 0; i < nvec; ++i) {
44-
const InputType val = static_cast<InputType>(loader.separate()[i]);
64+
const InputType val = static_cast<InputType>(v[i]);
4565
__builtin_assume(max >= InputType{0.f});
4666
if constexpr (std::is_same_v<InputType, bf16__>) {
4767
#ifndef __HIP_PLATFORM_AMD__
@@ -65,7 +85,7 @@ __launch_bounds__(amax_kernel_threads) __global__
6585
// Reduce amax over block
6686
max = reduce_max<amax_kernel_threads / THREADS_PER_WARP>(max, warp_id);
6787
if (threadIdx.x == 0) {
68-
atomicMaxFloat(amax, max);
88+
block_amax[blockIdx.x] = max;
6989
}
7090
}
7191

@@ -89,24 +109,36 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud
89109
constexpr size_t max_blocks = 65535;
90110
num_blocks = std::min(num_blocks, max_blocks);
91111

112+
float* block_amax = nullptr;
113+
NVTE_CHECK_CUDA(cudaMalloc(&block_amax, num_blocks * sizeof(float)));
114+
92115
// Launch kernel
93116
switch (align) {
94117
case Alignment::SAME_ALIGNED:
95118
amax_kernel<nvec, true, InputType>
96-
<<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements);
119+
<<<num_blocks, threads, 0, stream>>>(input, block_amax, N, num_aligned_elements);
97120
break;
98121
case Alignment::SAME_UNALIGNED:
99122
amax_kernel<nvec, false, InputType>
100-
<<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements);
123+
<<<num_blocks, threads, 0, stream>>>(input, block_amax, N, num_aligned_elements);
101124
break;
102125
case Alignment::DIFFERENT: {
103126
// This case is a logic error, since there is only one pointer (input)
104127
// in the alignment check. Still safe to process without vectorization.
105-
amax_kernel<1, true, InputType><<<num_blocks, threads, 0, stream>>>(input, amax, N, N);
128+
amax_kernel<1, true, InputType><<<num_blocks, threads, 0, stream>>>(input, block_amax, N, N);
106129
break;
107130
}
108131
}
109132

133+
{
134+
constexpr int FINAL_REDUCE_THREADS = 256;
135+
dim3 fr_block(FINAL_REDUCE_THREADS);
136+
dim3 fr_grid(1);
137+
138+
amax_final_reduce<FINAL_REDUCE_THREADS>
139+
<<<fr_grid, fr_block, 0, stream>>>(block_amax, amax, static_cast<int>(num_blocks));
140+
}
141+
110142
// Check results
111143
NVTE_CHECK_CUDA(cudaGetLastError());
112144
}

0 commit comments

Comments
 (0)