Skip to content

Commit f9f9fcf

Browse files
maajidkhannamathewc
authored andcommitted
Enable SVE ACLE implementation for tanH Aten op for FP32 dType. (pytorch#143741)
In deep learning models, the tanh (hyperbolic tangent) function is a widely used activation function, primarily in feedforward networks, recurrent neural networks (RNNs), and various other architectures. Also, the tanh (hyperbolic tangent) function is commonly used in **Physics-Informed Neural Networks (PINNs).** PINNs are a class of machine learning models designed to solve partial differential equations (PDEs) by incorporating the governing physics directly into the loss function, along with data-driven terms. In PINNs, activation functions like tanh are used in the neural network architecture to enable the model to learn complex mappings between inputs (such as spatial and temporal coordinates) and outputs (such as field variables). **Operator: tanh()** **Current Implementation in OSS in ATen Backend:** **SVE Flow:** Uses SVE sleef when available else std implementation. **With this PR :** **SVE Flow:** Uses SVE ACLE implementation. (Faster Implementation) **Here are the performance improvements.** **Single core perf numbers:** ![image](https://github.com/user-attachments/assets/c2f4bcb6-11bc-4af1-b5eb-278a4cc4a69d) **Metric:** CPU time avg time per iteration (In ms) As you can see with both gcc and clang compilers, we see a significant performance gain with SVE ACLE implementation over current OSS Implementation (Sleef) and also Neon. **Hardware:** m7g.8xlarge (Graviton 3 Instance) **Script used in benchmarking:** ```python import os #os.environ["ATEN_CPU_CAPABILITY"] = "default" os.environ["ATEN_CPU_CAPABILITY"] = "sve256" import torch import torch.nn as nn #Set the random seed for reproducibility torch.manual_seed(1) #Create a tensor of shape (8521, 50) x = torch.randn(8521, 50) for i in range(10): output = x.tanh() #Perform the tanh operation 1000 times and profile the performance print("### CPU tanh") with torch.autograd.profiler.profile(record_shapes=True) as prof: for i in range(1000): output = x.tanh() #Print the profiling results sorted by self CPU time print(prof.key_averages().table(sort_by="self_cpu_time_total")) #Optionally print the final output (if needed, uncomment the following line) print(output) ``` Pull Request resolved: pytorch#143741 Approved by: https://github.com/malfet
1 parent fa3b05d commit f9f9fcf

File tree

2 files changed

+90
-1
lines changed

2 files changed

+90
-1
lines changed

aten/src/ATen/cpu/vec/sve/vec_float.h

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,58 @@ template <> class Vectorized<float> {
8585
}
8686
return b;
8787
}
88+
//Implementation is picked from https://github.com/ARM-software/ComputeLibrary/blob/v25.01/src/core/NEON/SVEMath.inl#L105
89+
inline svfloat32_t svexp_f32_z(svbool_t pg, svfloat32_t x) const {
90+
const auto c1 = svreinterpret_f32_u32(svdup_n_u32(0x3f7ffff6)); // x^1: 0x1.ffffecp-1f
91+
const auto c2 = svreinterpret_f32_u32(svdup_n_u32(0x3efffedb)); // x^2: 0x1.fffdb6p-2f
92+
const auto c3 = svreinterpret_f32_u32(svdup_n_u32(0x3e2aaf33)); // x^3: 0x1.555e66p-3f
93+
const auto c4 = svreinterpret_f32_u32(svdup_n_u32(0x3d2b9f17)); // x^4: 0x1.573e2ep-5f
94+
const auto c5 = svreinterpret_f32_u32(svdup_n_u32(0x3c072010)); // x^5: 0x1.0e4020p-7f
95+
const auto shift = svreinterpret_f32_u32(svdup_n_u32(0x4b00007f)); // 2^23 + 127 = 0x1.0000fep23f
96+
const auto inv_ln2 = svreinterpret_f32_u32(svdup_n_u32(0x3fb8aa3b)); // 1 / ln(2) = 0x1.715476p+0f
97+
const auto neg_ln2_hi =
98+
svreinterpret_f32_u32(svdup_n_u32(0xbf317200)); // -ln(2) from bits -1 to -19: -0x1.62e400p-1f
99+
const auto neg_ln2_lo =
100+
svreinterpret_f32_u32(svdup_n_u32(0xb5bfbe8e)); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f
101+
const auto inf = svdup_n_f32(std::numeric_limits<float>::infinity());
102+
const auto max_input = svdup_n_f32(88.37f); // Approximately ln(2^127.5)
103+
const auto zero = svdup_n_f32(0.f);
104+
const auto min_input = svdup_n_f32(-86.64f); // Approximately ln(2^-125)
105+
// Range reduction:
106+
// e^x = 2^n * e^r
107+
// where:
108+
// n = floor(x / ln(2))
109+
// r = x - n * ln(2)
110+
//
111+
// By adding x / ln(2) with 2^23 + 127 (shift):
112+
// * As FP32 fraction part only has 23-bits, the addition of 2^23 + 127 forces decimal part
113+
// of x / ln(2) out of the result. The integer part of x / ln(2) (i.e. n) + 127 will occupy
114+
// the whole fraction part of z in FP32 format.
115+
// Subtracting 2^23 + 127 (shift) from z will result in the integer part of x / ln(2)
116+
// (i.e. n) because the decimal part has been pushed out and lost.
117+
// * The addition of 127 makes the FP32 fraction part of z ready to be used as the exponent
118+
// in FP32 format. Left shifting z by 23 bits will result in 2^n.
119+
const auto z = svmla_f32_z(pg, shift, x, inv_ln2);
120+
const auto n = svsub_f32_z(pg, z, shift);
121+
const auto scale = svreinterpret_f32_u32(svlsl_n_u32_z(pg, svreinterpret_u32_f32(z), 23)); // 2^n
122+
// The calculation of n * ln(2) is done using 2 steps to achieve accuracy beyond FP32.
123+
// This outperforms longer Taylor series (3-4 tabs) both in term of accuracy and performance.
124+
const auto r_hi = svmla_f32_z(pg, x, n, neg_ln2_hi);
125+
const auto r = svmla_f32_z(pg, r_hi, n, neg_ln2_lo);
126+
// Compute the truncated Taylor series of e^r.
127+
// poly = scale * (1 + c1 * r + c2 * r^2 + c3 * r^3 + c4 * r^4 + c5 * r^5)
128+
const auto r2 = svmul_f32_z(pg, r, r);
129+
const auto p1 = svmul_f32_z(pg, c1, r);
130+
const auto p23 = svmla_f32_z(pg, c2, c3, r);
131+
const auto p45 = svmla_f32_z(pg, c4, c5, r);
132+
const auto p2345 = svmla_f32_z(pg, p23, p45, r2);
133+
const auto p12345 = svmla_f32_z(pg, p1, p2345, r2);
134+
auto poly = svmla_f32_z(pg, scale, p12345, scale);
135+
// Handle underflow and overflow.
136+
poly = svsel_f32(svcmplt_f32(pg, x, min_input), zero, poly);
137+
poly = svsel_f32(svcmpgt_f32(pg, x, max_input), inf, poly);
138+
return poly;
139+
}
88140
static Vectorized<float> loadu(const void* ptr, int64_t count = size()) {
89141
if (count == size())
90142
return svld1_f32(ptrue, reinterpret_cast<const float*>(ptr));
@@ -333,8 +385,34 @@ template <> class Vectorized<float> {
333385
Vectorized<float> tan() const {
334386
return USE_SLEEF(Vectorized<float>(Sleef_tanfx_u10sve(values)),map(std::tan));
335387
}
388+
//Implementation is picked from https://github.com/ARM-software/ComputeLibrary/blob/v25.01/src/core/NEON/SVEMath.inl#L179
336389
Vectorized<float> tanh() const {
337-
return USE_SLEEF(Vectorized<float>(Sleef_tanhfx_u10sve(values)),map(std::tanh));
390+
// Constants used for the tanh calculation.
391+
const svfloat32_t CONST_1 = svdup_n_f32(1.f); // Constant 1.0f for the tanh formula.
392+
const svfloat32_t CONST_2 = svdup_n_f32(2.f); // Constant 2.0f for the tanh formula (used in exp(2x)).
393+
const svfloat32_t CONST_MIN_TANH = svdup_n_f32(-10.f); // Minimum threshold for input values to prevent overflow.
394+
const svfloat32_t CONST_MAX_TANH = svdup_n_f32(10.f); // Maximum threshold for input values to prevent overflow.
395+
396+
// Step 1: Clamp the values within the range [-10, 10] to prevent overflow during exponentiation.
397+
// The tanh function approaches ±1 rapidly as the input grows large, so we limit the input range to avoid numerical instability.
398+
// svmax_f32_z ensures values are greater than -10, and svmin_f32_z ensures they are less than 10.
399+
svfloat32_t x = svmin_f32_z(ptrue, svmax_f32_z(ptrue, values, CONST_MIN_TANH), CONST_MAX_TANH);
400+
401+
// Step 2: Calculate exp(2 * x), where x is the clamped value.
402+
// svmul_f32_z computes 2 * x, and svexp_f32_z computes the exponential of the result.
403+
svfloat32_t exp2x = svexp_f32_z(ptrue, svmul_f32_z(ptrue, CONST_2, x));
404+
405+
// Step 3: Calculate the numerator of the tanh function, which is exp(2x) - 1.
406+
svfloat32_t num = svsub_f32_z(ptrue, exp2x, CONST_1);
407+
408+
// Step 4: Calculate the denominator of the tanh function, which is exp(2x) + 1.
409+
svfloat32_t den = svadd_f32_z(ptrue, exp2x, CONST_1);
410+
411+
// Step 5: Calculate the tanh function as the ratio of the numerator and denominator: num / den.
412+
svfloat32_t tanh = svdiv_f32_z(ptrue, num, den);
413+
414+
// Return the calculated tanh values.
415+
return tanh;
338416
}
339417
Vectorized<float> trunc() const {
340418
return svrintz_f32_x(ptrue, values);

aten/src/ATen/test/vec_test_all_types.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,11 +371,22 @@ namespace {
371371
}
372372
TYPED_TEST(Hyperbolic, Tanh) {
373373
using vec = TypeParam;
374+
// NOTE: Because SVE uses ACL logic, the precision changes, hence the adjusted tolerance.
375+
#if defined(CPU_CAPABILITY_SVE)
376+
using UVT = UvalueType<vec>;
377+
UVT tolerance = getDefaultTolerance<UVT>();
378+
test_unary<vec>(
379+
NAME_INFO(tanH),
380+
RESOLVE_OVERLOAD(std::tanh),
381+
[](vec v) { return v.tanh(); },
382+
createDefaultUnaryTestCase<vec>(TestSeed(), tolerance));
383+
#else
374384
test_unary<vec>(
375385
NAME_INFO(tanH),
376386
RESOLVE_OVERLOAD(std::tanh),
377387
[](vec v) { return v.tanh(); },
378388
createDefaultUnaryTestCase<vec>(TestSeed()));
389+
#endif
379390
}
380391
TYPED_TEST(Hyperbolic, Sinh) {
381392
using vec = TypeParam;

0 commit comments

Comments
 (0)