1010#include " hdr/errno_macros.h"
1111#include " hdr/fenv_macros.h"
1212#include " src/__support/FPUtil/FEnvImpl.h"
13+ #include " src/__support/FPUtil/FMA.h"
1314#include " src/__support/FPUtil/FPBits.h"
1415#include " src/__support/FPUtil/ManipulationFunctions.h"
1516#include " src/__support/FPUtil/PolyEval.h"
1617#include " src/__support/FPUtil/cast.h"
17- #include " src/__support/FPUtil/multiply_add.h" // to remove
1818#include " src/__support/macros/optimization.h"
1919
2020namespace LIBC_NAMESPACE_DECL {
@@ -55,11 +55,6 @@ LLVM_LIBC_FUNCTION(float16, rsqrtf16, (float16 x)) {
5555 return fputil::cast<float16>(0 .0f );
5656 }
5757
58- // x = 1 => rsqrt(x) = 1
59- if (LIBC_UNLIKELY (x_u == 0x1 )) {
60- return fputil::cast<float16>(1 .0f );
61- }
62-
6358 // x is valid, estimate the result
6459 // Range reduction:
6560 // x can be expressed as m*2^e, where e - int exponent and m - mantissa
@@ -72,23 +67,41 @@ LLVM_LIBC_FUNCTION(float16, rsqrtf16, (float16 x)) {
7267 float mantissa = fputil::frexp (xf, exponent);
7368
7469 // 6-degree polynomial generated using Sollya
75- // P = fpminimax(1/sqrt(x), [|0,1,2,3,4,5|], [|SG...|], [0.5, 1]);
70+ // bigger polynomial doesn't generate better results-> the current one
71+ // produces the least number of errors but still errors are presents P =
72+ // fpminimax(1/(sqrt(x)), [|0,1,2,3,4,5|], [|SG...|], [0.5, 1]);
7673 float interm =
7774 fputil::polyeval (mantissa, 0x1 .9c81c4p1f, -0x1 .e2c57cp2f , 0x1 .91e8bp3f,
7875 -0x1 .899954p3f, 0x1 .9edcp2f, -0x1 .6bd93cp0f);
7976
77+ // Apply one Newton-Raphson iteration to refine the approximation of
78+ // 1/sqrt(mantissa) y_new = y_old * (1.5 - 0.5 * mantissa * y_old^2) Using
79+ // fputil::fma for potential precision benefits in the factor calculation
80+ float interm_sq = interm * interm;
81+ float factor = fputil::fma<float >(-0 .5f * mantissa, interm_sq, 1 .5f );
82+ float interm_refined = interm * factor; // Final multiplication
83+
84+ // Apply a second Newton-Raphson iteration
85+ // y_new = y_old * (1.5 - 0.5 * mantissa * y_old^2)
86+ // y_old is now interm_refined
87+ float interm_refined_sq = interm_refined * interm_refined;
88+ float factor2 = fputil::fma<float >(-0 .5f * mantissa, interm_refined_sq, 1 .5f );
89+ float interm_refined2 = interm_refined * factor2;
90+
8091 // Round (-e/2)
8192 int exp_floored = -(exponent >> 1 );
8293
8394 // rsqrt(x) = 1/sqrt(mantissa) * 2^(-e/2)
8495 // rsqrt(x) = P(mantissa) * 2*(exp_floored)
85- float result = fputil::ldexp (interm, exp_floored);
96+ // float result = fputil::ldexp(interm, exp_floored);
97+ float result = fputil::ldexp (interm_refined2, exp_floored);
8698
8799 // Handle the case where exponent is odd
88100 if (exponent & 1 ) {
89- const float ONE_OVER_SQRT2 =
90- 0x1 .6a09e667f3bcc908b2fb1366ea957d3e3adec1751p-1f ;
91- result *= ONE_OVER_SQRT2;
101+ const float ONE_OVER_SQRT2 = 0x1 .6a09e6p-1f ;
102+ // result *= ONE_OVER_SQRT2;
103+ result = fputil::fma<float >(result, ONE_OVER_SQRT2,
104+ 0 .0f ); // Use FMA for multiplication
92105 }
93106
94107 return fputil::cast<float16>(result);
0 commit comments