Skip to content

Commit 77e6f0e

Browse files
committed
Use bfloat16
1 parent d33ae0c commit 77e6f0e

File tree

11 files changed

+261
-307
lines changed

11 files changed

+261
-307
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormat.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
1818
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
1919
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
20-
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
2120
import org.apache.lucene.index.ByteVectorValues;
2221
import org.apache.lucene.index.FieldInfo;
2322
import org.apache.lucene.index.FloatVectorValues;
@@ -29,6 +28,7 @@
2928
import org.apache.lucene.util.Bits;
3029
import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector;
3130
import org.apache.lucene.util.hnsw.RandomVectorScorer;
31+
import org.elasticsearch.index.codec.vectors.es91.ES91BFloat16FlatVectorsFormat;
3232
import org.elasticsearch.index.codec.vectors.reflect.OffHeapByteSizeUtils;
3333
import org.elasticsearch.index.codec.vectors.reflect.OffHeapStats;
3434

@@ -41,7 +41,7 @@ public class ES813FlatVectorFormat extends KnnVectorsFormat {
4141

4242
static final String NAME = "ES813FlatVectorFormat";
4343

44-
private static final FlatVectorsFormat format = new Lucene99FlatVectorsFormat(DefaultFlatVectorScorer.INSTANCE);
44+
private static final FlatVectorsFormat format = new ES91BFloat16FlatVectorsFormat(DefaultFlatVectorScorer.INSTANCE);
4545

4646
/**
4747
* Sole constructor

server/src/main/java/org/elasticsearch/index/codec/vectors/ES814ScalarQuantizedVectorsFormat.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
1717
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
1818
import org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer;
19-
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
2019
import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsReader;
2120
import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsWriter;
2221
import org.apache.lucene.index.ByteVectorValues;
@@ -34,6 +33,7 @@
3433
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
3534
import org.apache.lucene.util.quantization.QuantizedVectorsReader;
3635
import org.apache.lucene.util.quantization.ScalarQuantizer;
36+
import org.elasticsearch.index.codec.vectors.es91.ES91BFloat16FlatVectorsFormat;
3737
import org.elasticsearch.index.codec.vectors.reflect.OffHeapByteSizeUtils;
3838
import org.elasticsearch.index.codec.vectors.reflect.OffHeapStats;
3939
import org.elasticsearch.simdvec.VectorScorerFactory;
@@ -50,7 +50,7 @@ public class ES814ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
5050
static final String NAME = "ES814ScalarQuantizedVectorsFormat";
5151
private static final int ALLOWED_BITS = (1 << 8) | (1 << 7) | (1 << 4);
5252

53-
private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat(DefaultFlatVectorScorer.INSTANCE);
53+
private static final FlatVectorsFormat rawVectorFormat = new ES91BFloat16FlatVectorsFormat(DefaultFlatVectorScorer.INSTANCE);
5454

5555
static final FlatVectorsScorer flatVectorScorer = new ESFlatVectorsScorer(
5656
new ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE)

server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
1414
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
1515
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
16-
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
1716
import org.apache.lucene.index.ByteVectorValues;
1817
import org.apache.lucene.index.KnnVectorValues;
1918
import org.apache.lucene.index.SegmentReadState;
@@ -24,14 +23,15 @@
2423
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
2524
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
2625
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
26+
import org.elasticsearch.index.codec.vectors.es91.ES91BFloat16FlatVectorsFormat;
2727

2828
import java.io.IOException;
2929

3030
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT;
3131

3232
class ES815BitFlatVectorsFormat extends FlatVectorsFormat {
3333

34-
private static final FlatVectorsFormat delegate = new Lucene99FlatVectorsFormat(FlatBitVectorScorer.INSTANCE);
34+
private static final FlatVectorsFormat delegate = new ES91BFloat16FlatVectorsFormat(FlatBitVectorScorer.INSTANCE);
3535

3636
protected ES815BitFlatVectorsFormat() {
3737
super("ES815BitFlatVectorsFormat");

server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormat.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
2424
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
2525
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
26-
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
2726
import org.apache.lucene.index.SegmentReadState;
2827
import org.apache.lucene.index.SegmentWriteState;
28+
import org.elasticsearch.index.codec.vectors.es91.ES91BFloat16FlatVectorsFormat;
2929

3030
import java.io.IOException;
3131

@@ -47,7 +47,7 @@ public class ES816BinaryQuantizedVectorsFormat extends FlatVectorsFormat {
4747
static final String VECTOR_DATA_EXTENSION = "veb";
4848
static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16;
4949

50-
private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat(
50+
private static final FlatVectorsFormat rawVectorFormat = new ES91BFloat16FlatVectorsFormat(
5151
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
5252
);
5353

server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormat.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@
2323
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
2424
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
2525
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
26-
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
2726
import org.apache.lucene.index.SegmentReadState;
2827
import org.apache.lucene.index.SegmentWriteState;
2928
import org.elasticsearch.core.SuppressForbidden;
3029
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
30+
import org.elasticsearch.index.codec.vectors.es91.ES91BFloat16FlatVectorsFormat;
3131

3232
import java.io.IOException;
3333

@@ -110,7 +110,7 @@ private static boolean getUseDirectIO() {
110110

111111
private static final FlatVectorsFormat rawVectorFormat = USE_DIRECT_IO
112112
? new DirectIOLucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer())
113-
: new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
113+
: new ES91BFloat16FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
114114

115115
private static final ES818BinaryFlatVectorsScorer scorer = new ES818BinaryFlatVectorsScorer(
116116
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()

server/src/main/java/org/elasticsearch/index/codec/vectors/es91/BFloat16.java

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,20 @@
99

1010
package org.elasticsearch.index.codec.vectors.es91;
1111

12+
import org.apache.lucene.util.BitUtil;
13+
1214
class BFloat16 {
1315

1416
public static final int BYTES = Short.BYTES;
1517

1618
public static short floatToBFloat16(float f) {
17-
// TODO: maintain NaN if all NaN set bits are in removed section
18-
return (short)(Float.floatToIntBits(f) >>> 16);
19+
// this does round towards 0
20+
// zero - zero exp, zero fraction
21+
// denormal - zero exp, non-zero fraction
22+
// infinity - all-1 exp, zero fraction
23+
// NaN - all-1 exp, non-zero fraction
24+
// the Float.NaN constant is 0x7fc0_0000, so this won't turn the most common NaN values into infinities
25+
return (short) (Float.floatToIntBits(f) >>> 16);
1926
}
2027

2128
public static float bFloat16ToFloat(short bf) {
@@ -24,17 +31,24 @@ public static float bFloat16ToFloat(short bf) {
2431

2532
public static short[] floatToBFloat16(float[] f) {
2633
short[] bf = new short[f.length];
27-
for (int i=0; i<f.length; i++) {
34+
for (int i = 0; i < f.length; i++) {
2835
bf[i] = floatToBFloat16(f[i]);
2936
}
3037
return bf;
3138
}
3239

3340
public static float[] bFloat16ToFloat(short[] bf) {
3441
float[] f = new float[bf.length];
35-
for (int i=0; i<bf.length; i++) {
42+
for (int i = 0; i < bf.length; i++) {
3643
f[i] = bFloat16ToFloat(bf[i]);
3744
}
3845
return f;
3946
}
47+
48+
public static void bFloat16ToFloat(byte[] bfBytes, float[] floats) {
49+
assert floats.length * 2 == bfBytes.length;
50+
for (int i = 0; i < floats.length; i++) {
51+
floats[i] = bFloat16ToFloat((short) BitUtil.VH_LE_SHORT.get(bfBytes, i * 2));
52+
}
53+
}
4054
}

server/src/main/java/org/elasticsearch/index/codec/vectors/es91/ES91BFloat16FlatVectorsFormat.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
2525
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
2626
import org.apache.lucene.codecs.lucene99.ES91BFloat16FlatVectorsReader;
27-
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsReader;
28-
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsWriter;
2927
import org.apache.lucene.index.SegmentReadState;
3028
import org.apache.lucene.index.SegmentWriteState;
3129

0 commit comments

Comments
 (0)