Skip to content

Commit f105dc8

Browse files
committed
Add ES92 bfloat16 vector format
1 parent d1090a0 commit f105dc8

24 files changed

+3388
-32
lines changed

server/src/main/java/module-info.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,8 @@
459459
org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat,
460460
org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat,
461461
org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat,
462+
org.elasticsearch.index.codec.vectors.es92.ES92BinaryQuantizedBFloat16VectorsFormat,
463+
org.elasticsearch.index.codec.vectors.es92.ES92HnswBinaryQuantizedBFloat16VectorsFormat,
462464
org.elasticsearch.index.codec.vectors.IVFVectorsFormat;
463465

464466
provides org.apache.lucene.codecs.Codec

server/src/main/java/org/elasticsearch/index/codec/vectors/es818/BinarizedByteVectorValues.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
/**
3131
* Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10
3232
*/
33-
abstract class BinarizedByteVectorValues extends ByteVectorValues {
33+
public abstract class BinarizedByteVectorValues extends ByteVectorValues {
3434

3535
/**
3636
* Retrieve the corrective terms for the given vector ordinal. For the dot-product family of

server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ public RandomVectorScorer getRandomVectorScorer(
108108
return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
109109
}
110110

111-
RandomVectorScorerSupplier getRandomVectorScorerSupplier(
111+
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
112112
VectorSimilarityFunction similarityFunction,
113113
ES818BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues scoringVectors,
114114
BinarizedByteVectorValues targetVectors
@@ -122,7 +122,7 @@ public String toString() {
122122
}
123123

124124
/** Vector scorer supplier over binarized vector values */
125-
static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier {
125+
public static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier {
126126
private final ES818BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors;
127127
private final BinarizedByteVectorValues targetVectors;
128128
private final VectorSimilarityFunction similarityFunction;

server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -388,11 +388,11 @@ static FieldEntry create(IndexInput input, VectorEncoding vectorEncoding, Vector
388388
}
389389

390390
/** Binarized vector values holding row and quantized vector values */
391-
protected static final class BinarizedVectorValues extends FloatVectorValues {
391+
public static final class BinarizedVectorValues extends FloatVectorValues {
392392
private final FloatVectorValues rawVectorValues;
393393
private final BinarizedByteVectorValues quantizedVectorValues;
394394

395-
BinarizedVectorValues(FloatVectorValues rawVectorValues, BinarizedByteVectorValues quantizedVectorValues) {
395+
public BinarizedVectorValues(FloatVectorValues rawVectorValues, BinarizedByteVectorValues quantizedVectorValues) {
396396
this.rawVectorValues = rawVectorValues;
397397
this.quantizedVectorValues = quantizedVectorValues;
398398
}
@@ -437,7 +437,7 @@ public VectorScorer scorer(float[] query) throws IOException {
437437
return quantizedVectorValues.scorer(query);
438438
}
439439

440-
BinarizedByteVectorValues getQuantizedVectorValues() throws IOException {
440+
public BinarizedByteVectorValues getQuantizedVectorValues() throws IOException {
441441
return quantizedVectorValues;
442442
}
443443
}

server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,7 @@ public long ramBytesUsed() {
723723
}
724724

725725
// When accessing vectorValue method, targerOrd here means a row ordinal.
726-
static class OffHeapBinarizedQueryVectorValues {
726+
public static class OffHeapBinarizedQueryVectorValues {
727727
private final IndexInput slice;
728728
private final int dimension;
729729
private final int size;
@@ -734,7 +734,7 @@ static class OffHeapBinarizedQueryVectorValues {
734734
private int lastOrd = -1;
735735
private int quantizedComponentSum;
736736

737-
OffHeapBinarizedQueryVectorValues(IndexInput data, int dimension, int size) {
737+
public OffHeapBinarizedQueryVectorValues(IndexInput data, int dimension, int size) {
738738
this.slice = data;
739739
this.dimension = dimension;
740740
this.size = size;
@@ -798,7 +798,7 @@ public byte[] vectorValue(int targetOrd) throws IOException {
798798
}
799799
}
800800

801-
static class BinarizedFloatVectorValues extends BinarizedByteVectorValues {
801+
public static class BinarizedFloatVectorValues extends BinarizedByteVectorValues {
802802
private OptimizedScalarQuantizer.QuantizationResult corrections;
803803
private final byte[] binarized;
804804
private final int[] initQuantized;
@@ -808,7 +808,7 @@ static class BinarizedFloatVectorValues extends BinarizedByteVectorValues {
808808

809809
private int lastOrd = -1;
810810

811-
BinarizedFloatVectorValues(FloatVectorValues delegate, OptimizedScalarQuantizer quantizer, float[] centroid) {
811+
public BinarizedFloatVectorValues(FloatVectorValues delegate, OptimizedScalarQuantizer quantizer, float[] centroid) {
812812
this.values = delegate;
813813
this.quantizer = quantizer;
814814
this.binarized = new byte[BQVectorUtils.discretize(delegate.dimension(), 64) / 8];
@@ -881,12 +881,16 @@ public int ordToDoc(int ord) {
881881
}
882882
}
883883

884-
static class BinarizedCloseableRandomVectorScorerSupplier implements CloseableRandomVectorScorerSupplier {
884+
public static class BinarizedCloseableRandomVectorScorerSupplier implements CloseableRandomVectorScorerSupplier {
885885
private final RandomVectorScorerSupplier supplier;
886886
private final KnnVectorValues vectorValues;
887887
private final Closeable onClose;
888888

889-
BinarizedCloseableRandomVectorScorerSupplier(RandomVectorScorerSupplier supplier, KnnVectorValues vectorValues, Closeable onClose) {
889+
public BinarizedCloseableRandomVectorScorerSupplier(
890+
RandomVectorScorerSupplier supplier,
891+
KnnVectorValues vectorValues,
892+
Closeable onClose
893+
) {
890894
this.supplier = supplier;
891895
this.onClose = onClose;
892896
this.vectorValues = vectorValues;
@@ -913,11 +917,11 @@ public int totalVectorCount() {
913917
}
914918
}
915919

916-
static final class NormalizedFloatVectorValues extends FloatVectorValues {
920+
public static final class NormalizedFloatVectorValues extends FloatVectorValues {
917921
private final FloatVectorValues values;
918922
private final float[] normalizedVector;
919923

920-
NormalizedFloatVectorValues(FloatVectorValues values) {
924+
public NormalizedFloatVectorValues(FloatVectorValues values) {
921925
this.values = values;
922926
this.normalizedVector = new float[values.dimension()];
923927
}

server/src/main/java/org/elasticsearch/index/codec/vectors/es818/MergeReaderWrapper.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@
2323
import java.util.Collection;
2424
import java.util.Map;
2525

26-
class MergeReaderWrapper extends FlatVectorsReader {
26+
public class MergeReaderWrapper extends FlatVectorsReader {
2727

2828
private final FlatVectorsReader mainReader;
2929
private final FlatVectorsReader mergeReader;
3030

31-
protected MergeReaderWrapper(FlatVectorsReader mainReader, FlatVectorsReader mergeReader) {
31+
public MergeReaderWrapper(FlatVectorsReader mainReader, FlatVectorsReader mergeReader) {
3232
super(mainReader.getFlatVectorScorer());
3333
this.mainReader = mainReader;
3434
this.mergeReader = mergeReader;

server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OffHeapBinarizedVectorValues.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
import java.nio.ByteBuffer;
3737

3838
/** Binarized vector values loaded from off-heap */
39-
abstract class OffHeapBinarizedVectorValues extends BinarizedByteVectorValues {
39+
public abstract class OffHeapBinarizedVectorValues extends BinarizedByteVectorValues {
4040

4141
final int dimension;
4242
final int size;
@@ -151,7 +151,7 @@ public int getVectorByteLength() {
151151
return numBytes;
152152
}
153153

154-
static OffHeapBinarizedVectorValues load(
154+
public static OffHeapBinarizedVectorValues load(
155155
OrdToDocDISIReaderConfiguration configuration,
156156
int dimension,
157157
int size,
@@ -197,8 +197,8 @@ static OffHeapBinarizedVectorValues load(
197197
}
198198

199199
/** Dense off-heap binarized vector values */
200-
static class DenseOffHeapVectorValues extends OffHeapBinarizedVectorValues {
201-
DenseOffHeapVectorValues(
200+
public static class DenseOffHeapVectorValues extends OffHeapBinarizedVectorValues {
201+
public DenseOffHeapVectorValues(
202202
int dimension,
203203
int size,
204204
float[] centroid,
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.index.codec.vectors.es92;
11+
12+
import org.apache.lucene.util.BitUtil;
13+
14+
import java.nio.ByteOrder;
15+
import java.nio.ShortBuffer;
16+
17+
class BFloat16 {
18+
19+
public static final int BYTES = Short.BYTES;
20+
21+
public static short floatToBFloat16(float f) {
22+
// this rounds towards 0
23+
// zero - zero exp, zero fraction
24+
// denormal - zero exp, non-zero fraction
25+
// infinity - all-1 exp, zero fraction
26+
// NaN - all-1 exp, non-zero fraction
27+
// the Float.NaN constant is 0x7fc0_0000, so this won't turn the most common NaN values into
28+
// infinities
29+
return (short) (Float.floatToIntBits(f) >>> 16);
30+
}
31+
32+
public static float bFloat16ToFloat(short bf) {
33+
return Float.intBitsToFloat(bf << 16);
34+
}
35+
36+
public static void floatToBFloat16(float[] floats, ShortBuffer bFloats) {
37+
assert bFloats.remaining() == floats.length;
38+
assert bFloats.order() == ByteOrder.LITTLE_ENDIAN;
39+
for (float v : floats) {
40+
bFloats.put(floatToBFloat16(v));
41+
}
42+
}
43+
44+
public static void bFloat16ToFloat(byte[] bfBytes, float[] floats) {
45+
assert floats.length * 2 == bfBytes.length;
46+
for (int i = 0; i < floats.length; i++) {
47+
floats[i] = bFloat16ToFloat((short) BitUtil.VH_LE_SHORT.get(bfBytes, i * 2));
48+
}
49+
}
50+
}
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/*
2+
* @notice
3+
* Licensed to the Apache Software Foundation (ASF) under one or more
4+
* contributor license agreements. See the NOTICE file distributed with
5+
* this work for additional information regarding copyright ownership.
6+
* The ASF licenses this file to You under the Apache License, Version 2.0
7+
* (the "License"); you may not use this file except in compliance with
8+
* the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*
18+
* Modifications copyright (C) 2024 Elasticsearch B.V.
19+
*/
20+
package org.elasticsearch.index.codec.vectors.es92;
21+
22+
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
23+
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
24+
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
25+
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
26+
import org.apache.lucene.index.SegmentReadState;
27+
import org.apache.lucene.index.SegmentWriteState;
28+
import org.apache.lucene.store.FlushInfo;
29+
import org.apache.lucene.store.IOContext;
30+
import org.apache.lucene.store.MergeInfo;
31+
import org.elasticsearch.common.util.set.Sets;
32+
import org.elasticsearch.index.codec.vectors.es818.DirectIOHint;
33+
import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat;
34+
import org.elasticsearch.index.codec.vectors.es818.MergeReaderWrapper;
35+
import org.elasticsearch.index.store.FsDirectoryFactory;
36+
37+
import java.io.IOException;
38+
import java.util.Set;
39+
40+
public final class ES92BFloat16FlatVectorsFormat extends FlatVectorsFormat {
41+
42+
static final String NAME = "ES92BFloat16FlatVectorsFormat";
43+
static final String META_CODEC_NAME = "ES92BFloat16FlatVectorsFormatMeta";
44+
static final String VECTOR_DATA_CODEC_NAME = "ES92BFloat16FlatVectorsFormatData";
45+
static final String META_EXTENSION = "vemf";
46+
static final String VECTOR_DATA_EXTENSION = "vec";
47+
48+
public static final int VERSION_START = 0;
49+
public static final int VERSION_CURRENT = VERSION_START;
50+
51+
static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16;
52+
private final FlatVectorsScorer vectorsScorer;
53+
54+
public ES92BFloat16FlatVectorsFormat(FlatVectorsScorer vectorsScorer) {
55+
super(NAME);
56+
this.vectorsScorer = vectorsScorer;
57+
}
58+
59+
@Override
60+
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
61+
return new ES92BFloat16FlatVectorsWriter(state, vectorsScorer);
62+
}
63+
64+
static boolean shouldUseDirectIO(SegmentReadState state) {
65+
return ES818BinaryQuantizedVectorsFormat.USE_DIRECT_IO && FsDirectoryFactory.isHybridFs(state.directory);
66+
}
67+
68+
@Override
69+
public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
70+
if (shouldUseDirectIO(state) && state.context.context() == IOContext.Context.DEFAULT) {
71+
// only override the context for the random-access use case
72+
SegmentReadState directIOState = new SegmentReadState(
73+
state.directory,
74+
state.segmentInfo,
75+
state.fieldInfos,
76+
new DirectIOContext(state.context.hints()),
77+
state.segmentSuffix
78+
);
79+
// Use mmap for merges and direct I/O for searches.
80+
// TODO: Open the mmap file with sequential access instead of random (current behavior).
81+
return new MergeReaderWrapper(
82+
new ES92BFloat16FlatVectorsReader(directIOState, vectorsScorer),
83+
new ES92BFloat16FlatVectorsReader(state, vectorsScorer)
84+
);
85+
} else {
86+
return new ES92BFloat16FlatVectorsReader(state, vectorsScorer);
87+
}
88+
}
89+
90+
@Override
91+
public String toString() {
92+
return "ES92BFloat16FlatVectorsFormat(" + "vectorsScorer=" + vectorsScorer + ')';
93+
}
94+
95+
static class DirectIOContext implements IOContext {
96+
97+
final Set<FileOpenHint> hints;
98+
99+
DirectIOContext(Set<FileOpenHint> hints) {
100+
// always add DirectIOHint to the hints given
101+
this.hints = Sets.union(hints, Set.of(DirectIOHint.INSTANCE));
102+
}
103+
104+
@Override
105+
public Context context() {
106+
return Context.DEFAULT;
107+
}
108+
109+
@Override
110+
public MergeInfo mergeInfo() {
111+
return null;
112+
}
113+
114+
@Override
115+
public FlushInfo flushInfo() {
116+
return null;
117+
}
118+
119+
@Override
120+
public Set<FileOpenHint> hints() {
121+
return hints;
122+
}
123+
124+
@Override
125+
public IOContext withHints(FileOpenHint... hints) {
126+
return new DirectIOContext(Set.of(hints));
127+
}
128+
}
129+
}

0 commit comments

Comments
 (0)