Skip to content

Commit 7cb3fb1

Browse files
committed
- Refactored the rsqrtf16 implementation to temporarily use hardware for calling sqrt
- Adjusted the results from fixed-point call of hardware instruction to match the correctness required in the Libc - TODO: In the next PR add int-based approximation of the function for scenario where floats are not available in the hardware
1 parent a5b246b commit 7cb3fb1

File tree

1 file changed

+16
-71
lines changed

1 file changed

+16
-71
lines changed

libc/src/__support/math/rsqrtf16.h

Lines changed: 16 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
#include "src/__support/FPUtil/FEnvImpl.h"
1717
#include "src/__support/FPUtil/FPBits.h"
1818
#include "src/__support/FPUtil/ManipulationFunctions.h"
19-
#include "src/__support/FPUtil/PolyEval.h"
2019
#include "src/__support/FPUtil/cast.h"
2120
#include "src/__support/FPUtil/multiply_add.h"
21+
#include "src/__support/FPUtil/sqrt.h"
2222
#include "src/__support/macros/optimization.h"
2323

2424
namespace LIBC_NAMESPACE_DECL {
@@ -28,9 +28,9 @@ static constexpr float16 rsqrtf16(float16 x) {
2828
using FPBits = fputil::FPBits<float16>;
2929
FPBits xbits(x);
3030

31-
uint16_t x_u = xbits.uintval();
32-
uint16_t x_abs = x_u & 0x7fff;
33-
uint16_t x_sign = x_u >> 15;
31+
const uint16_t x_u = xbits.uintval();
32+
const uint16_t x_abs = x_u & 0x7fff;
33+
const uint16_t x_sign = x_u >> 15;
3434

3535
// x is NaN
3636
if (LIBC_UNLIKELY(xbits.is_nan())) {
@@ -57,75 +57,20 @@ static constexpr float16 rsqrtf16(float16 x) {
5757

5858
// x = +inf => rsqrt(x) = 0
5959
if (LIBC_UNLIKELY(xbits.is_inf())) {
60-
return fputil::cast<float16>(0.0f);
60+
return FPBits::zero().get_val();
6161
}
6262

63-
// x is valid, estimate the result
64-
// Range reduction:
65-
// x can be expressed as m*2^e, where e - int exponent and m - mantissa
66-
// rsqrtf16(x) = rsqrtf16(m*2^e)
67-
// rsqrtf16(m*2^e) = 1/sqrt(m) * 1/sqrt(2^e) = 1/sqrt(m) * 1/2^(e/2)
68-
// 1/sqrt(m) * 1/2^(e/2) = 1/sqrt(m) * 2^(-e/2)
69-
70-
// Compute in float throughout to minimize cost while preserving accuracy.
71-
float xf = x;
72-
int exponent = 0;
73-
float mantissa = fputil::frexp(xf, exponent);
74-
75-
float result = 0.0f;
76-
int exp_floored = -(exponent >> 1);
77-
78-
if (mantissa == 0.5f) {
79-
// When mantissa is 0.5f, x was a power of 2 (or subnormal that normalizes
80-
// this way). 1/sqrt(0.5f) = sqrt(2.0f).
81-
// If exponent is odd (exponent = 2k + 1):
82-
// rsqrt(x) = (1/sqrt(0.5)) * 2^(-(2k+1)/2) = sqrt(2) * 2^(-k-0.5)
83-
// = sqrt(2) * 2^(-k) * (1/sqrt(2)) = 2^(-k)
84-
// exp_floored = -((2k+1)>>1) = -(k) = -k
85-
// So result = ldexp(1.0f, exp_floored)
86-
// If exponent is even (exponent = 2k):
87-
// rsqrt(x) = (1/sqrt(0.5)) * 2^(-2k/2) = sqrt(2) * 2^(-k)
88-
// exp_floored = -((2k)>>1) = -(k) = -k
89-
// So result = ldexp(sqrt(2.0f), exp_floored)
90-
if (exponent & 1) {
91-
result = fputil::ldexp(1.0f, exp_floored);
92-
} else {
93-
constexpr float SQRT_2_F = 0x1.6a09e6p0f; // sqrt(2.0f)
94-
result = fputil::ldexp(SQRT_2_F, exp_floored);
95-
}
96-
} else {
97-
// Degree-5 polynomial (float coefficients) generated with Sollya:
98-
// P = fpminimax(1/sqrt(x) + 2^-28, 5, [|single...|], [0.5,1])
99-
float y =
100-
fputil::polyeval(mantissa, 0x1.9c81fap1f, -0x1.e2c63ap2f, 0x1.91e9b8p3f,
101-
-0x1.899abep3f, 0x1.9eddeap2f, -0x1.6bdb48p0f);
102-
103-
// Newton-Raphson iteration in float (use multiply_add to leverage FMA when
104-
// available):
105-
float y2 = y * y;
106-
float factor = fputil::multiply_add(-0.5f * mantissa, y2, 1.5f);
107-
y = y * factor;
108-
109-
result = fputil::ldexp(y, exp_floored);
110-
if (exponent & 1) {
111-
constexpr float ONE_OVER_SQRT2 = 0x1.6a09e6p-1f; // 1/sqrt(2)
112-
result *= ONE_OVER_SQRT2;
113-
}
114-
115-
// Targeted post-correction: for the specific half-precision mantissa
116-
// pattern M == 0x011F we observe a consistent -1 ULP bias across exponents.
117-
// Apply a tiny upward nudge to cross the rounding boundary in all modes.
118-
const uint16_t half_mantissa = static_cast<uint16_t>(x_abs & 0x3ff);
119-
if (half_mantissa == 0x011F) {
120-
// Nudge up to fix consistent -1 ULP at that mantissa boundary
121-
result = fputil::multiply_add(result, 0x1.0p-21f,
122-
result); // result *= (1 + 2^-21)
123-
} else if (half_mantissa == 0x0313) {
124-
// Nudge down to fix +1 ULP under upward rounding at this mantissa
125-
// boundary
126-
result = fputil::multiply_add(result, -0x1.0p-21f,
127-
result); // result *= (1 - 2^-21)
128-
}
63+
// TODO: add integer based implementation when LIBC_TARGET_CPU_HAS_FPU_FLOAT
64+
// is not defined
65+
float result = 1.0f / fputil::sqrt<float>(fputil::cast<float>(x));
66+
67+
// Targeted post-corrections to ensure correct rounding in half for specific
68+
// mantissa patterns
69+
const uint16_t half_mantissa = x_abs & 0x3ff;
70+
if (half_mantissa == 0x011F) {
71+
result = fputil::multiply_add(result, 0x1.0p-21f, result);
72+
} else if (half_mantissa == 0x0313) {
73+
result = fputil::multiply_add(result, -0x1.0p-21f, result);
12974
}
13075

13176
return fputil::cast<float16>(result);

0 commit comments

Comments
 (0)