Skip to content

Commit 88b682c

Browse files
committed
Add BFloat16 raw vector format to bbq_hnsw and bbq_disk
1 parent 89c58cf commit 88b682c

23 files changed

+1808
-231
lines changed

qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ record CmdLineArgs(
5151
float filterSelectivity,
5252
long seed,
5353
VectorSimilarityFunction vectorSpace,
54+
int rawVectorSize,
5455
int quantizeBits,
5556
VectorEncoding vectorEncoding,
5657
int dimensions,
@@ -80,6 +81,7 @@ record CmdLineArgs(
8081
static final ParseField FORCE_MERGE_FIELD = new ParseField("force_merge");
8182
static final ParseField VECTOR_SPACE_FIELD = new ParseField("vector_space");
8283
static final ParseField QUANTIZE_BITS_FIELD = new ParseField("quantize_bits");
84+
static final ParseField RAW_VECTOR_SIZE_FIELD = new ParseField("raw_vector_size");
8385
static final ParseField VECTOR_ENCODING_FIELD = new ParseField("vector_encoding");
8486
static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions");
8587
static final ParseField EARLY_TERMINATION_FIELD = new ParseField("early_termination");
@@ -123,6 +125,7 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException {
123125
PARSER.declareBoolean(Builder::setReindex, REINDEX_FIELD);
124126
PARSER.declareBoolean(Builder::setForceMerge, FORCE_MERGE_FIELD);
125127
PARSER.declareString(Builder::setVectorSpace, VECTOR_SPACE_FIELD);
128+
PARSER.declareInt(Builder::setRawVectorSize, RAW_VECTOR_SIZE_FIELD);
126129
PARSER.declareInt(Builder::setQuantizeBits, QUANTIZE_BITS_FIELD);
127130
PARSER.declareString(Builder::setVectorEncoding, VECTOR_ENCODING_FIELD);
128131
PARSER.declareInt(Builder::setDimensions, DIMENSIONS_FIELD);
@@ -161,6 +164,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
161164
builder.field(REINDEX_FIELD.getPreferredName(), reindex);
162165
builder.field(FORCE_MERGE_FIELD.getPreferredName(), forceMerge);
163166
builder.field(VECTOR_SPACE_FIELD.getPreferredName(), vectorSpace.name().toLowerCase(Locale.ROOT));
167+
builder.field(RAW_VECTOR_SIZE_FIELD.getPreferredName(), rawVectorSize);
164168
builder.field(QUANTIZE_BITS_FIELD.getPreferredName(), quantizeBits);
165169
builder.field(VECTOR_ENCODING_FIELD.getPreferredName(), vectorEncoding.name().toLowerCase(Locale.ROOT));
166170
builder.field(DIMENSIONS_FIELD.getPreferredName(), dimensions);
@@ -196,6 +200,7 @@ static class Builder {
196200
private boolean reindex = false;
197201
private boolean forceMerge = false;
198202
private VectorSimilarityFunction vectorSpace = VectorSimilarityFunction.EUCLIDEAN;
203+
private int rawVectorSize = 32;
199204
private int quantizeBits = 8;
200205
private VectorEncoding vectorEncoding = VectorEncoding.FLOAT32;
201206
private int dimensions;
@@ -305,6 +310,11 @@ public Builder setVectorSpace(String vectorSpace) {
305310
return this;
306311
}
307312

313+
public Builder setRawVectorSize(int rawVectorSize) {
314+
this.rawVectorSize = rawVectorSize;
315+
return this;
316+
}
317+
308318
public Builder setQuantizeBits(int quantizeBits) {
309319
this.quantizeBits = quantizeBits;
310320
return this;
@@ -380,6 +390,7 @@ public CmdLineArgs build() {
380390
filterSelectivity,
381391
seed,
382392
vectorSpace,
393+
rawVectorSize,
383394
quantizeBits,
384395
vectorEncoding,
385396
dimensions,

rest-api-spec/src/main/resources/rest-api-spec/api/cat.segments.json

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -93,29 +93,6 @@
9393
"micros",
9494
"nanos"
9595
]
96-
},
97-
"ignore_unavailable": {
98-
"type": "boolean",
99-
"description": "Whether specified concrete indices should be ignored when unavailable (missing or closed). Only allowed when providing an index expression."
100-
},
101-
"ignore_throttled": {
102-
"type": "boolean",
103-
"description": "Whether specified concrete, expanded or aliased indices should be ignored when throttled. Only allowed when providing an index expression."
104-
},
105-
"allow_no_indices": {
106-
"type": "boolean",
107-
"description": "Whether to ignore if a wildcard indices expression resolves into no concrete indices. (This includes `_all` string or when no indices have been specified). Only allowed when providing an index expression."
108-
},
109-
"expand_wildcards": {
110-
"type": "enum",
111-
"options": ["open", "closed", "hidden", "none", "all"],
112-
"default": "open",
113-
"description": "Whether to expand wildcard expression to concrete indices that are open, closed or both."
114-
},
115-
"allow_closed": {
116-
"type": "boolean",
117-
"description": "If true, allow closed indices to be returned in the response otherwise if false, keep the legacy behaviour of throwing an exception if index pattern matches closed indices",
118-
"default": false
11996
}
12097
}
12198
}

rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/cat.segments/10_basic.yml

Lines changed: 0 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -209,77 +209,3 @@ tsdb:
209209
$body: |
210210
/^(tsdb \s+ 0 \s+ p \s+ \d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3} \s+ _\d (\s\d){3} \s+
211211
(\d+|\d+[.]\d+)(kb|b) \s+ \d+ (\s+ (false|true)){2} \s+ \d+\.\d+(\.\d+)? \s+ (false|true) \s? \n?)$/
212-
213-
---
214-
Wildcard Expansion Settings:
215-
- requires:
216-
capabilities:
217-
- method: GET
218-
path: /_cat/segments
219-
capabilities: [ allow_closed ]
220-
test_runner_features: [ capabilities ]
221-
reason: Capability required to run test
222-
223-
- do:
224-
indices.create:
225-
index: basic-index
226-
body:
227-
settings:
228-
number_of_shards: 1
229-
number_of_replicas: 0
230-
231-
- do:
232-
index:
233-
index: basic-index
234-
id: "1"
235-
body:
236-
field: "basic doc 1"
237-
238-
- do:
239-
indices.create:
240-
index: hidden-index
241-
body:
242-
settings:
243-
number_of_shards: 1
244-
number_of_replicas: 0
245-
index.hidden: true
246-
247-
- do:
248-
index:
249-
index: hidden-index
250-
id: "1"
251-
body:
252-
field: "hidden doc 1"
253-
254-
- do:
255-
indices.create:
256-
index: closed-index
257-
body:
258-
settings:
259-
number_of_shards: 1
260-
number_of_replicas: 0
261-
262-
- do:
263-
index:
264-
index: closed-index
265-
id: "1"
266-
body:
267-
field: "closed doc 1"
268-
269-
- do:
270-
indices.refresh:
271-
index: [ basic-index, closed-index, hidden-index ]
272-
273-
- do:
274-
indices.close:
275-
index: closed-index
276-
277-
- do:
278-
cat.segments:
279-
v: true
280-
s: index
281-
expand_wildcards: all
282-
allow_closed: true
283-
- match:
284-
$body: |
285-
/basic-index(\s)+0(\s)+p.*\nclosed-index(\s)+0(\s)+p.*\nhidden-index(\s)+0(\s)+p.*/

server/src/main/java/org/elasticsearch/cluster/coordination/CoordinationState.java

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,9 @@ public Join handleStartJoin(StartJoinRequest startJoinRequest) {
211211
}
212212

213213
/**
214-
* May be called on receipt of a {@link Join}, which is effectively a vote for the receiving node to be the elected master.
214+
* May be called on receipt of a Join.
215215
*
216-
* @param join The {@link Join} received.
216+
* @param join The Join received.
217217
* @return true iff this instance does not already have a join vote from the given source node for this term
218218
* @throws CoordinationStateRejectedException if the arguments were incompatible with the current state of this object.
219219
*/
@@ -234,9 +234,6 @@ public boolean handleJoin(Join join) {
234234

235235
final long lastAcceptedTerm = getLastAcceptedTerm();
236236
if (join.lastAcceptedTerm() > lastAcceptedTerm) {
237-
// Note that this is running on the receiving node, so it must reject joins from nodes with fresher state. This is unlike a
238-
// real-world election where candidates will accept every vote they receive and it's the voter's responsibility to be selective
239-
// about the votes they cast.
240237
logger.debug(
241238
"handleJoin: ignored join as joiner has a better last accepted term (expected: <=[{}], actual: [{}])",
242239
lastAcceptedTerm,
@@ -251,9 +248,6 @@ public boolean handleJoin(Join join) {
251248
}
252249

253250
if (join.lastAcceptedTerm() == lastAcceptedTerm && join.lastAcceptedVersion() > getLastAcceptedVersion()) {
254-
// Note that this is running on the receiving node, so it must reject joins from nodes with fresher state. This is unlike a
255-
// real-world election where candidates will accept every vote they receive and it's the voter's responsibility to be selective
256-
// about the votes they cast.
257251
logger.debug(
258252
"handleJoin: ignored join as joiner has a better last accepted version (expected: <=[{}], actual: [{}]) in term {}",
259253
getLastAcceptedVersion(),
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.index.codec.vectors;
11+
12+
import org.apache.lucene.util.BitUtil;
13+
14+
import java.nio.ByteOrder;
15+
import java.nio.ShortBuffer;
16+
17+
public class BFloat16 {
18+
19+
public static final int BYTES = Short.BYTES;
20+
21+
public static short floatToBFloat16(float f) {
22+
// this rounds towards 0
23+
// zero - zero exp, zero fraction
24+
// denormal - zero exp, non-zero fraction
25+
// infinity - all-1 exp, zero fraction
26+
// NaN - all-1 exp, non-zero fraction
27+
// the Float.NaN constant is 0x7fc0_0000, so this won't turn the most common NaN values into
28+
// infinities
29+
return (short) (Float.floatToIntBits(f) >>> 16);
30+
}
31+
32+
public static float bFloat16ToFloat(short bf) {
33+
return Float.intBitsToFloat(bf << 16);
34+
}
35+
36+
public static void floatToBFloat16(float[] floats, ShortBuffer bFloats) {
37+
assert bFloats.remaining() == floats.length;
38+
assert bFloats.order() == ByteOrder.LITTLE_ENDIAN;
39+
for (float v : floats) {
40+
bFloats.put(floatToBFloat16(v));
41+
}
42+
}
43+
44+
public static void bFloat16ToFloat(byte[] bfBytes, float[] floats) {
45+
assert floats.length * 2 == bfBytes.length;
46+
for (int i = 0; i < floats.length; i++) {
47+
floats[i] = bFloat16ToFloat((short) BitUtil.VH_LE_SHORT.get(bfBytes, i * 2));
48+
}
49+
}
50+
51+
public static void bFloat16ToFloat(ShortBuffer bFloats, float[] floats) {
52+
assert floats.length == bFloats.remaining();
53+
assert bFloats.order() == ByteOrder.LITTLE_ENDIAN;
54+
for (int i = 0; i < floats.length; i++) {
55+
floats[i] = bFloat16ToFloat(bFloats.get());
56+
}
57+
}
58+
59+
private BFloat16() {}
60+
}

server/src/main/java/org/elasticsearch/index/codec/vectors/DirectIOCapableFlatVectorsFormat.java

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,81 @@
1111

1212
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
1313
import org.apache.lucene.index.SegmentReadState;
14+
import org.apache.lucene.store.FlushInfo;
15+
import org.apache.lucene.store.IOContext;
16+
import org.apache.lucene.store.MergeInfo;
17+
import org.elasticsearch.common.util.set.Sets;
18+
import org.elasticsearch.index.codec.vectors.es818.DirectIOHint;
19+
import org.elasticsearch.index.store.FsDirectoryFactory;
1420

1521
import java.io.IOException;
22+
import java.util.Set;
1623

1724
public abstract class DirectIOCapableFlatVectorsFormat extends AbstractFlatVectorsFormat {
1825
protected DirectIOCapableFlatVectorsFormat(String name) {
1926
super(name);
2027
}
2128

29+
protected abstract FlatVectorsReader createReader(SegmentReadState state) throws IOException;
30+
31+
static boolean canUseDirectIO(SegmentReadState state) {
32+
return FsDirectoryFactory.isHybridFs(state.directory);
33+
}
34+
2235
@Override
2336
public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
2437
return fieldsReader(state, false);
2538
}
2639

27-
public abstract FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectIO) throws IOException;
40+
public FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectIO) throws IOException {
41+
if (state.context.context() == IOContext.Context.DEFAULT && useDirectIO && canUseDirectIO(state)) {
42+
// only override the context for the random-access use case
43+
SegmentReadState directIOState = new SegmentReadState(
44+
state.directory,
45+
state.segmentInfo,
46+
state.fieldInfos,
47+
new DirectIOContext(state.context.hints()),
48+
state.segmentSuffix
49+
);
50+
// Use mmap for merges and direct I/O for searches.
51+
return new MergeReaderWrapper(createReader(directIOState), createReader(state));
52+
} else {
53+
return createReader(state);
54+
}
55+
}
56+
57+
static class DirectIOContext implements IOContext {
58+
59+
final Set<FileOpenHint> hints;
60+
61+
DirectIOContext(Set<FileOpenHint> hints) {
62+
// always add DirectIOHint to the hints given
63+
this.hints = Sets.union(hints, Set.of(DirectIOHint.INSTANCE));
64+
}
65+
66+
@Override
67+
public Context context() {
68+
return Context.DEFAULT;
69+
}
70+
71+
@Override
72+
public MergeInfo mergeInfo() {
73+
return null;
74+
}
75+
76+
@Override
77+
public FlushInfo flushInfo() {
78+
return null;
79+
}
80+
81+
@Override
82+
public Set<FileOpenHint> hints() {
83+
return hints;
84+
}
85+
86+
@Override
87+
public IOContext withHints(FileOpenHint... hints) {
88+
return new DirectIOContext(Set.of(hints));
89+
}
90+
}
2891
}

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat;
1919
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
2020
import org.elasticsearch.index.codec.vectors.es93.DirectIOCapableLucene99FlatVectorsFormat;
21+
import org.elasticsearch.index.codec.vectors.es93.ES93BFloat16FlatVectorsFormat;
2122

2223
import java.io.IOException;
2324
import java.util.Map;
@@ -58,12 +59,17 @@ public class ES920DiskBBQVectorsFormat extends KnnVectorsFormat {
5859
public static final int VERSION_DIRECT_IO = 1;
5960
public static final int VERSION_CURRENT = VERSION_DIRECT_IO;
6061

61-
private static final DirectIOCapableFlatVectorsFormat rawVectorFormat = new DirectIOCapableLucene99FlatVectorsFormat(
62+
private static final DirectIOCapableFlatVectorsFormat float32VectorFormat = new DirectIOCapableLucene99FlatVectorsFormat(
63+
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
64+
);
65+
private static final DirectIOCapableFlatVectorsFormat bfloat16VectorFormat = new ES93BFloat16FlatVectorsFormat(
6266
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
6367
);
6468
private static final Map<String, DirectIOCapableFlatVectorsFormat> supportedFormats = Map.of(
65-
rawVectorFormat.getName(),
66-
rawVectorFormat
69+
float32VectorFormat.getName(),
70+
float32VectorFormat,
71+
bfloat16VectorFormat.getName(),
72+
bfloat16VectorFormat
6773
);
6874

6975
// This dynamically sets the cluster probe based on the `k` requested and the number of clusters.
@@ -79,12 +85,13 @@ public class ES920DiskBBQVectorsFormat extends KnnVectorsFormat {
7985
private final int vectorPerCluster;
8086
private final int centroidsPerParentCluster;
8187
private final boolean useDirectIO;
88+
private final DirectIOCapableFlatVectorsFormat rawVectorFormat;
8289

8390
public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster) {
84-
this(vectorPerCluster, centroidsPerParentCluster, false);
91+
this(vectorPerCluster, centroidsPerParentCluster, false, false);
8592
}
8693

87-
public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster, boolean useDirectIO) {
94+
public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster, boolean useDirectIO, boolean useBFloat16) {
8895
super(NAME);
8996
if (vectorPerCluster < MIN_VECTORS_PER_CLUSTER || vectorPerCluster > MAX_VECTORS_PER_CLUSTER) {
9097
throw new IllegalArgumentException(
@@ -109,6 +116,7 @@ public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentClu
109116
this.vectorPerCluster = vectorPerCluster;
110117
this.centroidsPerParentCluster = centroidsPerParentCluster;
111118
this.useDirectIO = useDirectIO;
119+
this.rawVectorFormat = useBFloat16 ? bfloat16VectorFormat : float32VectorFormat;
112120
}
113121

114122
/** Constructs a format using the given graph construction parameters and scalar quantization. */

0 commit comments

Comments
 (0)