1818#include " src/__support/common.h"
1919#include " src/__support/uint128.h"
2020
21+ #include " hdr/fenv_macros.h"
22+
2123namespace LIBC_NAMESPACE {
2224namespace fputil {
2325
@@ -64,40 +66,50 @@ LIBC_INLINE void normalize<long double>(int &exponent, UInt128 &mantissa) {
6466
6567// Correctly rounded IEEE 754 SQRT for all rounding modes.
6668// Shift-and-add algorithm.
67- template <typename T>
68- LIBC_INLINE cpp::enable_if_t <cpp::is_floating_point_v<T>, T> sqrt (T x) {
69-
70- if constexpr (internal::SpecialLongDouble<T>::VALUE) {
69+ template <typename OutType, typename InType>
70+ LIBC_INLINE cpp::enable_if_t <cpp::is_floating_point_v<OutType> &&
71+ cpp::is_floating_point_v<InType> &&
72+ sizeof (OutType) <= sizeof (InType),
73+ OutType>
74+ sqrt (InType x) {
75+ if constexpr (internal::SpecialLongDouble<OutType>::VALUE &&
76+ internal::SpecialLongDouble<InType>::VALUE) {
7177 // Special 80-bit long double.
7278 return x86::sqrt (x);
7379 } else {
7480 // IEEE floating points formats.
75- using FPBits_t = typename fputil::FPBits<T>;
76- using StorageType = typename FPBits_t::StorageType;
77- constexpr StorageType ONE = StorageType (1 ) << FPBits_t::FRACTION_LEN;
78- constexpr auto FLT_NAN = FPBits_t::quiet_nan ().get_val ();
79-
80- FPBits_t bits (x);
81-
82- if (bits == FPBits_t::inf (Sign::POS) || bits.is_zero () || bits.is_nan ()) {
81+ using OutFPBits = typename fputil::FPBits<OutType>;
82+ using OutStorageType = typename OutFPBits::StorageType;
83+ using InFPBits = typename fputil::FPBits<InType>;
84+ using InStorageType = typename InFPBits::StorageType;
85+ constexpr InStorageType ONE = InStorageType (1 ) << InFPBits::FRACTION_LEN;
86+ constexpr auto FLT_NAN = OutFPBits::quiet_nan ().get_val ();
87+ constexpr int EXTRA_FRACTION_LEN =
88+ InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN;
89+ constexpr InStorageType EXTRA_FRACTION_MASK =
90+ (InStorageType (1 ) << EXTRA_FRACTION_LEN) - 1 ;
91+
92+ InFPBits bits (x);
93+
94+ if (bits == InFPBits::inf (Sign::POS) || bits.is_zero () || bits.is_nan ()) {
8395 // sqrt(+Inf) = +Inf
8496 // sqrt(+0) = +0
8597 // sqrt(-0) = -0
8698 // sqrt(NaN) = NaN
8799 // sqrt(-NaN) = -NaN
88- return x ;
100+ return static_cast <OutType>(x) ;
89101 } else if (bits.is_neg ()) {
90102 // sqrt(-Inf) = NaN
91103 // sqrt(-x) = NaN
92104 return FLT_NAN;
93105 } else {
94106 int x_exp = bits.get_exponent ();
95- StorageType x_mant = bits.get_mantissa ();
107+ InStorageType x_mant = bits.get_mantissa ();
96108
97109 // Step 1a: Normalize denormal input and append hidden bit to the mantissa
98110 if (bits.is_subnormal ()) {
99111 ++x_exp; // let x_exp be the correct exponent of ONE bit.
100- internal::normalize<T >(x_exp, x_mant);
112+ internal::normalize<InType >(x_exp, x_mant);
101113 } else {
102114 x_mant |= ONE;
103115 }
@@ -120,47 +132,105 @@ LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<T>, T> sqrt(T x) {
120132 // So the nth digit y_n of the mantissa of sqrt(x) can be found by:
121133 // y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1)
122134 // 0 otherwise.
123- StorageType y = ONE;
124- StorageType r = x_mant - ONE;
135+ InStorageType y = ONE;
136+ InStorageType r = x_mant - ONE;
125137
126- for (StorageType current_bit = ONE >> 1 ; current_bit; current_bit >>= 1 ) {
138+ for (InStorageType current_bit = ONE >> 1 ; current_bit;
139+ current_bit >>= 1 ) {
127140 r <<= 1 ;
128- StorageType tmp = (y << 1 ) + current_bit; // 2*y(n - 1) + 2^(-n-1)
141+ InStorageType tmp = (y << 1 ) + current_bit; // 2*y(n - 1) + 2^(-n-1)
129142 if (r >= tmp) {
130143 r -= tmp;
131144 y += current_bit;
132145 }
133146 }
134147
135148 // We compute one more iteration in order to round correctly.
136- bool lsb = static_cast <bool >(y & 1 ); // Least significant bit
137- bool rb = false ; // Round bit
149+ bool lsb = (y & (InStorageType (1 ) << EXTRA_FRACTION_LEN)) !=
150+ 0 ; // Least significant bit
151+ bool rb = false ; // Round bit
138152 r <<= 2 ;
139- StorageType tmp = (y << 2 ) + 1 ;
153+ InStorageType tmp = (y << 2 ) + 1 ;
140154 if (r >= tmp) {
141155 r -= tmp;
142156 rb = true ;
143157 }
144158
159+ bool sticky = false ;
160+
161+ if constexpr (EXTRA_FRACTION_LEN > 0 ) {
162+ sticky = rb || (y & EXTRA_FRACTION_MASK) != 0 ;
163+ rb = (y & (InStorageType (1 ) << (EXTRA_FRACTION_LEN - 1 ))) != 0 ;
164+ }
165+
145166 // Remove hidden bit and append the exponent field.
146- x_exp = ((x_exp >> 1 ) + FPBits_t::EXP_BIAS);
167+ x_exp = ((x_exp >> 1 ) + OutFPBits::EXP_BIAS);
168+
169+ OutStorageType y_out = static_cast <OutStorageType>(
170+ ((y - ONE) >> EXTRA_FRACTION_LEN) |
171+ (static_cast <OutStorageType>(x_exp) << OutFPBits::FRACTION_LEN));
172+
173+ if constexpr (EXTRA_FRACTION_LEN > 0 ) {
174+ if (x_exp >= OutFPBits::MAX_BIASED_EXPONENT) {
175+ switch (quick_get_round ()) {
176+ case FE_TONEAREST:
177+ case FE_UPWARD:
178+ return OutFPBits::inf ().get_val ();
179+ default :
180+ return OutFPBits::max_normal ().get_val ();
181+ }
182+ }
183+
184+ if (x_exp <
185+ -OutFPBits::EXP_BIAS - OutFPBits::SIG_LEN + EXTRA_FRACTION_LEN) {
186+ switch (quick_get_round ()) {
187+ case FE_UPWARD:
188+ return OutFPBits::min_subnormal ().get_val ();
189+ default :
190+ return OutType (0.0 );
191+ }
192+ }
147193
148- y = (y - ONE) |
149- (static_cast <StorageType>(x_exp) << FPBits_t::FRACTION_LEN);
194+ if (x_exp <= 0 ) {
195+ int underflow_extra_fraction_len = EXTRA_FRACTION_LEN - x_exp + 1 ;
196+ InStorageType underflow_extra_fraction_mask =
197+ (InStorageType (1 ) << underflow_extra_fraction_len) - 1 ;
198+
199+ rb = (y & (InStorageType (1 ) << (underflow_extra_fraction_len - 1 ))) !=
200+ 0 ;
201+ OutStorageType subnormal_mant =
202+ static_cast <OutStorageType>(y >> underflow_extra_fraction_len);
203+ lsb = (subnormal_mant & 1 ) != 0 ;
204+ sticky = sticky || (y & underflow_extra_fraction_mask) != 0 ;
205+
206+ switch (quick_get_round ()) {
207+ case FE_TONEAREST:
208+ if (rb && (lsb || sticky))
209+ ++subnormal_mant;
210+ break ;
211+ case FE_UPWARD:
212+ if (rb || sticky)
213+ ++subnormal_mant;
214+ break ;
215+ }
216+
217+ return cpp::bit_cast<OutType>(subnormal_mant);
218+ }
219+ }
150220
151221 switch (quick_get_round ()) {
152222 case FE_TONEAREST:
153223 // Round to nearest, ties to even
154224 if (rb && (lsb || (r != 0 )))
155- ++y ;
225+ ++y_out ;
156226 break ;
157227 case FE_UPWARD:
158- if (rb || (r != 0 ))
159- ++y ;
228+ if (rb || (r != 0 ) || sticky )
229+ ++y_out ;
160230 break ;
161231 }
162232
163- return cpp::bit_cast<T>(y );
233+ return cpp::bit_cast<OutType>(y_out );
164234 }
165235 }
166236}
0 commit comments