Skip to content

Commit d5cdca9

Browse files
rocm-micimhalkamd-hhashemi
authored andcommitted
[AUTOGENERATED] [release/2.5] [ROCm][layer_norm] Use __builtin_amdgcn_rcpf(x) instead of 1.f/x (#1800)
Cherry-pick of #1688 Co-authored-by: Michael Halkenhäuser <[email protected]> Co-authored-by: Hashem Hashemi <[email protected]> (cherry picked from commit f8544af) (cherry picked from commit ed48754) (cherry picked from commit d62a39e) (cherry picked from commit b26ddb8)
1 parent 57c7fa5 commit d5cdca9

File tree

3 files changed

+28
-0
lines changed

3 files changed

+28
-0
lines changed

aten/src/ATen/native/cuda/layer_norm_kernel.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,11 @@ WelfordDataLN cuWelfordOnlineSum(
141141
if constexpr (!rms_norm){
142142
U delta = val - curr_sum.mean;
143143
U new_count = curr_sum.count + 1.f;
144+
#if defined(USE_ROCM) && defined(PYTORCH_LAYERNORM_FAST_RECIPROCAL)
145+
U new_mean = curr_sum.mean + delta * __builtin_amdgcn_rcpf(new_count);
146+
#else
144147
U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster
148+
#endif
145149
return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count};
146150
} else{
147151
return {0.f, curr_sum.sigma2 + val * val, 0};
@@ -159,7 +163,11 @@ WelfordDataLN cuWelfordCombine(
159163
U count = dataA.count + dataB.count;
160164
U mean, sigma2;
161165
if (count > decltype(dataB.count){0}) {
166+
#if defined(USE_ROCM) && defined(PYTORCH_LAYERNORM_FAST_RECIPROCAL)
167+
auto coef = __builtin_amdgcn_rcpf(count);
168+
#else
162169
auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division
170+
#endif
163171
auto nA = dataA.count * coef;
164172
auto nB = dataB.count * coef;
165173
mean = nA*dataA.mean + nB*dataB.mean;

cmake/Dependencies.cmake

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,6 +1037,22 @@ if(USE_ROCM)
10371037
list(APPEND HIP_HIPCC_FLAGS -fdebug-info-for-profiling)
10381038
endif(CMAKE_BUILD_TYPE MATCHES Debug)
10391039

1040+
# Get EnVar 'PYTORCH_LAYERNORM_FAST_RECIPROCAL' (or default to on).
1041+
if(DEFINED ENV{PYTORCH_LAYERNORM_FAST_RECIPROCAL})
1042+
set(PYTORCH_LAYERNORM_FAST_RECIPROCAL_CMAKE $ENV{PYTORCH_LAYERNORM_FAST_RECIPROCAL})
1043+
else()
1044+
set(PYTORCH_LAYERNORM_FAST_RECIPROCAL_CMAKE ON)
1045+
endif()
1046+
1047+
set(PYTORCH_LAYERNORM_FAST_RECIPROCAL
1048+
${PYTORCH_LAYERNORM_FAST_RECIPROCAL_CMAKE}
1049+
CACHE BOOL "Enable fast reciprocals within layer normalization." FORCE
1050+
)
1051+
1052+
if(PYTORCH_LAYERNORM_FAST_RECIPROCAL)
1053+
add_definitions(-DPYTORCH_LAYERNORM_FAST_RECIPROCAL)
1054+
endif()
1055+
10401056
# needed for compat with newer versions of hip-clang that introduced C++20 mangling rules
10411057
list(APPEND HIP_HIPCC_FLAGS -fclang-abi-compat=17)
10421058

setup.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,10 @@
162162
# USE_ROCM_CK_SDPA=1
163163
# Enable building CK SDPA backend in ROCm platform
164164
#
165+
# PYTORCH_LAYERNORM_FAST_RECIPROCAL
166+
# If set, enables the use of builtin functions for fast reciprocals (1/x) w.r.t.
167+
# layer normalization. Default: enabled.
168+
#
165169
# Environment variables we respect (these environment variables are
166170
# conventional and are often understood/set by other software.)
167171
#

0 commit comments

Comments
 (0)