Skip to content

Commit bce34da

Browse files
add possibility to force using previous (atomic) kernel
1 parent be0e0c8 commit bce34da

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

transformer_engine/common/recipe/current_scaling.cu

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <algorithm>
1212
#include <limits>
1313
#include <type_traits>
14+
#include <cstdlib>
1415

1516
#include "../common.h"
1617
#include "../util/logging.h"
@@ -28,6 +29,19 @@ using bf16__ = __hip_bfloat16;
2829

2930
constexpr int amax_kernel_threads = 512;
3031

32+
// FIXME: Should this be covered by __HIP_PLATFORM_AMD__ ?
33+
inline bool nvte_use_atomic_amax() {
34+
static int cached = -1;
35+
if (cached == -1) {
36+
cached = 0;
37+
const char *env_p = std::getenv("NVTE_USE_ATOMIC_AMAX");
38+
if (env_p && std::string(env_p) == "1") {
39+
cached = 1;
40+
}
41+
}
42+
return cached == 1;
43+
}
44+
3145
template <int BLOCK_THREADS>
3246
__global__ void amax_final_reduce(const float* __restrict__ block_amax,
3347
float* __restrict__ global_amax,
@@ -114,7 +128,10 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, flo
114128
constexpr size_t max_blocks = 65535;
115129
num_blocks = std::min(num_blocks, max_blocks);
116130

117-
const bool UseBlockAmax = (block_amax != nullptr) && (block_capacity >= num_blocks);
131+
const bool UseBlockAmax =
132+
(block_amax != nullptr) &&
133+
(block_capacity >= num_blocks) &&
134+
!nvte_use_atomic_amax();
118135

119136
// Launch kernel
120137
switch (align) {

0 commit comments

Comments
 (0)