Skip to content

Commit 9e6586f

Browse files
don't add extra template to kernel
1 parent 6388604 commit 9e6586f

File tree

1 file changed

+12
-22
lines changed

1 file changed

+12
-22
lines changed

transformer_engine/common/recipe/current_scaling.cu

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ using bf16__ = __hip_bfloat16;
2929

3030
constexpr int amax_kernel_threads = 512;
3131

32-
// FIXME: Should this be covered by __HIP_PLATFORM_AMD__ ?
3332
inline bool nvte_use_atomic_amax() {
3433
static int cached = -1;
3534
if (cached == -1) {
@@ -60,11 +59,10 @@ __global__ void amax_final_reduce(const float* __restrict__ block_amax,
6059
*global_amax = block_max;
6160
}
6261
}
63-
64-
template <int nvec, bool aligned, typename InputType, bool UseBlockAmax>
62+
template <int nvec, bool aligned, typename InputType>
6563
__launch_bounds__(amax_kernel_threads) __global__
6664
void amax_kernel(const InputType *input, float *amax, float* __restrict__ block_amax, const size_t N,
67-
const size_t num_aligned_elements) {
65+
const size_t num_aligned_elements, bool use_block_amax) {
6866
VectorizedLoader<InputType, nvec, aligned> loader(input, N);
6967
InputType max{0.f};
7068
const int warp_id = threadIdx.x / THREADS_PER_WARP;
@@ -99,7 +97,7 @@ __launch_bounds__(amax_kernel_threads) __global__
9997
// Reduce amax over block
10098
max = reduce_max<amax_kernel_threads / THREADS_PER_WARP>(max, warp_id);
10199
if (threadIdx.x == 0) {
102-
if constexpr (UseBlockAmax) {
100+
if (use_block_amax) {
103101
block_amax[blockIdx.x] = max;
104102
} else {
105103
atomicMaxFloat(amax, max);
@@ -136,29 +134,21 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, flo
136134
// Launch kernel
137135
switch (align) {
138136
case Alignment::SAME_ALIGNED:
139-
// FIXME: this code is clumsy. Perhaps don't use the UseBlockAmax extra template argument
140-
if (UseBlockAmax)
141-
amax_kernel<nvec, true, InputType, true>
142-
<<<num_blocks, threads, 0, stream>>>(input, amax, block_amax, N, num_aligned_elements);
143-
else
144-
amax_kernel<nvec, true, InputType, false>
145-
<<<num_blocks, threads, 0, stream>>>(input, amax, block_amax, N, num_aligned_elements);
137+
amax_kernel<nvec, true, InputType>
138+
<<<num_blocks, threads, 0, stream>>>(
139+
input, amax, block_amax, N, num_aligned_elements, UseBlockAmax);
146140
break;
147141
case Alignment::SAME_UNALIGNED:
148-
if (UseBlockAmax)
149-
amax_kernel<nvec, false, InputType, true>
150-
<<<num_blocks, threads, 0, stream>>>(input, amax, block_amax, N, num_aligned_elements);
151-
else
152-
amax_kernel<nvec, false, InputType, false>
153-
<<<num_blocks, threads, 0, stream>>>(input, amax, block_amax, N, num_aligned_elements);
142+
amax_kernel<nvec, false, InputType>
143+
<<<num_blocks, threads, 0, stream>>>(
144+
input, amax, block_amax, N, num_aligned_elements, UseBlockAmax);
154145
break;
155146
case Alignment::DIFFERENT: {
156147
// This case is a logic error, since there is only one pointer (input)
157148
// in the alignment check. Still safe to process without vectorization.
158-
if (UseBlockAmax)
159-
amax_kernel<1, true, InputType, true><<<num_blocks, threads, 0, stream>>>(input, amax, block_amax, N, N);
160-
else
161-
amax_kernel<1, true, InputType, false><<<num_blocks, threads, 0, stream>>>(input, amax, block_amax, N, N);
149+
amax_kernel<1, true, InputType>
150+
<<<num_blocks, threads, 0, stream>>>(
151+
input, amax, block_amax, N, N, UseBlockAmax);
162152
break;
163153
}
164154
}

0 commit comments

Comments
 (0)