Skip to content

Commit 7eaaff8

Browse files
authored
Add bfloat16 to bbq_disk and bbq_hnsw (#136179)
Don't expose them for now, they will be enabled later on
1 parent e717741 commit 7eaaff8

19 files changed

+1570
-147
lines changed
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.index.codec.vectors;
11+
12+
import org.apache.lucene.util.BitUtil;
13+
14+
import java.nio.ByteOrder;
15+
import java.nio.ShortBuffer;
16+
17+
public final class BFloat16 {
18+
19+
public static final int BYTES = Short.BYTES;
20+
21+
public static short floatToBFloat16(float f) {
22+
// this rounds towards 0
23+
// zero - zero exp, zero fraction
24+
// denormal - zero exp, non-zero fraction
25+
// infinity - all-1 exp, zero fraction
26+
// NaN - all-1 exp, non-zero fraction
27+
// the Float.NaN constant is 0x7fc0_0000, so this won't turn the most common NaN values into
28+
// infinities
29+
return (short) (Float.floatToIntBits(f) >>> 16);
30+
}
31+
32+
public static float bFloat16ToFloat(short bf) {
33+
return Float.intBitsToFloat(bf << 16);
34+
}
35+
36+
public static void floatToBFloat16(float[] floats, ShortBuffer bFloats) {
37+
assert bFloats.remaining() == floats.length;
38+
assert bFloats.order() == ByteOrder.LITTLE_ENDIAN;
39+
for (float v : floats) {
40+
bFloats.put(floatToBFloat16(v));
41+
}
42+
}
43+
44+
public static void bFloat16ToFloat(byte[] bfBytes, float[] floats) {
45+
assert floats.length * 2 == bfBytes.length;
46+
for (int i = 0; i < floats.length; i++) {
47+
floats[i] = bFloat16ToFloat((short) BitUtil.VH_LE_SHORT.get(bfBytes, i * 2));
48+
}
49+
}
50+
51+
public static void bFloat16ToFloat(ShortBuffer bFloats, float[] floats) {
52+
assert floats.length == bFloats.remaining();
53+
assert bFloats.order() == ByteOrder.LITTLE_ENDIAN;
54+
for (int i = 0; i < floats.length; i++) {
55+
floats[i] = bFloat16ToFloat(bFloats.get());
56+
}
57+
}
58+
59+
private BFloat16() {}
60+
}

server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626

2727
/** Utility class for vector quantization calculations */
2828
public class BQVectorUtils {
29-
private static final float EPSILON = 1e-4f;
29+
// NOTE: this is currently > 1e-4f due to bfloat16
30+
private static final float EPSILON = 1e-2f;
3031

3132
public static double sqrtNewtonRaphson(double x, double curr, double prev) {
3233
return (curr == prev) ? curr : sqrtNewtonRaphson(x, 0.5 * (curr + x / curr), curr);

server/src/main/java/org/elasticsearch/index/codec/vectors/DirectIOCapableFlatVectorsFormat.java

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,81 @@
1111

1212
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
1313
import org.apache.lucene.index.SegmentReadState;
14+
import org.apache.lucene.store.FlushInfo;
15+
import org.apache.lucene.store.IOContext;
16+
import org.apache.lucene.store.MergeInfo;
17+
import org.elasticsearch.common.util.set.Sets;
18+
import org.elasticsearch.index.codec.vectors.es818.DirectIOHint;
19+
import org.elasticsearch.index.store.FsDirectoryFactory;
1420

1521
import java.io.IOException;
22+
import java.util.Set;
1623

1724
public abstract class DirectIOCapableFlatVectorsFormat extends AbstractFlatVectorsFormat {
1825
protected DirectIOCapableFlatVectorsFormat(String name) {
1926
super(name);
2027
}
2128

29+
protected abstract FlatVectorsReader createReader(SegmentReadState state) throws IOException;
30+
31+
protected static boolean canUseDirectIO(SegmentReadState state) {
32+
return FsDirectoryFactory.isHybridFs(state.directory);
33+
}
34+
2235
@Override
2336
public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
2437
return fieldsReader(state, false);
2538
}
2639

27-
public abstract FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectIO) throws IOException;
40+
public FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectIO) throws IOException {
41+
if (state.context.context() == IOContext.Context.DEFAULT && useDirectIO && canUseDirectIO(state)) {
42+
// only override the context for the random-access use case
43+
SegmentReadState directIOState = new SegmentReadState(
44+
state.directory,
45+
state.segmentInfo,
46+
state.fieldInfos,
47+
new DirectIOContext(state.context.hints()),
48+
state.segmentSuffix
49+
);
50+
// Use mmap for merges and direct I/O for searches.
51+
return new MergeReaderWrapper(createReader(directIOState), createReader(state));
52+
} else {
53+
return createReader(state);
54+
}
55+
}
56+
57+
protected static class DirectIOContext implements IOContext {
58+
59+
final Set<FileOpenHint> hints;
60+
61+
public DirectIOContext(Set<FileOpenHint> hints) {
62+
// always add DirectIOHint to the hints given
63+
this.hints = Sets.union(hints, Set.of(DirectIOHint.INSTANCE));
64+
}
65+
66+
@Override
67+
public Context context() {
68+
return Context.DEFAULT;
69+
}
70+
71+
@Override
72+
public MergeInfo mergeInfo() {
73+
return null;
74+
}
75+
76+
@Override
77+
public FlushInfo flushInfo() {
78+
return null;
79+
}
80+
81+
@Override
82+
public Set<FileOpenHint> hints() {
83+
return hints;
84+
}
85+
86+
@Override
87+
public IOContext withHints(FileOpenHint... hints) {
88+
return new DirectIOContext(Set.of(hints));
89+
}
90+
}
2891
}

server/src/main/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ public QuantizationResult[] multiScalarQuantize(
109109
}
110110

111111
public QuantizationResult scalarQuantize(float[] vector, float[] residualDestination, int[] destination, byte bits, float[] centroid) {
112-
assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector);
112+
assert similarityFunction != COSINE || BQVectorUtils.isUnitVector(vector);
113113
assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
114114
assert vector.length <= destination.length;
115115
assert bits > 0 && bits <= 8;

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsFormat.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,12 @@ public class ES920DiskBBQVectorsFormat extends KnnVectorsFormat {
5858
public static final int VERSION_DIRECT_IO = 1;
5959
public static final int VERSION_CURRENT = VERSION_DIRECT_IO;
6060

61-
private static final DirectIOCapableFlatVectorsFormat rawVectorFormat = new DirectIOCapableLucene99FlatVectorsFormat(
61+
private static final DirectIOCapableFlatVectorsFormat float32VectorFormat = new DirectIOCapableLucene99FlatVectorsFormat(
6262
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
6363
);
6464
private static final Map<String, DirectIOCapableFlatVectorsFormat> supportedFormats = Map.of(
65-
rawVectorFormat.getName(),
66-
rawVectorFormat
65+
float32VectorFormat.getName(),
66+
float32VectorFormat
6767
);
6868

6969
// This dynamically sets the cluster probe based on the `k` requested and the number of clusters.
@@ -79,6 +79,7 @@ public class ES920DiskBBQVectorsFormat extends KnnVectorsFormat {
7979
private final int vectorPerCluster;
8080
private final int centroidsPerParentCluster;
8181
private final boolean useDirectIO;
82+
private final DirectIOCapableFlatVectorsFormat rawVectorFormat;
8283

8384
public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster) {
8485
this(vectorPerCluster, centroidsPerParentCluster, false);
@@ -109,6 +110,7 @@ public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentClu
109110
this.vectorPerCluster = vectorPerCluster;
110111
this.centroidsPerParentCluster = centroidsPerParentCluster;
111112
this.useDirectIO = useDirectIO;
113+
this.rawVectorFormat = float32VectorFormat;
112114
}
113115

114116
/** Constructs a format using the given graph construction parameters and scalar quantization. */

server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
2828
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
2929
import org.elasticsearch.index.codec.vectors.BQSpaceUtils;
30+
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
3031
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
3132
import org.elasticsearch.simdvec.ESVectorUtil;
3233

@@ -70,7 +71,7 @@ public RandomVectorScorer getRandomVectorScorer(
7071
assert binarizedVectors.size() > 0 : "BinarizedByteVectorValues must have at least one vector for ES816BinaryFlatVectorsScorer";
7172
OptimizedScalarQuantizer quantizer = binarizedVectors.getQuantizer();
7273
float[] centroid = binarizedVectors.getCentroid();
73-
assert similarityFunction != COSINE || VectorUtil.isUnitVector(target);
74+
assert similarityFunction != COSINE || BQVectorUtils.isUnitVector(target);
7475
float[] scratch = new float[vectorValues.dimension()];
7576
int[] initial = new int[target.length];
7677
byte[] quantized = new byte[BQSpaceUtils.B_QUERY * binarizedVectors.discretizedDimensions() / 8];

server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java

Lines changed: 4 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,17 @@
2525
import org.apache.lucene.search.DocAndFloatFeatureBuffer;
2626
import org.apache.lucene.search.DocIdSetIterator;
2727
import org.apache.lucene.search.VectorScorer;
28-
import org.apache.lucene.store.FlushInfo;
2928
import org.apache.lucene.store.IOContext;
3029
import org.apache.lucene.store.IndexInput;
31-
import org.apache.lucene.store.MergeInfo;
3230
import org.apache.lucene.util.Bits;
3331
import org.apache.lucene.util.hnsw.RandomVectorScorer;
34-
import org.elasticsearch.common.util.set.Sets;
3532
import org.elasticsearch.index.codec.vectors.BulkScorableFloatVectorValues;
3633
import org.elasticsearch.index.codec.vectors.BulkScorableVectorValues;
3734
import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat;
3835
import org.elasticsearch.index.codec.vectors.MergeReaderWrapper;
39-
import org.elasticsearch.index.codec.vectors.es818.DirectIOHint;
40-
import org.elasticsearch.index.store.FsDirectoryFactory;
4136

4237
import java.io.IOException;
4338
import java.util.List;
44-
import java.util.Set;
4539

4640
public class DirectIOCapableLucene99FlatVectorsFormat extends DirectIOCapableFlatVectorsFormat {
4741

@@ -61,17 +55,13 @@ protected FlatVectorsScorer flatVectorsScorer() {
6155
}
6256

6357
@Override
64-
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
65-
return new Lucene99FlatVectorsWriter(state, vectorsScorer);
66-
}
67-
68-
static boolean canUseDirectIO(SegmentReadState state) {
69-
return FsDirectoryFactory.isHybridFs(state.directory);
58+
protected FlatVectorsReader createReader(SegmentReadState state) throws IOException {
59+
return new Lucene99FlatVectorsReader(state, vectorsScorer);
7060
}
7161

7262
@Override
73-
public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
74-
return fieldsReader(state, false);
63+
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
64+
return new Lucene99FlatVectorsWriter(state, vectorsScorer);
7565
}
7666

7767
@Override
@@ -99,41 +89,6 @@ public FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectI
9989
}
10090
}
10191

102-
static class DirectIOContext implements IOContext {
103-
104-
final Set<FileOpenHint> hints;
105-
106-
DirectIOContext(Set<FileOpenHint> hints) {
107-
// always add DirectIOHint to the hints given
108-
this.hints = Sets.union(hints, Set.of(DirectIOHint.INSTANCE));
109-
}
110-
111-
@Override
112-
public Context context() {
113-
return Context.DEFAULT;
114-
}
115-
116-
@Override
117-
public MergeInfo mergeInfo() {
118-
return null;
119-
}
120-
121-
@Override
122-
public FlushInfo flushInfo() {
123-
return null;
124-
}
125-
126-
@Override
127-
public Set<FileOpenHint> hints() {
128-
return hints;
129-
}
130-
131-
@Override
132-
public IOContext withHints(FileOpenHint... hints) {
133-
return new DirectIOContext(Set.of(hints));
134-
}
135-
}
136-
13792
static class Lucene99FlatBulkScoringVectorsReader extends FlatVectorsReader {
13893
private final Lucene99FlatVectorsReader inner;
13994
private final SegmentReadState state;
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* @notice
3+
* Licensed to the Apache Software Foundation (ASF) under one or more
4+
* contributor license agreements. See the NOTICE file distributed with
5+
* this work for additional information regarding copyright ownership.
6+
* The ASF licenses this file to You under the Apache License, Version 2.0
7+
* (the "License"); you may not use this file except in compliance with
8+
* the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*
18+
* Modifications copyright (C) 2025 Elasticsearch B.V.
19+
*/
20+
package org.elasticsearch.index.codec.vectors.es93;
21+
22+
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
23+
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
24+
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
25+
import org.apache.lucene.index.SegmentReadState;
26+
import org.apache.lucene.index.SegmentWriteState;
27+
import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat;
28+
29+
import java.io.IOException;
30+
31+
public final class ES93BFloat16FlatVectorsFormat extends DirectIOCapableFlatVectorsFormat {
32+
33+
static final String NAME = "ES93BFloat16FlatVectorsFormat";
34+
static final String META_CODEC_NAME = "ES93BFloat16FlatVectorsFormatMeta";
35+
static final String VECTOR_DATA_CODEC_NAME = "ES93BFloat16FlatVectorsFormatData";
36+
static final String META_EXTENSION = "vemf";
37+
static final String VECTOR_DATA_EXTENSION = "vec";
38+
39+
public static final int VERSION_START = 0;
40+
public static final int VERSION_CURRENT = VERSION_START;
41+
42+
static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16;
43+
private final FlatVectorsScorer vectorsScorer;
44+
45+
public ES93BFloat16FlatVectorsFormat(FlatVectorsScorer vectorsScorer) {
46+
super(NAME);
47+
this.vectorsScorer = vectorsScorer;
48+
}
49+
50+
@Override
51+
protected FlatVectorsScorer flatVectorsScorer() {
52+
return vectorsScorer;
53+
}
54+
55+
@Override
56+
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
57+
return new ES93BFloat16FlatVectorsWriter(state, vectorsScorer);
58+
}
59+
60+
@Override
61+
protected FlatVectorsReader createReader(SegmentReadState state) throws IOException {
62+
return new ES93BFloat16FlatVectorsReader(state, vectorsScorer);
63+
}
64+
}

0 commit comments

Comments
 (0)