Skip to content

Commit 020855d

Browse files
committed
update powf16 approach
1 parent 13dd3ff commit 020855d

File tree

2 files changed

+128
-57
lines changed

2 files changed

+128
-57
lines changed

libc/src/math/generic/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1600,7 +1600,6 @@ add_entrypoint_object(
16001600
libc.src.__support.FPUtil.nearest_integer
16011601
libc.src.__support.FPUtil.sqrt
16021602
libc.src.__support.macros.optimization
1603-
libc.src.__support.math.expxf16_utils
16041603
)
16051604

16061605
add_entrypoint_object(

libc/src/math/generic/powf16.cpp

Lines changed: 128 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
#include "src/__support/macros/config.h"
2222
#include "src/__support/macros/optimization.h"
2323
#include "src/__support/macros/properties/types.h"
24-
#include "src/__support/math/expxf16_utils.h"
24+
#include "src/math/generic/common_constants.h"
2525

2626
namespace LIBC_NAMESPACE_DECL {
2727

@@ -66,6 +66,7 @@ LLVM_LIBC_FUNCTION(float16, powf16, (float16 x, float16 y)) {
6666
uint16_t x_u = xbits.uintval();
6767
uint16_t x_a = x_abs.uintval();
6868
uint16_t y_a = y_abs.uintval();
69+
uint16_t y_u = ybits.uintval();
6970
bool result_sign = false;
7071

7172
///////// BEGIN - Check exceptional cases ////////////////////////////////////
@@ -75,10 +76,12 @@ LLVM_LIBC_FUNCTION(float16, powf16, (float16 x, float16 y)) {
7576
return FPBits::quiet_nan().get_val();
7677
}
7778

78-
//
79-
if (LIBC_UNLIKELY(ybits.is_zero() || x_u == FPBits::one().uintval() ||
80-
x_u >= FPBits::inf().uintval() ||
81-
x_u < FPBits::min_normal().uintval())) {
79+
if (LIBC_UNLIKELY(
80+
ybits.is_zero() || x_u == FPBits::one().uintval() ||
81+
x_u == FPBits::one().uintval() || x_u == FPBits::zero().uintval() ||
82+
x_u >= FPBits::inf().uintval() || y_u >= FPBits::inf().uintval() ||
83+
x_u < FPBits::min_normal().uintval() || y_a == 0x3800U ||
84+
y_a == 0x3c00U || y_a == 0x4000U || is_integer(y))) {
8285
// pow(x, 0) = 1
8386
if (ybits.is_zero()) {
8487
return fputil::cast<float16>(1.0f);
@@ -96,8 +99,13 @@ LLVM_LIBC_FUNCTION(float16, powf16, (float16 x, float16 y)) {
9699
(x == 0.0 || x_u == FPBits::inf(Sign::NEG).uintval()))) {
97100
// pow(-0, 1/2) = +0
98101
// pow(-inf, 1/2) = +inf
99-
// Make sure it works correctly for FTZ/DAZ.
100-
return fputil::cast<float16>(y_sign ? (1.0 / (x * x)) : (x * x));
102+
// For pow(x, 0.5), sqrt(x) is used. pow(0, -0.5) is handled below.
103+
break;
104+
}
105+
// If x is not negative or special, use sqrt(x)
106+
if (x_sign && !xbits.is_zero()) {
107+
// pow(negative, non-integer) = NaN, handled below.
108+
break;
101109
}
102110
return fputil::cast<float16>(y_sign ? (1.0 / fputil::sqrt<float16>(x))
103111
: fputil::sqrt<float16>(x));
@@ -163,9 +171,13 @@ LLVM_LIBC_FUNCTION(float16, powf16, (float16 x, float16 y)) {
163171
bool x_abs_less_than_one = x_a < FPBits::one().uintval();
164172
if ((x_abs_less_than_one && !y_sign) ||
165173
(!x_abs_less_than_one && y_sign)) {
174+
// |x| < 1 and y = +inf => 0.0
175+
// |x| > 1 and y = -inf => 0.0
166176
return fputil::cast<float16>(0.0f);
167177
} else {
168-
return FPBits::inf().get_val();
178+
// |x| > 1 and y = +inf => +inf
179+
// |x| < 1 and y = -inf => +inf
180+
return FPBits::inf(Sign::POS).get_val();
169181
}
170182
}
171183

@@ -175,7 +187,31 @@ LLVM_LIBC_FUNCTION(float16, powf16, (float16 x, float16 y)) {
175187
fputil::raise_except_if_required(FE_INVALID);
176188
return FPBits::quiet_nan().get_val();
177189
}
190+
if (is_integer(y)) {
191+
double base = x_abs.get_val();
192+
double res = 1.0;
193+
int yi = static_cast<int>(y_abs.get_val());
194+
195+
while (yi > 0) {
196+
if (yi % 2 == 1)
197+
res *= base;
198+
base *= base;
199+
yi /= 2;
200+
}
201+
202+
if (y_sign) {
203+
res = 1.0 / res;
204+
}
205+
206+
float16 final_res = fputil::cast<float16>(res);
178207

208+
if (x_sign && is_odd_integer(y)) {
209+
FPBits res_bits(final_res);
210+
res_bits.set_sign(Sign::NEG);
211+
return res_bits.get_val();
212+
}
213+
return final_res;
214+
}
179215
// For negative x with integer y, compute pow(|x|, y) and adjust sign
180216
if (x_sign) {
181217
x = -x;
@@ -184,72 +220,113 @@ LLVM_LIBC_FUNCTION(float16, powf16, (float16 x, float16 y)) {
184220
}
185221
}
186222
}
223+
187224
///////// END - Check exceptional cases //////////////////////////////////////
188225

189-
// x^y = 2^( y * log2(x) )
190-
// = 2^( y * ( e_x + log2(m_x) ) )
191-
// First we compute log2(x) = e_x + log2(m_x)
226+
// Core computation: x^y = 2^( y * log2(x) )
227+
// We compute log2(x) = log(x) / log(2) using a polynomial approximation.
192228

193-
using namespace math::expxf16_internal;
229+
// The exponent part (m) is added later to get the final log(x).
194230
FPBits x_bits(x);
195-
196231
uint16_t x_u_log = x_bits.uintval();
197232

198233
// Extract exponent field of x.
199-
int m = -FPBits::EXP_BIAS;
200-
201-
// When x is subnormal, normalize it by multiplying by 2^FRACTION_LEN.
202-
if ((x_u_log & FPBits::EXP_MASK) == 0U) { // Subnormal x
203-
constexpr double NORMALIZE_EXP = 1.0 * (1U << FPBits::FRACTION_LEN);
204-
x_bits = FPBits(fputil::cast<float16>(
205-
fputil::cast<double>(x_bits.get_val()) * NORMALIZE_EXP));
206-
x_u_log = x_bits.uintval();
207-
m -= FPBits::FRACTION_LEN;
234+
int m = x_bits.get_exponent();
235+
236+
// When x is subnormal, normalize it by adjusting m.
237+
if ((x_u_log & FPBits::EXP_MASK) == 0U) {
238+
unsigned leading_zeros =
239+
cpp::countl_zero(static_cast<uint32_t>(x_u_log)) - (32 - 16);
240+
241+
constexpr unsigned SUBNORMAL_SHIFT_CORRECTION = 5;
242+
unsigned shift = leading_zeros - SUBNORMAL_SHIFT_CORRECTION;
243+
244+
x_bits.set_mantissa(static_cast<uint16_t>(x_u_log << shift));
245+
246+
m = 1 - FPBits::EXP_BIAS - static_cast<int>(shift);
208247
}
248+
209249
// Extract the mantissa and index into small lookup tables.
210250
uint16_t mant = x_bits.get_mantissa();
211-
// Use the highest 5 fractional bits of the mantissa as the index f.
212-
int f = mant >> 5;
213-
214-
m += (x_u_log >> FPBits::FRACTION_LEN);
251+
// Use the highest 7 fractional bits of the mantissa as the index f.
252+
int f = mant >> (FPBits::FRACTION_LEN - 7);
215253

216-
// Add the hidden bit to the mantissa.
217-
// 1 <= m_x < 2
254+
// Reconstruct the mantissa value m_x so it's in the range [1.0, 2.0).
218255
x_bits.set_biased_exponent(FPBits::EXP_BIAS);
219256
double mant_d = x_bits.get_val();
220257

221-
// Range reduction for log2(m_x):
222-
// v = r * m_x - 1, where r is a power of 2 from a lookup table.
223-
// The computation is exact for half-precision, and -2^-5 <= v < 2^-4.
224-
// Then m_x = (1 + v) / r, and log2(m_x) = log2(1 + v) - log2(r).
258+
// Range reduction for log(m_x):
259+
// v = r * m_x - 1
225260

226-
double v =
227-
fputil::multiply_add(mant_d, fputil::cast<double>(ONE_OVER_F_F[f]), -1.0);
228-
// For half-precision accuracy, we use a degree-2 polynomial approximation:
229-
// P(v) ~ log2(1 + v) / v
230-
// Generated by Sollya with:
231-
// > P = fpminimax(log2(1+x)/x, 2, [|D...|], [-2^-5, 2^-4]);
232-
// The coefficients are rounded from the Sollya output.
261+
// log(1+v) = v + v^2 * P(v)
262+
263+
// log(m_x) = log(1+v) - log(r).
264+
265+
double log_m_x;
266+
267+
double v = fputil::multiply_add<double>(mant_d, R[f], -1.0);
233268

234-
double log2p1_d_over_f =
235-
v * fputil::polyeval(v, 0x1.715476p+0, -0x1.71771ap-1, 0x1.ecb38ep-2);
269+
double p_v = fputil::polyeval(v, LOG_COEFFS[0], LOG_COEFFS[1], LOG_COEFFS[2],
270+
LOG_COEFFS[3], LOG_COEFFS[4], LOG_COEFFS[5]);
236271

237-
// log2(1.mant) = log2(f) + log2(1 + v)
238-
double log2_1_mant = LOG2F_F[f] + log2p1_d_over_f;
272+
// log(1+v) = v + v^2 * P(v)
273+
double logp1_v = fputil::multiply_add<double>(v * v, p_v, v);
239274

240-
// Complete log2(x) = e_x + log2(m_x)
241-
double log2_x = static_cast<double>(m) + log2_1_mant;
275+
// log(m_x) = log(1+v) - log(r).
276+
const double LOG_R_127 = 0x1.62e42fefa38p-1;
277+
if (f == 127)
278+
log_m_x = logp1_v + LOG_R_127;
279+
else
280+
log_m_x = logp1_v + LOG_R[f];
281+
282+
// Complete log(x) = m * log(2) + log(m_x).
283+
double log_x =
284+
fputil::multiply_add<double>(static_cast<double>(m), LOG_2_HI, log_m_x);
285+
// Convert to log2(x): log2(x) = log(x) / log(2)
286+
double log2_x = log_x / LOG_2_HI;
242287

243288
// z = y * log2(x)
244289
// Now compute 2^z = 2^(n + r), with n integer and r in [-0.5, 0.5].
245290
double z = fputil::cast<double>(y) * log2_x;
246291

247-
// Check for overflow/underflow for half-precision.
248-
// Half-precision range is approximately 2^-24 to 2^15.
249-
//
250-
if (z < -24.0) {
292+
// Check for underflow. Half-precision min normal exponent is -14 ,
293+
// but the smallest subnormal is 2^-24.
294+
if (LIBC_UNLIKELY(z < -25.0)) {
251295
fputil::raise_except_if_required(FE_UNDERFLOW);
252-
return fputil::cast<float16>(0.0f);
296+
// Determine sign of underflowed result
297+
return result_sign ? FPBits::zero(Sign::NEG).get_val()
298+
: FPBits::zero(Sign::POS).get_val();
299+
}
300+
301+
// Check for overflow. The max value of float16 is ~2^16, so z > 16.0 will
302+
// overflow.
303+
if (LIBC_UNLIKELY(z > 16.0)) {
304+
fputil::raise_except_if_required(FE_OVERFLOW);
305+
306+
float16 max_finite = FPBits::max_normal().get_val();
307+
float16 neg_max_finite = FPBits::max_normal(Sign::NEG).get_val();
308+
309+
int round_mode = fputil::get_round();
310+
311+
if (result_sign) {
312+
// Negative result overflows
313+
// For TOWARDZERO or DOWNWARD rounding, the result is the most negative
314+
// finite value.
315+
if (round_mode == FE_TOWARDZERO || round_mode == FE_DOWNWARD) {
316+
return neg_max_finite;
317+
}
318+
// For other rounding modes, the result is -Infinity.
319+
return FPBits::inf(Sign::NEG).get_val();
320+
} else {
321+
// Positive result overflows
322+
// For TOWARDZERO or DOWNWARD rounding, the result is the max finite
323+
// value.
324+
if (round_mode == FE_TOWARDZERO || round_mode == FE_DOWNWARD) {
325+
return max_finite;
326+
}
327+
// For other rounding modes, the result is +Infinity.
328+
return FPBits::inf(Sign::POS).get_val();
329+
}
253330
}
254331

255332
double n = fputil::nearest_integer(z);
@@ -278,12 +355,7 @@ LLVM_LIBC_FUNCTION(float16, powf16, (float16 x, float16 y)) {
278355
int n_int = static_cast<int>(n);
279356
uint64_t exp_bits = static_cast<uint64_t>(n_int + 1023) << 52;
280357
double pow2_n = cpp::bit_cast<double>(exp_bits);
281-
282-
283-
double result_d = (pow2_n * exp2_r);
284-
float16 result = fputil::cast<float16>(result_d);
285-
if(result_d==65504.0)
286-
return (65504.f16);
358+
float16 result = fputil::cast<float16>((pow2_n * exp2_r));
287359

288360
if (result_sign) {
289361
FPBits result_bits(result);

0 commit comments

Comments
 (0)