Skip to content

Commit c72eda6

Browse files
committed
Add BFloat16 raw vector format to bbq_hnsw and bbq_disk
1 parent 89c58cf commit c72eda6

19 files changed

+1802
-106
lines changed

qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ record CmdLineArgs(
5151
float filterSelectivity,
5252
long seed,
5353
VectorSimilarityFunction vectorSpace,
54+
int rawVectorSize,
5455
int quantizeBits,
5556
VectorEncoding vectorEncoding,
5657
int dimensions,
@@ -80,6 +81,7 @@ record CmdLineArgs(
8081
static final ParseField FORCE_MERGE_FIELD = new ParseField("force_merge");
8182
static final ParseField VECTOR_SPACE_FIELD = new ParseField("vector_space");
8283
static final ParseField QUANTIZE_BITS_FIELD = new ParseField("quantize_bits");
84+
static final ParseField RAW_VECTOR_SIZE_FIELD = new ParseField("raw_vector_size");
8385
static final ParseField VECTOR_ENCODING_FIELD = new ParseField("vector_encoding");
8486
static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions");
8587
static final ParseField EARLY_TERMINATION_FIELD = new ParseField("early_termination");
@@ -123,6 +125,7 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException {
123125
PARSER.declareBoolean(Builder::setReindex, REINDEX_FIELD);
124126
PARSER.declareBoolean(Builder::setForceMerge, FORCE_MERGE_FIELD);
125127
PARSER.declareString(Builder::setVectorSpace, VECTOR_SPACE_FIELD);
128+
PARSER.declareInt(Builder::setRawVectorSize, RAW_VECTOR_SIZE_FIELD);
126129
PARSER.declareInt(Builder::setQuantizeBits, QUANTIZE_BITS_FIELD);
127130
PARSER.declareString(Builder::setVectorEncoding, VECTOR_ENCODING_FIELD);
128131
PARSER.declareInt(Builder::setDimensions, DIMENSIONS_FIELD);
@@ -161,6 +164,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
161164
builder.field(REINDEX_FIELD.getPreferredName(), reindex);
162165
builder.field(FORCE_MERGE_FIELD.getPreferredName(), forceMerge);
163166
builder.field(VECTOR_SPACE_FIELD.getPreferredName(), vectorSpace.name().toLowerCase(Locale.ROOT));
167+
builder.field(RAW_VECTOR_SIZE_FIELD.getPreferredName(), rawVectorSize);
164168
builder.field(QUANTIZE_BITS_FIELD.getPreferredName(), quantizeBits);
165169
builder.field(VECTOR_ENCODING_FIELD.getPreferredName(), vectorEncoding.name().toLowerCase(Locale.ROOT));
166170
builder.field(DIMENSIONS_FIELD.getPreferredName(), dimensions);
@@ -196,6 +200,7 @@ static class Builder {
196200
private boolean reindex = false;
197201
private boolean forceMerge = false;
198202
private VectorSimilarityFunction vectorSpace = VectorSimilarityFunction.EUCLIDEAN;
203+
private int rawVectorSize = 32;
199204
private int quantizeBits = 8;
200205
private VectorEncoding vectorEncoding = VectorEncoding.FLOAT32;
201206
private int dimensions;
@@ -305,6 +310,11 @@ public Builder setVectorSpace(String vectorSpace) {
305310
return this;
306311
}
307312

313+
public Builder setRawVectorSize(int rawVectorSize) {
314+
this.rawVectorSize = rawVectorSize;
315+
return this;
316+
}
317+
308318
public Builder setQuantizeBits(int quantizeBits) {
309319
this.quantizeBits = quantizeBits;
310320
return this;
@@ -380,6 +390,7 @@ public CmdLineArgs build() {
380390
filterSelectivity,
381391
seed,
382392
vectorSpace,
393+
rawVectorSize,
383394
quantizeBits,
384395
vectorEncoding,
385396
dimensions,
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.index.codec.vectors;
11+
12+
import org.apache.lucene.util.BitUtil;
13+
14+
import java.nio.ByteOrder;
15+
import java.nio.ShortBuffer;
16+
17+
public class BFloat16 {
18+
19+
public static final int BYTES = Short.BYTES;
20+
21+
public static short floatToBFloat16(float f) {
22+
// this rounds towards 0
23+
// zero - zero exp, zero fraction
24+
// denormal - zero exp, non-zero fraction
25+
// infinity - all-1 exp, zero fraction
26+
// NaN - all-1 exp, non-zero fraction
27+
// the Float.NaN constant is 0x7fc0_0000, so this won't turn the most common NaN values into
28+
// infinities
29+
return (short) (Float.floatToIntBits(f) >>> 16);
30+
}
31+
32+
public static float bFloat16ToFloat(short bf) {
33+
return Float.intBitsToFloat(bf << 16);
34+
}
35+
36+
public static void floatToBFloat16(float[] floats, ShortBuffer bFloats) {
37+
assert bFloats.remaining() == floats.length;
38+
assert bFloats.order() == ByteOrder.LITTLE_ENDIAN;
39+
for (float v : floats) {
40+
bFloats.put(floatToBFloat16(v));
41+
}
42+
}
43+
44+
public static void bFloat16ToFloat(byte[] bfBytes, float[] floats) {
45+
assert floats.length * 2 == bfBytes.length;
46+
for (int i = 0; i < floats.length; i++) {
47+
floats[i] = bFloat16ToFloat((short) BitUtil.VH_LE_SHORT.get(bfBytes, i * 2));
48+
}
49+
}
50+
51+
public static void bFloat16ToFloat(ShortBuffer bFloats, float[] floats) {
52+
assert floats.length == bFloats.remaining();
53+
assert bFloats.order() == ByteOrder.LITTLE_ENDIAN;
54+
for (int i = 0; i < floats.length; i++) {
55+
floats[i] = bFloat16ToFloat(bFloats.get());
56+
}
57+
}
58+
59+
private BFloat16() {}
60+
}

server/src/main/java/org/elasticsearch/index/codec/vectors/DirectIOCapableFlatVectorsFormat.java

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,81 @@
1111

1212
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
1313
import org.apache.lucene.index.SegmentReadState;
14+
import org.apache.lucene.store.FlushInfo;
15+
import org.apache.lucene.store.IOContext;
16+
import org.apache.lucene.store.MergeInfo;
17+
import org.elasticsearch.common.util.set.Sets;
18+
import org.elasticsearch.index.codec.vectors.es818.DirectIOHint;
19+
import org.elasticsearch.index.store.FsDirectoryFactory;
1420

1521
import java.io.IOException;
22+
import java.util.Set;
1623

1724
public abstract class DirectIOCapableFlatVectorsFormat extends AbstractFlatVectorsFormat {
1825
protected DirectIOCapableFlatVectorsFormat(String name) {
1926
super(name);
2027
}
2128

29+
protected abstract FlatVectorsReader createReader(SegmentReadState state) throws IOException;
30+
31+
static boolean canUseDirectIO(SegmentReadState state) {
32+
return FsDirectoryFactory.isHybridFs(state.directory);
33+
}
34+
2235
@Override
2336
public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
2437
return fieldsReader(state, false);
2538
}
2639

27-
public abstract FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectIO) throws IOException;
40+
public FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectIO) throws IOException {
41+
if (state.context.context() == IOContext.Context.DEFAULT && useDirectIO && canUseDirectIO(state)) {
42+
// only override the context for the random-access use case
43+
SegmentReadState directIOState = new SegmentReadState(
44+
state.directory,
45+
state.segmentInfo,
46+
state.fieldInfos,
47+
new DirectIOContext(state.context.hints()),
48+
state.segmentSuffix
49+
);
50+
// Use mmap for merges and direct I/O for searches.
51+
return new MergeReaderWrapper(createReader(directIOState), createReader(state));
52+
} else {
53+
return createReader(state);
54+
}
55+
}
56+
57+
static class DirectIOContext implements IOContext {
58+
59+
final Set<FileOpenHint> hints;
60+
61+
DirectIOContext(Set<FileOpenHint> hints) {
62+
// always add DirectIOHint to the hints given
63+
this.hints = Sets.union(hints, Set.of(DirectIOHint.INSTANCE));
64+
}
65+
66+
@Override
67+
public Context context() {
68+
return Context.DEFAULT;
69+
}
70+
71+
@Override
72+
public MergeInfo mergeInfo() {
73+
return null;
74+
}
75+
76+
@Override
77+
public FlushInfo flushInfo() {
78+
return null;
79+
}
80+
81+
@Override
82+
public Set<FileOpenHint> hints() {
83+
return hints;
84+
}
85+
86+
@Override
87+
public IOContext withHints(FileOpenHint... hints) {
88+
return new DirectIOContext(Set.of(hints));
89+
}
90+
}
2891
}

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat;
1919
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
2020
import org.elasticsearch.index.codec.vectors.es93.DirectIOCapableLucene99FlatVectorsFormat;
21+
import org.elasticsearch.index.codec.vectors.es93.ES93BFloat16FlatVectorsFormat;
2122

2223
import java.io.IOException;
2324
import java.util.Map;
@@ -58,12 +59,17 @@ public class ES920DiskBBQVectorsFormat extends KnnVectorsFormat {
5859
public static final int VERSION_DIRECT_IO = 1;
5960
public static final int VERSION_CURRENT = VERSION_DIRECT_IO;
6061

61-
private static final DirectIOCapableFlatVectorsFormat rawVectorFormat = new DirectIOCapableLucene99FlatVectorsFormat(
62+
private static final DirectIOCapableFlatVectorsFormat float32VectorFormat = new DirectIOCapableLucene99FlatVectorsFormat(
63+
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
64+
);
65+
private static final DirectIOCapableFlatVectorsFormat bfloat16VectorFormat = new ES93BFloat16FlatVectorsFormat(
6266
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
6367
);
6468
private static final Map<String, DirectIOCapableFlatVectorsFormat> supportedFormats = Map.of(
65-
rawVectorFormat.getName(),
66-
rawVectorFormat
69+
float32VectorFormat.getName(),
70+
float32VectorFormat,
71+
bfloat16VectorFormat.getName(),
72+
bfloat16VectorFormat
6773
);
6874

6975
// This dynamically sets the cluster probe based on the `k` requested and the number of clusters.
@@ -79,12 +85,13 @@ public class ES920DiskBBQVectorsFormat extends KnnVectorsFormat {
7985
private final int vectorPerCluster;
8086
private final int centroidsPerParentCluster;
8187
private final boolean useDirectIO;
88+
private final DirectIOCapableFlatVectorsFormat rawVectorFormat;
8289

8390
public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster) {
84-
this(vectorPerCluster, centroidsPerParentCluster, false);
91+
this(vectorPerCluster, centroidsPerParentCluster, false, false);
8592
}
8693

87-
public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster, boolean useDirectIO) {
94+
public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster, boolean useDirectIO, boolean useBFloat16) {
8895
super(NAME);
8996
if (vectorPerCluster < MIN_VECTORS_PER_CLUSTER || vectorPerCluster > MAX_VECTORS_PER_CLUSTER) {
9097
throw new IllegalArgumentException(
@@ -109,6 +116,7 @@ public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentClu
109116
this.vectorPerCluster = vectorPerCluster;
110117
this.centroidsPerParentCluster = centroidsPerParentCluster;
111118
this.useDirectIO = useDirectIO;
119+
this.rawVectorFormat = useBFloat16 ? bfloat16VectorFormat : float32VectorFormat;
112120
}
113121

114122
/** Constructs a format using the given graph construction parameters and scalar quantization. */

server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java

Lines changed: 4 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,9 @@
1515
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsWriter;
1616
import org.apache.lucene.index.SegmentReadState;
1717
import org.apache.lucene.index.SegmentWriteState;
18-
import org.apache.lucene.store.FlushInfo;
19-
import org.apache.lucene.store.IOContext;
20-
import org.apache.lucene.store.MergeInfo;
21-
import org.elasticsearch.common.util.set.Sets;
2218
import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat;
23-
import org.elasticsearch.index.codec.vectors.MergeReaderWrapper;
24-
import org.elasticsearch.index.codec.vectors.es818.DirectIOHint;
25-
import org.elasticsearch.index.store.FsDirectoryFactory;
2619

2720
import java.io.IOException;
28-
import java.util.Set;
2921

3022
public class DirectIOCapableLucene99FlatVectorsFormat extends DirectIOCapableFlatVectorsFormat {
3123

@@ -45,72 +37,12 @@ protected FlatVectorsScorer flatVectorsScorer() {
4537
}
4638

4739
@Override
48-
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
49-
return new Lucene99FlatVectorsWriter(state, vectorsScorer);
50-
}
51-
52-
static boolean canUseDirectIO(SegmentReadState state) {
53-
return FsDirectoryFactory.isHybridFs(state.directory);
40+
protected FlatVectorsReader createReader(SegmentReadState state) throws IOException {
41+
return new Lucene99FlatVectorsReader(state, vectorsScorer);
5442
}
5543

5644
@Override
57-
public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
58-
return fieldsReader(state, false);
59-
}
60-
61-
@Override
62-
public FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectIO) throws IOException {
63-
if (state.context.context() == IOContext.Context.DEFAULT && useDirectIO && canUseDirectIO(state)) {
64-
// only override the context for the random-access use case
65-
SegmentReadState directIOState = new SegmentReadState(
66-
state.directory,
67-
state.segmentInfo,
68-
state.fieldInfos,
69-
new DirectIOContext(state.context.hints()),
70-
state.segmentSuffix
71-
);
72-
// Use mmap for merges and direct I/O for searches.
73-
return new MergeReaderWrapper(
74-
new Lucene99FlatVectorsReader(directIOState, vectorsScorer),
75-
new Lucene99FlatVectorsReader(state, vectorsScorer)
76-
);
77-
} else {
78-
return new Lucene99FlatVectorsReader(state, vectorsScorer);
79-
}
80-
}
81-
82-
static class DirectIOContext implements IOContext {
83-
84-
final Set<FileOpenHint> hints;
85-
86-
DirectIOContext(Set<FileOpenHint> hints) {
87-
// always add DirectIOHint to the hints given
88-
this.hints = Sets.union(hints, Set.of(DirectIOHint.INSTANCE));
89-
}
90-
91-
@Override
92-
public Context context() {
93-
return Context.DEFAULT;
94-
}
95-
96-
@Override
97-
public MergeInfo mergeInfo() {
98-
return null;
99-
}
100-
101-
@Override
102-
public FlushInfo flushInfo() {
103-
return null;
104-
}
105-
106-
@Override
107-
public Set<FileOpenHint> hints() {
108-
return hints;
109-
}
110-
111-
@Override
112-
public IOContext withHints(FileOpenHint... hints) {
113-
return new DirectIOContext(Set.of(hints));
114-
}
45+
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
46+
return new Lucene99FlatVectorsWriter(state, vectorsScorer);
11547
}
11648
}

0 commit comments

Comments
 (0)