Skip to content

Commit 3213723

Browse files
committed
Implement bfloat16 round-to-even
1 parent d02487e commit 3213723

File tree

2 files changed

+99
-3
lines changed

2 files changed

+99
-3
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/BFloat16.java

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,29 @@ public final class BFloat16 {
1818
public static final int BYTES = Short.BYTES;
1919

2020
public static short floatToBFloat16(float f) {
21-
// this rounds towards 0
21+
// this rounds towards even
2222
// zero - zero exp, zero fraction
2323
// denormal - zero exp, non-zero fraction
2424
// infinity - all-1 exp, zero fraction
2525
// NaN - all-1 exp, non-zero fraction
2626
// the Float.NaN constant is 0x7fc0_0000, so this won't turn the most common NaN values into
2727
// infinities
28-
return (short) (Float.floatToIntBits(f) >>> 16);
28+
29+
int bits = Float.floatToIntBits(f);
30+
int bfloat16 = bits >>> 16;
31+
32+
// if highest discarded bit is 1,
33+
// and there's other non-zero discarded bits, or the bfloat16 is odd
34+
// then round up
35+
if ((bits & 0x8000) == 0x8000 && ((bits & 0x7fff) != 0 || (bfloat16 & 1) == 1)) {
36+
bfloat16++;
37+
}
38+
39+
return (short) bfloat16;
2940
}
3041

3142
public static float truncateToBFloat16(float f) {
32-
return Float.intBitsToFloat(Float.floatToIntBits(f) & 0xffff0000);
43+
return Float.intBitsToFloat(floatToBFloat16(f) << 16);
3344
}
3445

3546
public static float bFloat16ToFloat(short bf) {
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.index.codec.vectors;
11+
12+
import org.elasticsearch.test.ESTestCase;
13+
14+
import static org.hamcrest.Matchers.closeTo;
15+
import static org.hamcrest.Matchers.equalTo;
16+
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
17+
18+
public class BFloat16Tests extends ESTestCase {
19+
20+
public void testRoundToEven() {
21+
int exp = 0b001111110; // to create floating numbers around 1.0
22+
23+
// exact bfloat16 value
24+
float bfloat16 = construct(exp, 0b1111001_00000000_00000000);
25+
assertRounding(bfloat16, bfloat16);
26+
27+
// round down
28+
assertRounding(construct(exp, 0b0000001_01111111_11111111), construct(exp, 0b0000001_00000000_00000000));
29+
30+
// round up
31+
assertRounding(construct(exp, 0b0000001_10000000_00000001), construct(exp, 0b0000010_00000000_00000000));
32+
33+
// split down to even
34+
assertRounding(construct(exp, 0b000010_10000000_00000000), construct(exp, 0b000010_00000000_00000000));
35+
36+
// split up to even
37+
assertRounding(construct(exp, 0b000001_10000000_00000000), construct(exp, 0b000010_00000000_00000000));
38+
39+
// round up, overflowing into exponent
40+
assertRounding(construct(0b000111111, 0b1111111_10000000_00000000), construct(0b001000000, 0b0000000_00000000_00000000));
41+
42+
// round up, overflowing from denormal to normal number
43+
assertRounding(construct(0b000000000, 0b1111111_10000000_00000000), construct(0b000000001, 0b0000000_00000000_00000000));
44+
45+
// round to positive infinity
46+
assertThat(BFloat16.truncateToBFloat16(construct(0b011111110, 0b1111111_10000000_00000000)), equalTo(Float.POSITIVE_INFINITY));
47+
48+
// round to negative infinity
49+
assertThat(BFloat16.truncateToBFloat16(construct(0b111111110, 0b1111111_10000000_00000000)), equalTo(Float.NEGATIVE_INFINITY));
50+
51+
// round to zero
52+
assertRounding(construct(0b000000000, 0b0000000_10000000_00000000), 0f);
53+
54+
// rounding the standard NaN value should be unchanged
55+
assertThat(Float.floatToIntBits(BFloat16.truncateToBFloat16(Float.NaN)), equalTo(Float.floatToIntBits(Float.NaN)));
56+
}
57+
58+
private static float construct(int exp, int mantissa) {
59+
assert (exp & 0xfffffe00) == 0;
60+
assert (mantissa & 0xf8000000) == 0;
61+
return Float.intBitsToFloat((exp << 23) | mantissa);
62+
}
63+
64+
private static void assertRounding(float value, float expectedRounded) {
65+
assert (Float.floatToIntBits(expectedRounded) & 0xffff) == 0;
66+
67+
// rounded float value to check should be close to input value
68+
// this checks the bit representations in the tests are actually sensible
69+
assertThat((double) expectedRounded, closeTo(value, 0.002));
70+
71+
float rounded = BFloat16.truncateToBFloat16(value);
72+
73+
// System.out.println(value + " rounds to " + rounded);
74+
assertEquals(value + " rounded to " + rounded + ", not " + expectedRounded,
75+
Float.floatToIntBits(expectedRounded), Float.floatToIntBits(rounded));
76+
77+
// there should not be a closer bfloat16 value (comparing using FP math) than the expected rounded value
78+
float delta = Math.abs(value - rounded);
79+
float higherValue = Float.intBitsToFloat(Float.floatToIntBits(rounded) + 0x10000);
80+
assertThat(Math.abs(value - higherValue), greaterThanOrEqualTo(delta));
81+
82+
float lowerValue = Float.intBitsToFloat(Float.floatToIntBits(rounded) - 0x10000);
83+
assertThat(Math.abs(value - lowerValue), greaterThanOrEqualTo(delta));
84+
}
85+
}

0 commit comments

Comments
 (0)