@@ -29,7 +29,6 @@ using bf16__ = __hip_bfloat16;
2929
3030constexpr int amax_kernel_threads = 512 ;
3131
32- // FIXME: Should this be covered by __HIP_PLATFORM_AMD__ ?
3332inline bool nvte_use_atomic_amax () {
3433 static int cached = -1 ;
3534 if (cached == -1 ) {
@@ -60,11 +59,10 @@ __global__ void amax_final_reduce(const float* __restrict__ block_amax,
6059 *global_amax = block_max;
6160 }
6261}
63-
64- template <int nvec, bool aligned, typename InputType, bool UseBlockAmax>
62+ template <int nvec, bool aligned, typename InputType>
6563__launch_bounds__ (amax_kernel_threads) __global__
6664 void amax_kernel (const InputType *input, float *amax, float * __restrict__ block_amax, const size_t N,
67- const size_t num_aligned_elements) {
65+ const size_t num_aligned_elements, bool use_block_amax ) {
6866 VectorizedLoader<InputType, nvec, aligned> loader (input, N);
6967 InputType max{0 .f };
7068 const int warp_id = threadIdx .x / THREADS_PER_WARP;
@@ -99,7 +97,7 @@ __launch_bounds__(amax_kernel_threads) __global__
9997 // Reduce amax over block
10098 max = reduce_max<amax_kernel_threads / THREADS_PER_WARP>(max, warp_id);
10199 if (threadIdx .x == 0 ) {
102- if constexpr (UseBlockAmax ) {
100+ if (use_block_amax ) {
103101 block_amax[blockIdx .x ] = max;
104102 } else {
105103 atomicMaxFloat (amax, max);
@@ -136,29 +134,21 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, flo
136134 // Launch kernel
137135 switch (align) {
138136 case Alignment::SAME_ALIGNED:
139- // FIXME: this code is clumsy. Perhaps don't use the UseBlockAmax extra template argument
140- if (UseBlockAmax)
141- amax_kernel<nvec, true , InputType, true >
142- <<<num_blocks, threads, 0 , stream>>> (input, amax, block_amax, N, num_aligned_elements);
143- else
144- amax_kernel<nvec, true , InputType, false >
145- <<<num_blocks, threads, 0 , stream>>> (input, amax, block_amax, N, num_aligned_elements);
137+ amax_kernel<nvec, true , InputType>
138+ <<<num_blocks, threads, 0 , stream>>> (
139+ input, amax, block_amax, N, num_aligned_elements, UseBlockAmax);
146140 break ;
147141 case Alignment::SAME_UNALIGNED:
148- if (UseBlockAmax)
149- amax_kernel<nvec, false , InputType, true >
150- <<<num_blocks, threads, 0 , stream>>> (input, amax, block_amax, N, num_aligned_elements);
151- else
152- amax_kernel<nvec, false , InputType, false >
153- <<<num_blocks, threads, 0 , stream>>> (input, amax, block_amax, N, num_aligned_elements);
142+ amax_kernel<nvec, false , InputType>
143+ <<<num_blocks, threads, 0 , stream>>> (
144+ input, amax, block_amax, N, num_aligned_elements, UseBlockAmax);
154145 break ;
155146 case Alignment::DIFFERENT: {
156147 // This case is a logic error, since there is only one pointer (input)
157148 // in the alignment check. Still safe to process without vectorization.
158- if (UseBlockAmax)
159- amax_kernel<1 , true , InputType, true ><<<num_blocks, threads, 0 , stream>>> (input, amax, block_amax, N, N);
160- else
161- amax_kernel<1 , true , InputType, false ><<<num_blocks, threads, 0 , stream>>> (input, amax, block_amax, N, N);
149+ amax_kernel<1 , true , InputType>
150+ <<<num_blocks, threads, 0 , stream>>> (
151+ input, amax, block_amax, N, N, UseBlockAmax);
162152 break ;
163153 }
164154 }
0 commit comments