@@ -64,7 +64,7 @@ const std::array<float32x4_t, 8> log_tab =
6464 *
6565 * @return The calculated inverse square root.
6666 */
67- inline float32x4_t vinvsqrt_f32 (float32x4_t x)
67+ inline float32x4_t vinvsqrtq_f32 (float32x4_t x)
6868{
6969 float32x4_t sqrt_reciprocal = vrsqrteq_f32 (x);
7070 sqrt_reciprocal = vmulq_f32 (vrsqrtsq_f32 (vmulq_f32 (x, sqrt_reciprocal), sqrt_reciprocal), sqrt_reciprocal);
@@ -79,7 +79,7 @@ inline float32x4_t vinvsqrt_f32(float32x4_t x)
7979 *
8080 * @return The calculated reciprocal.
8181 */
82- inline float32x4_t vinv_f32 (const float32x4_t &x)
82+ inline float32x4_t vinvq_f32 (const float32x4_t &x)
8383{
8484 float32x4_t recip = vrecpeq_f32 (x);
8585 recip = vmulq_f32 (vrecpsq_f32 (x, recip), recip);
@@ -94,7 +94,7 @@ inline float32x4_t vinv_f32(const float32x4_t &x)
9494 *
9595 * @return The calculated approximation.
9696 */
97- inline float32x4_t vtaylor_poly_f32 (const float32x4_t &x, const std::array<float32x4_t , 8 > &coeffs)
97+ inline float32x4_t vtaylor_polyq_f32 (const float32x4_t &x, const std::array<float32x4_t , 8 > &coeffs)
9898{
9999 float32x4_t A = vmlaq_f32 (coeffs[0 ], coeffs[4 ], x);
100100 float32x4_t B = vmlaq_f32 (coeffs[2 ], coeffs[6 ], x);
@@ -112,7 +112,7 @@ inline float32x4_t vtaylor_poly_f32(const float32x4_t &x, const std::array<float
112112 *
113113 * @return The calculated exponent.
114114 */
115- inline float32x4_t vexp_f32 (const float32x4_t &x)
115+ inline float32x4_t vexpq_f32 (const float32x4_t &x)
116116{
117117 static const float32x4_t CONST_LN2 = vdupq_n_f32 (0 .6931471805f ); // ln(2)
118118 static const float32x4_t CONST_INV_LN2 = vdupq_n_f32 (1 .4426950408f ); // 1/ln(2)
@@ -122,7 +122,7 @@ inline float32x4_t vexp_f32(const float32x4_t &x)
122122 float32x4_t val = vmlsq_f32 (x, vcvtq_f32_s32 (m), CONST_LN2);
123123
124124 // Polynomial Approximation
125- float32x4_t poly = vtaylor_poly_f32 (val, exp_tab);
125+ float32x4_t poly = vtaylor_polyq_f32 (val, exp_tab);
126126
127127 // Reconstruct
128128 poly = vreinterpretq_f32_s32 (vaddq_s32 (vreinterpretq_s32_f32 (poly), vshlq_n_s32 (m, 23 )));
@@ -136,7 +136,7 @@ inline float32x4_t vexp_f32(const float32x4_t &x)
136136 *
137137 * @return The calculated logarithm.
138138 */
139- inline float32x4_t vlog_f32 (const float32x4_t &x)
139+ inline float32x4_t vlogq_f32 (const float32x4_t &x)
140140{
141141 static const int32x4_t CONST_127 = vdupq_n_s32 (127 ); // 127
142142 static const float32x4_t CONST_LN2 = vdupq_n_f32 (0 .6931471805f ); // ln(2)
@@ -146,7 +146,7 @@ inline float32x4_t vlog_f32(const float32x4_t &x)
146146 float32x4_t val = vreinterpretq_f32_s32 (vsubq_s32 (vreinterpretq_s32_f32 (x), vshlq_n_s32 (m, 23 )));
147147
148148 // Polynomial Approximation
149- float32x4_t poly = vtaylor_poly_f32 (val, log_tab);
149+ float32x4_t poly = vtaylor_polyq_f32 (val, log_tab);
150150
151151 // Reconstruct
152152 poly = vmlaq_f32 (poly, vcvtq_f32_s32 (m), CONST_LN2);
@@ -158,19 +158,24 @@ inline float32x4_t vlog_f32(const float32x4_t &x)
158158 *
159159 * tanh(x) = (e^2x - 1)/(e^2x + 1)
160160 *
161+ * @note We clamp x to [-5,5] to avoid overflowing issues.
162+ *
161163 * @param val Input vector value in F32 format.
162164 *
163165 * @return The calculated Hyperbolic Tangent.
164166 */
165- inline float32x4_t vtanh_f32 (const float32x4_t &val)
167+ inline float32x4_t vtanhq_f32 (const float32x4_t &val)
166168{
167- static const float32x4_t CONST_1 = vdupq_n_f32 (1 .f ); // 1.f
168- static const float32x4_t CONST_2 = vdupq_n_f32 (2 .f ); // 2.f
169+ static const float32x4_t CONST_1 = vdupq_n_f32 (1 .f ); // 1.f
170+ static const float32x4_t CONST_2 = vdupq_n_f32 (2 .f ); // 2.f
171+ static const float32x4_t CONST_MIN_TANH = vdupq_n_f32 (-5 .f ); // -5.f
172+ static const float32x4_t CONST_MAX_TANH = vdupq_n_f32 (5 .f ); // 5.f
169173
170- float32x4_t exp2x = vexp_f32 (vmulq_f32 (CONST_2, val));
174+ float32x4_t x = vminq_f32 (vmaxq_f32 (val, CONST_MIN_TANH), CONST_MAX_TANH);
175+ float32x4_t exp2x = vexpq_f32 (vmulq_f32 (CONST_2, x));
171176 float32x4_t num = vsubq_f32 (exp2x, CONST_1);
172177 float32x4_t den = vaddq_f32 (exp2x, CONST_1);
173- float32x4_t tanh = vmulq_f32 (num, vinv_f32 (den));
178+ float32x4_t tanh = vmulq_f32 (num, vinvq_f32 (den));
174179 return tanh;
175180}
176181
@@ -185,7 +190,7 @@ inline float32x4_t vtanh_f32(const float32x4_t &val)
185190 */
186191inline float32x4_t vpowq_f32 (const float32x4_t &val, const float32x4_t &n)
187192{
188- return vexp_f32 (vmulq_f32 (n, vlog_f32 (val)));
193+ return vexpq_f32 (vmulq_f32 (n, vlogq_f32 (val)));
189194}
190195}
191196
0 commit comments