Skip to content

Commit d33ae0c

Browse files
committed
More changes
1 parent ed83b1b commit d33ae0c

File tree

5 files changed

+390
-38
lines changed

5 files changed

+390
-38
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.index.codec.vectors.es91;
11+
12+
class BFloat16 {
13+
14+
public static final int BYTES = Short.BYTES;
15+
16+
public static short floatToBFloat16(float f) {
17+
// TODO: maintain NaN if all NaN set bits are in removed section
18+
return (short)(Float.floatToIntBits(f) >>> 16);
19+
}
20+
21+
public static float bFloat16ToFloat(short bf) {
22+
return Float.intBitsToFloat(bf << 16);
23+
}
24+
25+
public static short[] floatToBFloat16(float[] f) {
26+
short[] bf = new short[f.length];
27+
for (int i=0; i<f.length; i++) {
28+
bf[i] = floatToBFloat16(f[i]);
29+
}
30+
return bf;
31+
}
32+
33+
public static float[] bFloat16ToFloat(short[] bf) {
34+
float[] f = new float[bf.length];
35+
for (int i=0; i<bf.length; i++) {
36+
f[i] = bFloat16ToFloat(bf[i]);
37+
}
38+
return f;
39+
}
40+
}

server/src/main/java/org/elasticsearch/index/codec/vectors/es91/ES91BFloat16FlatVectorsFormat.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
2424
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
2525
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
26+
import org.apache.lucene.codecs.lucene99.ES91BFloat16FlatVectorsReader;
2627
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsReader;
2728
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsWriter;
2829
import org.apache.lucene.index.SegmentReadState;
@@ -33,7 +34,7 @@
3334
public class ES91BFloat16FlatVectorsFormat extends FlatVectorsFormat {
3435

3536
static final String NAME = "ES91BFloat16FlatVectorsFormat";
36-
static final String META_CODEC_NAME = "Lucene99FlatVectorsFormatMeta";
37+
static final String META_CODEC_NAME = "ES91BFloat16FlatVectorsFormatMeta";
3738
static final String VECTOR_DATA_CODEC_NAME = "ES91BFloat16FlatVectorsFormatData";
3839
static final String META_EXTENSION = "vemf";
3940
static final String VECTOR_DATA_EXTENSION = "vec";
@@ -52,12 +53,12 @@ public ES91BFloat16FlatVectorsFormat(FlatVectorsScorer vectorsScorer) {
5253

5354
@Override
5455
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
55-
return new Lucene99FlatVectorsWriter(state, vectorsScorer);
56+
return new ES91BFloat16FlatVectorsWriter(state, vectorsScorer);
5657
}
5758

5859
@Override
5960
public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
60-
return new Lucene99FlatVectorsReader(state, vectorsScorer);
61+
return new ES91BFloat16FlatVectorsReader(state, vectorsScorer);
6162
}
6263

6364
@Override

server/src/main/java/org/elasticsearch/index/codec/vectors/es91/ES91BFloat16FlatVectorsReader.java

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
*
1818
* Modifications copyright (C) 2024 Elasticsearch B.V.
1919
*/
20-
package org.apache.lucene.codecs.lucene99;
20+
package org.elasticsearch.index.codec.vectors.es91;
2121

2222
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSimilarityFunction;
2323
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding;
@@ -51,6 +51,7 @@
5151
import org.apache.lucene.util.IOUtils;
5252
import org.apache.lucene.util.RamUsageEstimator;
5353
import org.apache.lucene.util.hnsw.RandomVectorScorer;
54+
import org.elasticsearch.index.codec.vectors.es91.ES91BFloat16FlatVectorsFormat;
5455

5556
/**
5657
* Reads vectors from the index segments.
@@ -60,13 +61,13 @@
6061
public final class ES91BFloat16FlatVectorsReader extends FlatVectorsReader {
6162

6263
private static final long SHALLOW_SIZE =
63-
RamUsageEstimator.shallowSizeOfInstance(Lucene99FlatVectorsFormat.class);
64+
RamUsageEstimator.shallowSizeOfInstance(ES91BFloat16FlatVectorsFormat.class);
6465

6566
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
6667
private final IndexInput vectorData;
6768
private final FieldInfos fieldInfos;
6869

69-
public Lucene99FlatVectorsReader(SegmentReadState state, FlatVectorsScorer scorer)
70+
public ES91BFloat16FlatVectorsReader(SegmentReadState state, FlatVectorsScorer scorer)
7071
throws IOException {
7172
super(scorer);
7273
int versionMeta = readMetadata(state);
@@ -76,8 +77,8 @@ public Lucene99FlatVectorsReader(SegmentReadState state, FlatVectorsScorer score
7677
openDataInput(
7778
state,
7879
versionMeta,
79-
Lucene99FlatVectorsFormat.VECTOR_DATA_EXTENSION,
80-
Lucene99FlatVectorsFormat.VECTOR_DATA_CODEC_NAME,
80+
ES91BFloat16FlatVectorsFormat.VECTOR_DATA_EXTENSION,
81+
ES91BFloat16FlatVectorsFormat.VECTOR_DATA_CODEC_NAME,
8182
// Flat formats are used to randomly access vectors from their node ID that is stored
8283
// in the HNSW graph.
8384
state.context.withHints(
@@ -91,17 +92,17 @@ public Lucene99FlatVectorsReader(SegmentReadState state, FlatVectorsScorer score
9192
private int readMetadata(SegmentReadState state) throws IOException {
9293
String metaFileName =
9394
IndexFileNames.segmentFileName(
94-
state.segmentInfo.name, state.segmentSuffix, Lucene99FlatVectorsFormat.META_EXTENSION);
95+
state.segmentInfo.name, state.segmentSuffix, ES91BFloat16FlatVectorsFormat.META_EXTENSION);
9596
int versionMeta = -1;
9697
try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) {
9798
Throwable priorE = null;
9899
try {
99100
versionMeta =
100101
CodecUtil.checkIndexHeader(
101102
meta,
102-
Lucene99FlatVectorsFormat.META_CODEC_NAME,
103-
Lucene99FlatVectorsFormat.VERSION_START,
104-
Lucene99FlatVectorsFormat.VERSION_CURRENT,
103+
ES91BFloat16FlatVectorsFormat.META_CODEC_NAME,
104+
ES91BFloat16FlatVectorsFormat.VERSION_START,
105+
ES91BFloat16FlatVectorsFormat.VERSION_CURRENT,
105106
state.segmentInfo.getId(),
106107
state.segmentSuffix);
107108
readFields(meta, state.fieldInfos);
@@ -129,8 +130,8 @@ private static IndexInput openDataInput(
129130
CodecUtil.checkIndexHeader(
130131
in,
131132
codecName,
132-
Lucene99FlatVectorsFormat.VERSION_START,
133-
Lucene99FlatVectorsFormat.VERSION_CURRENT,
133+
ES91BFloat16FlatVectorsFormat.VERSION_START,
134+
ES91BFloat16FlatVectorsFormat.VERSION_CURRENT,
134135
state.segmentInfo.getId(),
135136
state.segmentSuffix);
136137
if (versionMeta != versionVectorData) {
@@ -164,13 +165,13 @@ private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOExce
164165

165166
@Override
166167
public long ramBytesUsed() {
167-
return Lucene99FlatVectorsReader.SHALLOW_SIZE + fields.ramBytesUsed();
168+
return ES91BFloat16FlatVectorsReader.SHALLOW_SIZE + fields.ramBytesUsed();
168169
}
169170

170171
@Override
171172
public Map<String, Long> getOffHeapByteSize(FieldInfo fieldInfo) {
172173
final FieldEntry entry = getFieldEntryOrThrow(fieldInfo.name);
173-
return Map.of(Lucene99FlatVectorsFormat.VECTOR_DATA_EXTENSION, entry.vectorDataLength());
174+
return Map.of(ES91BFloat16FlatVectorsFormat.VECTOR_DATA_EXTENSION, entry.vectorDataLength());
174175
}
175176

176177
@Override
@@ -215,7 +216,7 @@ private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding)
215216
@Override
216217
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
217218
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
218-
return OffHeapFloatVectorValues.load(
219+
return OffHeapBFloat16VectorValues.load(
219220
fieldEntry.similarityFunction,
220221
vectorScorer,
221222
fieldEntry.ordToDoc,
@@ -229,7 +230,7 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException {
229230
@Override
230231
public ByteVectorValues getByteVectorValues(String field) throws IOException {
231232
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
232-
return OffHeapByteVectorValues.load(
233+
return OffHeapBFloat16VectorValues.load(
233234
fieldEntry.similarityFunction,
234235
vectorScorer,
235236
fieldEntry.ordToDoc,
@@ -320,7 +321,7 @@ private record FieldEntry(
320321
int byteSize =
321322
switch (info.getVectorEncoding()) {
322323
case BYTE -> Byte.BYTES;
323-
case FLOAT32 -> Float.BYTES;
324+
case FLOAT32 -> BFloat16.BYTES;
324325
};
325326
long vectorBytes = Math.multiplyExact((long) infoVectorDimension, byteSize);
326327
long numBytes = Math.multiplyExact(vectorBytes, size);

server/src/main/java/org/elasticsearch/index/codec/vectors/es91/ES91BFloat16FlatVectorsWriter.java

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
public final class ES91BFloat16FlatVectorsWriter extends FlatVectorsWriter {
6868

6969
private static final long SHALLOW_RAM_BYTES_USED =
70-
RamUsageEstimator.shallowSizeOfInstance(Lucene99FlatVectorsWriter.class);
70+
RamUsageEstimator.shallowSizeOfInstance(ES91BFloat16FlatVectorsWriter.class);
7171

7272
private final SegmentWriteState segmentWriteState;
7373
private final IndexOutput meta, vectorData;
@@ -81,28 +81,28 @@ public ES91BFloat16FlatVectorsWriter(SegmentWriteState state, FlatVectorsScorer
8181
segmentWriteState = state;
8282
String metaFileName =
8383
IndexFileNames.segmentFileName(
84-
state.segmentInfo.name, state.segmentSuffix, Lucene99FlatVectorsFormat.META_EXTENSION);
84+
state.segmentInfo.name, state.segmentSuffix, ES91BFloat16FlatVectorsFormat.META_EXTENSION);
8585

8686
String vectorDataFileName =
8787
IndexFileNames.segmentFileName(
8888
state.segmentInfo.name,
8989
state.segmentSuffix,
90-
Lucene99FlatVectorsFormat.VECTOR_DATA_EXTENSION);
90+
ES91BFloat16FlatVectorsFormat.VECTOR_DATA_EXTENSION);
9191

9292
try {
9393
meta = state.directory.createOutput(metaFileName, state.context);
9494
vectorData = state.directory.createOutput(vectorDataFileName, state.context);
9595

9696
CodecUtil.writeIndexHeader(
9797
meta,
98-
Lucene99FlatVectorsFormat.META_CODEC_NAME,
99-
Lucene99FlatVectorsFormat.VERSION_CURRENT,
98+
ES91BFloat16FlatVectorsFormat.META_CODEC_NAME,
99+
ES91BFloat16FlatVectorsFormat.VERSION_CURRENT,
100100
state.segmentInfo.getId(),
101101
state.segmentSuffix);
102102
CodecUtil.writeIndexHeader(
103103
vectorData,
104-
Lucene99FlatVectorsFormat.VECTOR_DATA_CODEC_NAME,
105-
Lucene99FlatVectorsFormat.VERSION_CURRENT,
104+
ES91BFloat16FlatVectorsFormat.VECTOR_DATA_CODEC_NAME,
105+
ES91BFloat16FlatVectorsFormat.VERSION_CURRENT,
106106
state.segmentInfo.getId(),
107107
state.segmentSuffix);
108108
} catch (Throwable t) {
@@ -160,19 +160,20 @@ private void writeField(FieldWriter<?> fieldData, int maxDoc) throws IOException
160160
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
161161
switch (fieldData.fieldInfo.getVectorEncoding()) {
162162
case BYTE -> writeByteVectors(fieldData);
163-
case FLOAT32 -> writeFloat32Vectors(fieldData);
163+
case FLOAT32 -> writeBFloat16Vectors(fieldData);
164164
}
165165
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
166166

167167
writeMeta(
168168
fieldData.fieldInfo, maxDoc, vectorDataOffset, vectorDataLength, fieldData.docsWithField);
169169
}
170170

171-
private void writeFloat32Vectors(FieldWriter<?> fieldData) throws IOException {
171+
private void writeBFloat16Vectors(FieldWriter<?> fieldData) throws IOException {
172172
final ByteBuffer buffer =
173-
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
173+
ByteBuffer.allocate(fieldData.dim * BFloat16.BYTES).order(ByteOrder.LITTLE_ENDIAN);
174174
for (Object v : fieldData.vectors) {
175-
buffer.asFloatBuffer().put((float[]) v);
175+
short[] data = BFloat16.floatToBFloat16((float[]) v);
176+
buffer.asShortBuffer().put(data);
176177
vectorData.writeBytes(buffer.array(), buffer.array().length);
177178
}
178179
}
@@ -195,21 +196,22 @@ private void writeSortingField(FieldWriter<?> fieldData, int maxDoc, Sorter.DocM
195196
long vectorDataOffset =
196197
switch (fieldData.fieldInfo.getVectorEncoding()) {
197198
case BYTE -> writeSortedByteVectors(fieldData, ordMap);
198-
case FLOAT32 -> writeSortedFloat32Vectors(fieldData, ordMap);
199+
case FLOAT32 -> writeSortedBFloat16Vectors(fieldData, ordMap);
199200
};
200201
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
201202

202203
writeMeta(fieldData.fieldInfo, maxDoc, vectorDataOffset, vectorDataLength, newDocsWithField);
203204
}
204205

205-
private long writeSortedFloat32Vectors(FieldWriter<?> fieldData, int[] ordMap)
206+
private long writeSortedBFloat16Vectors(FieldWriter<?> fieldData, int[] ordMap)
206207
throws IOException {
207208
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
208209
final ByteBuffer buffer =
209-
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
210+
ByteBuffer.allocate(fieldData.dim * BFloat16.BYTES).order(ByteOrder.LITTLE_ENDIAN);
210211
for (int ordinal : ordMap) {
211212
float[] vector = (float[]) fieldData.vectors.get(ordinal);
212-
buffer.asFloatBuffer().put(vector);
213+
short[] data = BFloat16.floatToBFloat16(vector);
214+
buffer.asShortBuffer().put(data);
213215
vectorData.writeBytes(buffer.array(), buffer.array().length);
214216
}
215217
return vectorDataOffset;
@@ -383,13 +385,14 @@ private static DocsWithFieldSet writeVectorData(
383385
IndexOutput output, FloatVectorValues floatVectorValues) throws IOException {
384386
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
385387
ByteBuffer buffer =
386-
ByteBuffer.allocate(floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize)
388+
ByteBuffer.allocate(floatVectorValues.dimension() * BFloat16.BYTES)
387389
.order(ByteOrder.LITTLE_ENDIAN);
388390
KnnVectorValues.DocIndexIterator iter = floatVectorValues.iterator();
389391
for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) {
390392
// write vector
391393
float[] value = floatVectorValues.vectorValue(iter.index());
392-
buffer.asFloatBuffer().put(value);
394+
short[] data = BFloat16.floatToBFloat16(value);
395+
buffer.asShortBuffer().put(data);
393396
output.writeBytes(buffer.array(), buffer.limit());
394397
docsWithField.add(docV);
395398
}
@@ -416,14 +419,14 @@ static FieldWriter<?> create(FieldInfo fieldInfo) {
416419
int dim = fieldInfo.getVectorDimension();
417420
return switch (fieldInfo.getVectorEncoding()) {
418421
case BYTE ->
419-
new Lucene99FlatVectorsWriter.FieldWriter<byte[]>(fieldInfo) {
422+
new FieldWriter<byte[]>(fieldInfo) {
420423
@Override
421424
public byte[] copyValue(byte[] value) {
422425
return ArrayUtil.copyOfSubArray(value, 0, dim);
423426
}
424427
};
425428
case FLOAT32 ->
426-
new Lucene99FlatVectorsWriter.FieldWriter<float[]>(fieldInfo) {
429+
new FieldWriter<float[]>(fieldInfo) {
427430
@Override
428431
public float[] copyValue(float[] value) {
429432
return ArrayUtil.copyOfSubArray(value, 0, dim);

0 commit comments

Comments
 (0)