2121#include " src/__support/macros/config.h"
2222#include " src/__support/macros/optimization.h"
2323#include " src/__support/macros/properties/types.h"
24- #include " src/__support/ math/expxf16_utils .h"
24+ #include " src/math/generic/common_constants .h"
2525
2626namespace LIBC_NAMESPACE_DECL {
2727
@@ -66,6 +66,7 @@ LLVM_LIBC_FUNCTION(float16, powf16, (float16 x, float16 y)) {
6666 uint16_t x_u = xbits.uintval ();
6767 uint16_t x_a = x_abs.uintval ();
6868 uint16_t y_a = y_abs.uintval ();
69+ uint16_t y_u = ybits.uintval ();
6970 bool result_sign = false ;
7071
7172 // /////// BEGIN - Check exceptional cases ////////////////////////////////////
@@ -75,10 +76,12 @@ LLVM_LIBC_FUNCTION(float16, powf16, (float16 x, float16 y)) {
7576 return FPBits::quiet_nan ().get_val ();
7677 }
7778
78- //
79- if (LIBC_UNLIKELY (ybits.is_zero () || x_u == FPBits::one ().uintval () ||
80- x_u >= FPBits::inf ().uintval () ||
81- x_u < FPBits::min_normal ().uintval ())) {
79+ if (LIBC_UNLIKELY (
80+ ybits.is_zero () || x_u == FPBits::one ().uintval () ||
81+ x_u == FPBits::one ().uintval () || x_u == FPBits::zero ().uintval () ||
82+ x_u >= FPBits::inf ().uintval () || y_u >= FPBits::inf ().uintval () ||
83+ x_u < FPBits::min_normal ().uintval () || y_a == 0x3800U ||
84+ y_a == 0x3c00U || y_a == 0x4000U || is_integer (y))) {
8285 // pow(x, 0) = 1
8386 if (ybits.is_zero ()) {
8487 return fputil::cast<float16>(1 .0f );
@@ -96,8 +99,13 @@ LLVM_LIBC_FUNCTION(float16, powf16, (float16 x, float16 y)) {
9699 (x == 0.0 || x_u == FPBits::inf (Sign::NEG).uintval ()))) {
97100 // pow(-0, 1/2) = +0
98101 // pow(-inf, 1/2) = +inf
99- // Make sure it works correctly for FTZ/DAZ.
100- return fputil::cast<float16>(y_sign ? (1.0 / (x * x)) : (x * x));
102+ // For pow(x, 0.5), sqrt(x) is used. pow(0, -0.5) is handled below.
103+ break ;
104+ }
105+ // If x is not negative or special, use sqrt(x)
106+ if (x_sign && !xbits.is_zero ()) {
107+ // pow(negative, non-integer) = NaN, handled below.
108+ break ;
101109 }
102110 return fputil::cast<float16>(y_sign ? (1.0 / fputil::sqrt<float16>(x))
103111 : fputil::sqrt<float16>(x));
@@ -163,9 +171,13 @@ LLVM_LIBC_FUNCTION(float16, powf16, (float16 x, float16 y)) {
163171 bool x_abs_less_than_one = x_a < FPBits::one ().uintval ();
164172 if ((x_abs_less_than_one && !y_sign) ||
165173 (!x_abs_less_than_one && y_sign)) {
174+ // |x| < 1 and y = +inf => 0.0
175+ // |x| > 1 and y = -inf => 0.0
166176 return fputil::cast<float16>(0 .0f );
167177 } else {
168- return FPBits::inf ().get_val ();
178+ // |x| > 1 and y = +inf => +inf
179+ // |x| < 1 and y = -inf => +inf
180+ return FPBits::inf (Sign::POS).get_val ();
169181 }
170182 }
171183
@@ -175,7 +187,31 @@ LLVM_LIBC_FUNCTION(float16, powf16, (float16 x, float16 y)) {
175187 fputil::raise_except_if_required (FE_INVALID);
176188 return FPBits::quiet_nan ().get_val ();
177189 }
190+ if (is_integer (y)) {
191+ double base = x_abs.get_val ();
192+ double res = 1.0 ;
193+ int yi = static_cast <int >(y_abs.get_val ());
194+
195+ while (yi > 0 ) {
196+ if (yi % 2 == 1 )
197+ res *= base;
198+ base *= base;
199+ yi /= 2 ;
200+ }
201+
202+ if (y_sign) {
203+ res = 1.0 / res;
204+ }
205+
206+ float16 final_res = fputil::cast<float16>(res);
178207
208+ if (x_sign && is_odd_integer (y)) {
209+ FPBits res_bits (final_res);
210+ res_bits.set_sign (Sign::NEG);
211+ return res_bits.get_val ();
212+ }
213+ return final_res;
214+ }
179215 // For negative x with integer y, compute pow(|x|, y) and adjust sign
180216 if (x_sign) {
181217 x = -x;
@@ -184,72 +220,109 @@ LLVM_LIBC_FUNCTION(float16, powf16, (float16 x, float16 y)) {
184220 }
185221 }
186222 }
223+
187224 // /////// END - Check exceptional cases //////////////////////////////////////
188225
189- // x^y = 2^( y * log2(x) )
190- // = 2^( y * ( e_x + log2(m_x) ) )
191- // First we compute log2(x) = e_x + log2(m_x)
226+ // Core computation: x^y = 2^( y * log2(x) )
227+ // We compute log2(x) = log(x) / log(2) using a polynomial approximation.
192228
193- using namespace math ::expxf16_internal ;
229+ // The exponent part (m) is added later to get the final log(x).
194230 FPBits x_bits (x);
195-
196231 uint16_t x_u_log = x_bits.uintval ();
197232
198233 // Extract exponent field of x.
199- int m = -FPBits::EXP_BIAS;
200-
201- // When x is subnormal, normalize it by multiplying by 2^FRACTION_LEN.
202- if ((x_u_log & FPBits::EXP_MASK) == 0U ) { // Subnormal x
203- constexpr double NORMALIZE_EXP = 1.0 * (1U << FPBits::FRACTION_LEN);
204- x_bits = FPBits (fputil::cast<float16>(
205- fputil::cast<double >(x_bits.get_val ()) * NORMALIZE_EXP));
206- x_u_log = x_bits.uintval ();
207- m -= FPBits::FRACTION_LEN;
234+ int m = x_bits.get_exponent ();
235+
236+ // When x is subnormal, normalize it by adjusting m.
237+ if ((x_u_log & FPBits::EXP_MASK) == 0U ) {
238+ unsigned leading_zeros =
239+ cpp::countl_zero (static_cast <uint32_t >(x_u_log)) - (32 - 16 );
240+
241+ constexpr unsigned SUBNORMAL_SHIFT_CORRECTION = 5 ;
242+ unsigned shift = leading_zeros - SUBNORMAL_SHIFT_CORRECTION;
243+
244+ x_bits.set_mantissa (static_cast <uint16_t >(x_u_log << shift));
245+
246+ m = 1 - FPBits::EXP_BIAS - static_cast <int >(shift);
208247 }
248+
209249 // Extract the mantissa and index into small lookup tables.
210250 uint16_t mant = x_bits.get_mantissa ();
211- // Use the highest 5 fractional bits of the mantissa as the index f.
212- int f = mant >> 5 ;
213-
214- m += (x_u_log >> FPBits::FRACTION_LEN);
251+ // Use the highest 7 fractional bits of the mantissa as the index f.
252+ int f = mant >> (FPBits::FRACTION_LEN - 7 );
215253
216- // Add the hidden bit to the mantissa.
217- // 1 <= m_x < 2
254+ // Reconstruct the mantissa value m_x so it's in the range [1.0, 2.0).
218255 x_bits.set_biased_exponent (FPBits::EXP_BIAS);
219256 double mant_d = x_bits.get_val ();
220257
221- // Range reduction for log2(m_x):
222- // v = r * m_x - 1, where r is a power of 2 from a lookup table.
223- // The computation is exact for half-precision, and -2^-5 <= v < 2^-4.
224- // Then m_x = (1 + v) / r, and log2(m_x) = log2(1 + v) - log2(r).
258+ // Range reduction for log(m_x):
259+ // v = r * m_x - 1
225260
226- double v =
227- fputil::multiply_add (mant_d, fputil::cast< double >(ONE_OVER_F_F[f]), - 1.0 );
228- // For half-precision accuracy, we use a degree-2 polynomial approximation:
229- // P(v) ~ log2(1 + v) / v
230- // Generated by Sollya with:
231- // > P = fpminimax(log2(1+x)/x, 2, [|D...|], [-2^-5, 2^-4]);
232- // The coefficients are rounded from the Sollya output.
261+ // log(1+v) = v + v^2 * P(v)
262+
263+ // log(m_x) = log(1+v) - log(r).
264+
265+ double log_m_x;
266+
267+ double v = fputil::multiply_add< double >(mant_d, R[f], - 1.0 );
233268
234- double log2p1_d_over_f =
235- v * fputil::polyeval (v, 0x1 .715476p+ 0 , - 0x1 .71771ap- 1 , 0x1 . ecb38ep - 2 );
269+ double p_v = fputil::polyeval (v, LOG_COEFFS[ 0 ], LOG_COEFFS[ 1 ], LOG_COEFFS[ 2 ],
270+ LOG_COEFFS[ 3 ], LOG_COEFFS[ 4 ], LOG_COEFFS[ 5 ] );
236271
237- // log2(1.mant ) = log2(f) + log2(1 + v)
238- double log2_1_mant = LOG2F_F[f] + log2p1_d_over_f ;
272+ // log(1+v ) = v + v^2 * P( v)
273+ double logp1_v = fputil::multiply_add< double >(v * v, p_v, v) ;
239274
240- // Complete log2(x) = e_x + log2(m_x)
241- double log2_x = static_cast <double >(m) + log2_1_mant;
275+ // log(m_x) = log(1+v) - log(r).
276+ log_m_x = logp1_v + LOG_R[f];
277+
278+ // Complete log(x) = m * log(2) + log(m_x).
279+ double log_x =
280+ fputil::multiply_add<double >(static_cast <double >(m), LOG_2_HI, log_m_x);
281+ // Convert to log2(x): log2(x) = log(x) / log(2)
282+ double log2_x = log_x / LOG_2_HI;
242283
243284 // z = y * log2(x)
244285 // Now compute 2^z = 2^(n + r), with n integer and r in [-0.5, 0.5].
245286 double z = fputil::cast<double >(y) * log2_x;
246287
247- // Check for overflow/underflow for half-precision.
248- // Half-precision range is approximately 2^-24 to 2^15.
249- //
250- if (z < -24.0 ) {
288+ // Check for underflow. Half-precision min normal exponent is -14 ,
289+ // but the smallest subnormal is 2^-24.
290+ if (LIBC_UNLIKELY (z < -25.0 )) {
251291 fputil::raise_except_if_required (FE_UNDERFLOW);
252- return fputil::cast<float16>(0 .0f );
292+ // Determine sign of underflowed result
293+ return result_sign ? FPBits::zero (Sign::NEG).get_val ()
294+ : FPBits::zero (Sign::POS).get_val ();
295+ }
296+
297+ // Check for overflow. The max value of float16 is ~2^16, so z > 16.0 will
298+ // overflow.
299+ if (LIBC_UNLIKELY (z > 16.0 )) {
300+ fputil::raise_except_if_required (FE_OVERFLOW);
301+
302+ float16 max_finite = FPBits::max_normal ().get_val ();
303+ float16 neg_max_finite = FPBits::max_normal (Sign::NEG).get_val ();
304+
305+ int round_mode = fputil::get_round ();
306+
307+ if (result_sign) {
308+ // Negative result overflows
309+ // For TOWARDZERO or DOWNWARD rounding, the result is the most negative
310+ // finite value.
311+ if (round_mode == FE_TOWARDZERO || round_mode == FE_DOWNWARD) {
312+ return neg_max_finite;
313+ }
314+ // For other rounding modes, the result is -Infinity.
315+ return FPBits::inf (Sign::NEG).get_val ();
316+ } else {
317+ // Positive result overflows
318+ // For TOWARDZERO or DOWNWARD rounding, the result is the max finite
319+ // value.
320+ if (round_mode == FE_TOWARDZERO || round_mode == FE_DOWNWARD) {
321+ return max_finite;
322+ }
323+ // For other rounding modes, the result is +Infinity.
324+ return FPBits::inf (Sign::POS).get_val ();
325+ }
253326 }
254327
255328 double n = fputil::nearest_integer (z);
@@ -278,12 +351,7 @@ LLVM_LIBC_FUNCTION(float16, powf16, (float16 x, float16 y)) {
278351 int n_int = static_cast <int >(n);
279352 uint64_t exp_bits = static_cast <uint64_t >(n_int + 1023 ) << 52 ;
280353 double pow2_n = cpp::bit_cast<double >(exp_bits);
281-
282-
283- double result_d = (pow2_n * exp2_r);
284- float16 result = fputil::cast<float16>(result_d);
285- if (result_d==65504.0 )
286- return (65504 .f16 );
354+ float16 result = fputil::cast<float16>((pow2_n * exp2_r));
287355
288356 if (result_sign) {
289357 FPBits result_bits (result);
0 commit comments