1616#include " src/__support/FPUtil/FEnvImpl.h"
1717#include " src/__support/FPUtil/FPBits.h"
1818#include " src/__support/FPUtil/ManipulationFunctions.h"
19- #include " src/__support/FPUtil/PolyEval.h"
2019#include " src/__support/FPUtil/cast.h"
2120#include " src/__support/FPUtil/multiply_add.h"
21+ #include " src/__support/FPUtil/sqrt.h"
2222#include " src/__support/macros/optimization.h"
2323
2424namespace LIBC_NAMESPACE_DECL {
@@ -28,9 +28,9 @@ static constexpr float16 rsqrtf16(float16 x) {
2828 using FPBits = fputil::FPBits<float16>;
2929 FPBits xbits (x);
3030
31- uint16_t x_u = xbits.uintval ();
32- uint16_t x_abs = x_u & 0x7fff ;
33- uint16_t x_sign = x_u >> 15 ;
31+ const uint16_t x_u = xbits.uintval ();
32+ const uint16_t x_abs = x_u & 0x7fff ;
33+ const uint16_t x_sign = x_u >> 15 ;
3434
3535 // x is NaN
3636 if (LIBC_UNLIKELY (xbits.is_nan ())) {
@@ -57,75 +57,20 @@ static constexpr float16 rsqrtf16(float16 x) {
5757
5858 // x = +inf => rsqrt(x) = 0
5959 if (LIBC_UNLIKELY (xbits.is_inf ())) {
60- return fputil::cast<float16>( 0 . 0f );
60+ return FPBits::zero (). get_val ( );
6161 }
6262
63- // x is valid, estimate the result
64- // Range reduction:
65- // x can be expressed as m*2^e, where e - int exponent and m - mantissa
66- // rsqrtf16(x) = rsqrtf16(m*2^e)
67- // rsqrtf16(m*2^e) = 1/sqrt(m) * 1/sqrt(2^e) = 1/sqrt(m) * 1/2^(e/2)
68- // 1/sqrt(m) * 1/2^(e/2) = 1/sqrt(m) * 2^(-e/2)
69-
70- // Compute in float throughout to minimize cost while preserving accuracy.
71- float xf = x;
72- int exponent = 0 ;
73- float mantissa = fputil::frexp (xf, exponent);
74-
75- float result = 0 .0f ;
76- int exp_floored = -(exponent >> 1 );
77-
78- if (mantissa == 0 .5f ) {
79- // When mantissa is 0.5f, x was a power of 2 (or subnormal that normalizes
80- // this way). 1/sqrt(0.5f) = sqrt(2.0f).
81- // If exponent is odd (exponent = 2k + 1):
82- // rsqrt(x) = (1/sqrt(0.5)) * 2^(-(2k+1)/2) = sqrt(2) * 2^(-k-0.5)
83- // = sqrt(2) * 2^(-k) * (1/sqrt(2)) = 2^(-k)
84- // exp_floored = -((2k+1)>>1) = -(k) = -k
85- // So result = ldexp(1.0f, exp_floored)
86- // If exponent is even (exponent = 2k):
87- // rsqrt(x) = (1/sqrt(0.5)) * 2^(-2k/2) = sqrt(2) * 2^(-k)
88- // exp_floored = -((2k)>>1) = -(k) = -k
89- // So result = ldexp(sqrt(2.0f), exp_floored)
90- if (exponent & 1 ) {
91- result = fputil::ldexp (1 .0f , exp_floored);
92- } else {
93- constexpr float SQRT_2_F = 0x1 .6a09e6p0f; // sqrt(2.0f)
94- result = fputil::ldexp (SQRT_2_F, exp_floored);
95- }
96- } else {
97- // Degree-5 polynomial (float coefficients) generated with Sollya:
98- // P = fpminimax(1/sqrt(x) + 2^-28, 5, [|single...|], [0.5,1])
99- float y =
100- fputil::polyeval (mantissa, 0x1 .9c81fap1f, -0x1 .e2c63ap2f , 0x1 .91e9b8p3f,
101- -0x1 .899abep3f, 0x1 .9eddeap2f, -0x1 .6bdb48p0f);
102-
103- // Newton-Raphson iteration in float (use multiply_add to leverage FMA when
104- // available):
105- float y2 = y * y;
106- float factor = fputil::multiply_add (-0 .5f * mantissa, y2, 1 .5f );
107- y = y * factor;
108-
109- result = fputil::ldexp (y, exp_floored);
110- if (exponent & 1 ) {
111- constexpr float ONE_OVER_SQRT2 = 0x1 .6a09e6p-1f ; // 1/sqrt(2)
112- result *= ONE_OVER_SQRT2;
113- }
114-
115- // Targeted post-correction: for the specific half-precision mantissa
116- // pattern M == 0x011F we observe a consistent -1 ULP bias across exponents.
117- // Apply a tiny upward nudge to cross the rounding boundary in all modes.
118- const uint16_t half_mantissa = static_cast <uint16_t >(x_abs & 0x3ff );
119- if (half_mantissa == 0x011F ) {
120- // Nudge up to fix consistent -1 ULP at that mantissa boundary
121- result = fputil::multiply_add (result, 0x1 .0p-21f ,
122- result); // result *= (1 + 2^-21)
123- } else if (half_mantissa == 0x0313 ) {
124- // Nudge down to fix +1 ULP under upward rounding at this mantissa
125- // boundary
126- result = fputil::multiply_add (result, -0x1 .0p-21f ,
127- result); // result *= (1 - 2^-21)
128- }
63+ // TODO: add integer based implementation when LIBC_TARGET_CPU_HAS_FPU_FLOAT
64+ // is not defined
65+ float result = 1 .0f / fputil::sqrt<float >(fputil::cast<float >(x));
66+
67+ // Targeted post-corrections to ensure correct rounding in half for specific
68+ // mantissa patterns
69+ const uint16_t half_mantissa = x_abs & 0x3ff ;
70+ if (half_mantissa == 0x011F ) {
71+ result = fputil::multiply_add (result, 0x1 .0p-21f , result);
72+ } else if (half_mantissa == 0x0313 ) {
73+ result = fputil::multiply_add (result, -0x1 .0p-21f , result);
12974 }
13075
13176 return fputil::cast<float16>(result);
0 commit comments