Skip to content

Commit 1ad60c2

Browse files
bjacobElias Joseph
authored andcommitted
Fix libdevice iree_f2h_ieee conversion (#20248)
This replaces the implementation with a specialized copy of the code we have in base/internal/math.h. This also adds an extensive test, so that we can feel better about this code which isn't shared with anything else (being libdevice) and is relatively little used, being used only with workloads involving the CPU-hostile f16 type, and only on CPU targets lacking native f16-f32 conversion instructions, which is generally only `generic` CPU targets as contemporary CPUs tend to have native instructions for these f16-f32 (F16C extension on x86) even though they lack native f16 *arithmetic* beyond these conversions. Fixes #20163. --------- Signed-off-by: Benoit Jacob <[email protected]> Signed-off-by: Elias Joseph <[email protected]>
1 parent 3c5a195 commit 1ad60c2

File tree

5 files changed

+137
-57
lines changed

5 files changed

+137
-57
lines changed

runtime/src/iree/builtins/device/device_generic.c

Lines changed: 64 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -65,65 +65,81 @@ IREE_DEVICE_EXPORT float iree_h2f_ieee(short param) {
6565
}
6666

6767
IREE_DEVICE_EXPORT short iree_f2h_ieee(float param) {
68+
// Some constants about the f32 and f16 types.
69+
const int f32_mantissa_bits = 23;
70+
const int f32_exp_bias = 127;
71+
const uint32_t f32_sign_mask = 0x80000000u;
72+
const uint32_t f32_exp_mask = 0x7F800000u;
73+
const uint32_t f32_mantissa_mask = 0x007FFFFFu;
74+
const int f16_mantissa_bits = 10;
75+
const int f16_exp_bits = 5;
76+
const int f16_exp_bias = 15;
77+
const uint16_t f16_exp_mask = 0x7C00u;
78+
const uint16_t f16_mantissa_mask = 0x03FFu;
79+
80+
// Bitcast float param to uint32.
6881
union {
6982
unsigned int u;
7083
float f;
7184
} param_bits = {
7285
.f = param,
7386
};
74-
int sign = param_bits.u >> 31;
75-
int mantissa = param_bits.u & 0x007FFFFF;
76-
int exp = ((param_bits.u & 0x7F800000) >> 23) + 15 - 127;
77-
short res;
78-
if (exp > 0 && exp < 30) {
79-
// use rte rounding mode, round the significand, combine sign, exponent and
80-
// significand into a short.
81-
res = (sign << 15) | (exp << 10) | ((mantissa + 0x00001000) >> 13);
82-
} else if (param_bits.u == 0) {
83-
res = 0;
87+
uint32_t u32_value = param_bits.u;
88+
89+
// Split the f32 sign/exponent/mantissa components.
90+
const uint32_t f32_sign = u32_value & f32_sign_mask;
91+
const uint32_t f32_exp = u32_value & f32_exp_mask;
92+
const uint32_t f32_mantissa = u32_value & f32_mantissa_mask;
93+
// Initialize the f16 sign/exponent/mantissa components.
94+
uint32_t f16_sign = f32_sign >> 16;
95+
uint32_t f16_exp = 0;
96+
uint32_t f16_mantissa = 0;
97+
98+
if (f32_exp >= f32_exp_mask) {
99+
// NaN or Inf case.
100+
f16_exp = f16_exp_mask;
101+
if (f32_mantissa) {
102+
// NaN. Generate a quiet NaN.
103+
return f16_sign | f16_exp_mask | f16_mantissa_mask;
104+
} else {
105+
// Inf. Leave zero mantissa.
106+
}
107+
} else if (f32_exp == 0) {
108+
// Zero or subnormal. Generate zero. Leave zero mantissa.
84109
} else {
85-
if (exp <= 0) {
86-
if (exp < -10) {
87-
// value is less than min half float point
88-
res = 0;
89-
} else {
90-
// normalized single, magnitude is less than min normal half float
91-
// point.
92-
mantissa = (mantissa | 0x00800000) >> (1 - exp);
93-
// round to nearest
94-
if ((mantissa & 0x00001000) > 0) {
95-
mantissa = mantissa + 0x00002000;
96-
}
97-
// combine sign & mantissa (exp is zero to get denormalized number)
98-
res = (sign << 15) | (mantissa >> 13);
99-
}
100-
} else if (exp == (255 - 127 + 15)) {
101-
if (mantissa == 0) {
102-
// input float is infinity, return infinity half
103-
res = (sign << 15) | 0x7C00;
104-
} else {
105-
// input float is NaN, return half NaN
106-
res = (sign << 15) | 0x7C00 | (mantissa >> 13);
107-
}
110+
// Normal finite value.
111+
int arithmetic_exp = (f32_exp >> f32_mantissa_bits) - f32_exp_bias;
112+
// Test if the exponent is too large for the destination type. If
113+
// the destination type does not have infinities, that frees up the
114+
// max exponent value for additional finite values.
115+
if (arithmetic_exp >= 1 << (f16_exp_bits - 1)) {
116+
// Overflow. Generate Inf. Leave zero mantissa.
117+
f16_exp = f16_exp_mask;
118+
} else if (arithmetic_exp + f16_exp_bias <= 0) {
119+
// Underflow. Generate zero. Leave zero mantissa.
120+
f16_exp = 0;
108121
} else {
109-
// exp > 0, normalized single, round to nearest
110-
if ((mantissa & 0x00001000) > 0) {
111-
mantissa = mantissa + 0x00002000;
112-
if ((mantissa & 0x00800000) > 0) {
113-
mantissa = 0;
114-
exp = exp + 1;
115-
}
116-
}
117-
if (exp > 30) {
118-
// exponent overflow - return infinity half
119-
res = (sign << 15) | 0x7C00;
120-
} else {
121-
// combine sign, exp and mantissa into normalized half
122-
res = (sign << 15) | (exp << 10) | (mantissa >> 13);
122+
// Normal case.
123+
// Implement round-to-nearest-even, by adding a bias before truncating.
124+
int even_bit = 1u << (f32_mantissa_bits - f16_mantissa_bits);
125+
int odd_bit = even_bit >> 1;
126+
uint32_t biased_f32_mantissa =
127+
f32_mantissa +
128+
((f32_mantissa & even_bit) ? (odd_bit) : (odd_bit - 1));
129+
// Adding the bias may cause an exponent increment.
130+
if (biased_f32_mantissa > f32_mantissa_mask) {
131+
biased_f32_mantissa = 0;
132+
++arithmetic_exp;
123133
}
134+
// The exponent increment in the above if() branch may cause overflow.
135+
// This is exercised by converting 65520.0f from f32 to f16.
136+
f16_exp = (arithmetic_exp + f16_exp_bias) << f16_mantissa_bits;
137+
f16_mantissa =
138+
biased_f32_mantissa >> (f32_mantissa_bits - f16_mantissa_bits);
124139
}
125140
}
126-
return res;
141+
142+
return f16_sign | f16_exp | f16_mantissa;
127143
}
128144

129145
#if defined(IREE_DEVICE_STANDALONE)

runtime/src/iree/builtins/device/tools/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ iree_runtime_cc_test(
2929
srcs = ["libdevice_test.cc"],
3030
deps = [
3131
"//runtime/src/iree/base",
32+
"//runtime/src/iree/base/internal",
3233
"//runtime/src/iree/base/internal:flags",
3334
"//runtime/src/iree/builtins/device",
3435
"//runtime/src/iree/testing:gtest",

runtime/src/iree/builtins/device/tools/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ iree_cc_test(
3030
"libdevice_test.cc"
3131
DEPS
3232
iree::base
33+
iree::base::internal
3334
iree::base::internal::flags
3435
iree::builtins::device
3536
iree::testing::gtest

runtime/src/iree/builtins/device/tools/libdevice_test.cc

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,80 @@
77
#include <cstring>
88

99
#include "iree/base/api.h"
10+
#include "iree/base/internal/math.h"
1011
#include "iree/builtins/device/device.h"
1112
#include "iree/testing/gtest.h"
1213
#include "iree/testing/status_matchers.h"
1314

15+
static constexpr uint16_t kF16ExponentMask = 0x7C00;
16+
static constexpr uint16_t kMantissaMask = 0x03FF;
17+
18+
static uint16_t f16BitsIsNaN(uint16_t bits) {
19+
return ((bits & kF16ExponentMask) == kF16ExponentMask) &&
20+
(bits & kMantissaMask);
21+
}
22+
23+
static uint16_t f16BitsIsDenormalOrZero(uint16_t bits) {
24+
return !(bits & kF16ExponentMask);
25+
}
26+
1427
TEST(LibDeviceTest, iree_h2f_ieee) {
15-
// Just ensuring that the code links.
16-
EXPECT_EQ(0.25f, iree_h2f_ieee(0x3400));
28+
// Iterate over all f16 values as u16. Needs a wider type for loop condition.
29+
for (uint32_t f16Bits = 0; f16Bits <= 0xffff; ++f16Bits) {
30+
float f32 = iree_h2f_ieee(f16Bits);
31+
if (f16BitsIsNaN(f16Bits)) {
32+
EXPECT_TRUE(std::isnan(f32));
33+
} else if (f16Bits == 0) {
34+
EXPECT_EQ(f32, 0.f);
35+
} else if (f16BitsIsDenormalOrZero(f16Bits)) {
36+
EXPECT_LE(std::abs(f32), 6.1e-5f);
37+
} else {
38+
EXPECT_EQ(f32, iree_math_f16_to_f32(f16Bits));
39+
}
40+
}
1741
}
1842

1943
TEST(LibDeviceTest, iree_f2h_ieee) {
20-
// Just ensuring that the code links.
21-
EXPECT_EQ(0x3400, iree_f2h_ieee(0.25f));
44+
auto testcase = [](uint32_t f32Bits) {
45+
float f32 = 0.f;
46+
memcpy(&f32, &f32Bits, sizeof f32);
47+
uint16_t f16Bits = iree_f2h_ieee(f32);
48+
if (std::isnan(f32)) {
49+
EXPECT_TRUE(f16BitsIsNaN(f16Bits));
50+
} else if (f32 == 0.f) {
51+
EXPECT_EQ(f16Bits, std::signbit(f32) ? 0x8000 : 0);
52+
} else if (std::abs(f32) < 6.1e-5f) {
53+
EXPECT_TRUE(f16BitsIsDenormalOrZero(f16Bits));
54+
} else {
55+
EXPECT_EQ(f16Bits, iree_math_f32_to_f16(f32));
56+
}
57+
};
58+
// Testing all 2^32 float32 values is too much. We test two slices of that
59+
// space.
60+
//
61+
// Test all 2^12 float32 values that have only their top 12 bits potentially
62+
// set. That covers all combination of sign x exponent x the top 3 bits of
63+
// mantissa. The bottom 20 mantissa bits stay zero, so this lacks coverage
64+
// of rounding behavior.
65+
for (uint32_t f32Top12Bits = 0; f32Top12Bits <= 0xfff; ++f32Top12Bits) {
66+
testcase(f32Top12Bits << 20);
67+
}
68+
// For a few select exponent values, test all 2^12 float32 values whose
69+
// *mantissa* bits have only their top 12 bits potentially set.
70+
// Since float16 has only 10 bits of mantissa, that covers all float16
71+
// mantissas plus 2 additional bits of float32 mantissa past the truncation.
72+
// Having 2 extra bits should be exactly what is relevant to testing rounding
73+
// behavior including tie breaks to "nearest even".
74+
for (uint32_t f32MantissaTop12Bits = 0; f32MantissaTop12Bits <= 0xfff;
75+
++f32MantissaTop12Bits) {
76+
// A few select exponent values.
77+
for (uint32_t f32ExponentBits :
78+
{0 /*denormal*/, 1 /*minimum normal*/, 127 /*neutral*/,
79+
254 /*maximum finite*/, 255 /*infinite*/}) {
80+
for (uint32_t f32SignBit : {0, 1}) {
81+
testcase((f32SignBit << 31) | (f32ExponentBits << 23) |
82+
(f32MantissaTop12Bits << 11));
83+
}
84+
}
85+
}
2286
}

tests/e2e/math/math_ops_llvm-cpu.json

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,7 @@
155155
{
156156
"op": "exp2",
157157
"type": "f16",
158-
"atol": 0.25,
159-
"comment": "TODO(#20163)",
158+
"atol": 1.0e-03,
160159
"rtol": 1.0e-02
161160
},
162161
{
@@ -339,8 +338,7 @@
339338
{
340339
"op": "powf",
341340
"type": "f16",
342-
"atol": 0.25,
343-
"comment": "TODO(#20163)",
344-
"rtol": 5.0e-03
341+
"atol": 1.0e-03,
342+
"rtol": 1.0e-02
345343
}
346344
]

0 commit comments

Comments
 (0)