11// ===-- Implementation of sqrtf128 function -------------------------------===//
22//
3- // Copyright (c) 2024 Alexei Sibidanov <[email protected] >4- //
53// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
64// See https://llvm.org/LICENSE.txt for license information.
75// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
1715#include " src/__support/macros/optimization.h"
1816#include " src/__support/uint128.h"
1917
18+ // Compute sqrtf128 with correct rounding for all rounding modes using integer
19+ // arithmetic by Alexei Sibidanov ([email protected] ):20+ // Let the input be expressed as x = 2^e * m_x,
21+ // - Step 1: Range reduction
22+ // Let x_reduced = 2^(e % 2) * m_x,
23+ // Then sqrt(x) = 2^(e / 2) * sqrt(x_reduced), with
24+ // 1 <= x_reduced < 4.
25+ // - Step 2: Polynomial approximation
26+ // Approximate 1/sqrt(x_reduced) using polynomial approximation with the
27+ // result errors bounded by:
28+ // |r0 - 1/sqrt(x_reduced)| < 2^-32.
29+ // The computations are done in uint64_t.
30+ // - Step 3: First Newton iteration
31+ // Let the scaled error defined by:
32+ // h0 = r0^2 * x_reduced - 1.
33+ // Then we compute the first Newton iteration:
34+ // r1 = r0 - r0 * h0 / 2.
35+ // The result is then bounded by:
36+ // |r1 - 1 / sqrt(x_reduced)| < 2^-62.
37+ // - Step 4: Second Newton iteration
38+ // We calculate the scaled error from Step 3:
39+ // h1 = r1^2 * x_reduced - 1.
40+ // Then the second Newton iteration is computed by:
41+ // r2 = x_reduced * (r1 - r1 * h0 / 2)
42+ // ~ x_reduced * (1/sqrt(x_reduced)) = sqrt(x_reduced)
43+ // - Step 5: Perform rounding test and correction if needed.
44+ // Rounding correction is done by computing the exact rounding errors:
45+ // x_reduced - r2^2.
46+
2047namespace LIBC_NAMESPACE_DECL {
2148
2249using FPBits = fputil::FPBits<float128>;
@@ -35,11 +62,11 @@ inline constexpr uint64_t prod_hi<uint64_t>(uint64_t x, uint64_t y) {
3562
3663// Get high part of unsigned 128x64 bit multiplication.
3764template <>
38- inline constexpr UInt128 prod_hi<UInt128, uint64_t >(UInt128 y , uint64_t x ) {
39- uint64_t y_lo = static_cast <uint64_t >(y );
40- uint64_t y_hi = static_cast <uint64_t >(y >> 64 );
41- UInt128 xyl = static_cast <UInt128>(x ) * static_cast <UInt128>(y_lo );
42- UInt128 xyh = static_cast <UInt128>(x ) * static_cast <UInt128>(y_hi );
65+ inline constexpr UInt128 prod_hi<UInt128, uint64_t >(UInt128 x , uint64_t y ) {
66+ uint64_t x_lo = static_cast <uint64_t >(x );
67+ uint64_t x_hi = static_cast <uint64_t >(x >> 64 );
68+ UInt128 xyl = static_cast <UInt128>(x_lo ) * static_cast <UInt128>(y );
69+ UInt128 xyh = static_cast <UInt128>(x_hi ) * static_cast <UInt128>(y );
4370 return xyh + (xyl >> 64 );
4471}
4572
@@ -178,11 +205,11 @@ LIBC_INLINE uint64_t rsqrt_approx(uint64_t m) {
178205 // r1 = r0 - r0 * h / 2
179206 // which has error bounded by:
180207 // |r1 - 1/sqrt(x)| < h^2 / 2.
181- uint64_t r2 = prod_hi< uint64_t > (r, r);
208+ uint64_t r2 = prod_hi (r, r);
182209 // h = r0^2*x - 1.
183- int64_t h = static_cast <int64_t >(prod_hi< uint64_t > (m, r2) + r2);
210+ int64_t h = static_cast <int64_t >(prod_hi (m, r2) + r2);
184211 // hr = r * h / 2
185- int64_t hr = prod_hi< int64_t > (h, static_cast <int64_t >(r >> 1 ));
212+ int64_t hr = prod_hi (h, static_cast <int64_t >(r >> 1 ));
186213 r -= hr;
187214 // Adjust in the unlucky case x~1;
188215 if (LIBC_UNLIKELY (!r))
@@ -224,8 +251,10 @@ LLVM_LIBC_FUNCTION(float128, sqrtf128, (float128 x)) {
224251 fputil::raise_except_if_required (FE_INVALID);
225252 return xbits.quiet_nan ().get_val ();
226253 }
227- // x is subnormal or x=+0
228- if (x == 0 )
254+ // Now x is subnormal or x = +0.
255+
256+ // x is +0.
257+ if (x_u == 0 )
229258 return x;
230259
231260 // Normalize subnormal inputs.
@@ -253,7 +282,7 @@ LLVM_LIBC_FUNCTION(float128, sqrtf128, (float128 x)) {
253282 0xb504f333f9de6484 /* 2^64/sqrt(2) */ };
254283
255284 // Approximate 1/sqrt(1 + x_frac)
256- // Error: |r_1 - 1/sqrt(x)| < 2^-63 .
285+ // Error: |r_1 - 1/sqrt(x)| < 2^-62 .
257286 uint64_t r1 = rsqrt_approx (static_cast <uint64_t >(x_frac >> 64 ));
258287 // Adjust for the even/odd exponent.
259288 uint64_t r2 = prod_hi (r1, RSQRT_2[i]);
@@ -279,8 +308,9 @@ LLVM_LIBC_FUNCTION(float128, sqrtf128, (float128 x)) {
279308 uint32_t nrst = rm == FE_TONEAREST;
280309 // The result lies within (-2,5) of true square root so we now
281310 // test that we can correctly round the result taking into account
282- // the rounding mode
283- // check the lowest 14 bits.
311+ // the rounding mode.
312+ // Check the lowest 14 bits (by clearing and sign-extending the top
313+ // 32 - 14 = 18 bits).
284314 int dd = (static_cast <int >(v) << 18 ) >> 18 ;
285315
286316 if (LIBC_UNLIKELY (dd < 4 && dd >= -8 )) { // can round correctly?
@@ -289,17 +319,16 @@ LLVM_LIBC_FUNCTION(float128, sqrtf128, (float128 x)) {
289319 // compare with the initial argument.
290320 UInt128 m = v >> 15 ;
291321 UInt128 m2 = m * m;
292- Int128 t0, t1;
293322 // The difference of the squared result and the argument
294- t0 = static_cast <Int128>(m2 - (x_reduced << 98 ));
323+ Int128 t0 = static_cast <Int128>(m2 - (x_reduced << 98 ));
295324 if (t0 == 0 ) {
296325 // the square root is exact
297326 v = m << 15 ;
298327 } else {
299328 // Add +-1 ulp to m depend on the sign of the difference. Here
300329 // we do not need to square again since (m+1)^2 = m^2 + 2*m +
301330 // 1 so just need to add shifted m and 1.
302- t1 = t0;
331+ Int128 t1 = t0;
303332 Int128 sgn = t0 >> 127 ; // sign of the difference
304333 t1 -= (m << 1 ) ^ sgn;
305334 t1 += 1 + sgn;
@@ -332,20 +361,20 @@ LLVM_LIBC_FUNCTION(float128, sqrtf128, (float128 x)) {
332361 rnd = frac >> 14 ; // round to nearest tie to even
333362 } else if (rm == FE_UPWARD) {
334363 rnd = !!frac; // round up
335- } else if (rm == FE_DOWNWARD) {
336- rnd = 0 ; // round down
337364 } else {
338- rnd = 0 ; // round to zero
365+ rnd = 0 ; // round down or round to zero
339366 }
340367
341368 v >>= 15 ; // position mantissa
342369 v += rnd; // round
343370
344- // // Set inexact flag only if square root is inexact
345- // // TODO: We will have to raise FE_INEXACT most of the time, but this
346- // // operation is very costly, especially in x86-64, since technically, it
347- // // needs to synchronize both SSE and x87 flags. Need to investigate
348- // // further to see how we can make this performant.
371+ // Set inexact flag only if square root is inexact
372+ // TODO: We will have to raise FE_INEXACT most of the time, but this
373+ // operation is very costly, especially in x86-64, since technically, it
374+ // needs to synchronize both SSE and x87 flags. Need to investigate
375+ // further to see how we can make this performant.
376+ // https://github.com/llvm/llvm-project/issues/126753
377+
349378 // if(frac) fputil::raise_except_if_required(FE_INEXACT);
350379
351380 v += static_cast <UInt128>(e2 ) << FPBits::FRACTION_LEN; // place exponent
0 commit comments