@@ -28,9 +28,28 @@ using bf16__ = __hip_bfloat16;
2828
2929constexpr int amax_kernel_threads = 512 ;
3030
31+ template <int BLOCK_THREADS>
32+ __global__ void amax_final_reduce (const float * __restrict__ block_amax,
33+ float * __restrict__ global_amax,
34+ int num_blocks) {
35+ float val = 0 .f ;
36+
37+ for (int i = threadIdx .x ; i < num_blocks; i += BLOCK_THREADS) {
38+ val = fmaxf (val, block_amax[i]);
39+ }
40+
41+ const int warp_id = threadIdx .x / THREADS_PER_WARP;
42+ const float block_max =
43+ reduce_max<BLOCK_THREADS / THREADS_PER_WARP>(val, warp_id);
44+
45+ if (threadIdx .x == 0 ) {
46+ *global_amax = block_max;
47+ }
48+ }
49+
3150template <int nvec, bool aligned, typename InputType>
3251__launch_bounds__ (amax_kernel_threads) __global__
33- void amax_kernel (const InputType *input, float *amax , const size_t N,
52+ void amax_kernel (const InputType *input, float * __restrict__ block_amax , const size_t N,
3453 const size_t num_aligned_elements) {
3554 VectorizedLoader<InputType, nvec, aligned> loader (input, N);
3655 InputType max{0 .f };
@@ -39,9 +58,10 @@ __launch_bounds__(amax_kernel_threads) __global__
3958
4059 for (size_t tid = blockIdx .x * blockDim .x + threadIdx .x ; tid < M; tid += gridDim .x * blockDim .x ) {
4160 loader.load (tid, N);
61+ auto v = loader.separate ();
4262#pragma unroll
4363 for (int i = 0 ; i < nvec; ++i) {
44- const InputType val = static_cast <InputType>(loader. separate () [i]);
64+ const InputType val = static_cast <InputType>(v [i]);
4565 __builtin_assume (max >= InputType{0 .f });
4666 if constexpr (std::is_same_v<InputType, bf16__>) {
4767#ifndef __HIP_PLATFORM_AMD__
@@ -65,7 +85,7 @@ __launch_bounds__(amax_kernel_threads) __global__
6585 // Reduce amax over block
6686 max = reduce_max<amax_kernel_threads / THREADS_PER_WARP>(max, warp_id);
6787 if (threadIdx .x == 0 ) {
68- atomicMaxFloat (amax, max) ;
88+ block_amax[ blockIdx . x ] = max;
6989 }
7090}
7191
@@ -89,24 +109,36 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud
89109 constexpr size_t max_blocks = 65535 ;
90110 num_blocks = std::min (num_blocks, max_blocks);
91111
112+ float * block_amax = nullptr ;
113+ NVTE_CHECK_CUDA (cudaMalloc (&block_amax, num_blocks * sizeof (float )));
114+
92115 // Launch kernel
93116 switch (align) {
94117 case Alignment::SAME_ALIGNED:
95118 amax_kernel<nvec, true , InputType>
96- <<<num_blocks, threads, 0 , stream>>> (input, amax , N, num_aligned_elements);
119+ <<<num_blocks, threads, 0 , stream>>> (input, block_amax , N, num_aligned_elements);
97120 break ;
98121 case Alignment::SAME_UNALIGNED:
99122 amax_kernel<nvec, false , InputType>
100- <<<num_blocks, threads, 0 , stream>>> (input, amax , N, num_aligned_elements);
123+ <<<num_blocks, threads, 0 , stream>>> (input, block_amax , N, num_aligned_elements);
101124 break ;
102125 case Alignment::DIFFERENT: {
103126 // This case is a logic error, since there is only one pointer (input)
104127 // in the alignment check. Still safe to process without vectorization.
105- amax_kernel<1 , true , InputType><<<num_blocks, threads, 0 , stream>>> (input, amax , N, N);
128+ amax_kernel<1 , true , InputType><<<num_blocks, threads, 0 , stream>>> (input, block_amax , N, N);
106129 break ;
107130 }
108131 }
109132
133+ {
134+ constexpr int FINAL_REDUCE_THREADS = 256 ;
135+ dim3 fr_block (FINAL_REDUCE_THREADS);
136+ dim3 fr_grid (1 );
137+
138+ amax_final_reduce<FINAL_REDUCE_THREADS>
139+ <<<fr_grid, fr_block, 0 , stream>>> (block_amax, amax, static_cast <int >(num_blocks));
140+ }
141+
110142 // Check results
111143 NVTE_CHECK_CUDA (cudaGetLastError ());
112144}
0 commit comments