Skip to content

Commit c51d17f

Browse files
authored
Add an HNSW bit vector implementation using ES93GenericFlatVectorsFormat (#136885)
1 parent da4f894 commit c51d17f

18 files changed

+316
-50
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/AbstractFlatVectorsFormat.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ protected AbstractFlatVectorsFormat(String name) {
2020
super(name);
2121
}
2222

23-
protected abstract FlatVectorsScorer flatVectorsScorer();
23+
public abstract FlatVectorsScorer flatVectorsScorer();
2424

2525
@Override
2626
public int getMaxDimensions(String fieldName) {

server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormat.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public ES816BinaryQuantizedVectorsFormat() {
6161
}
6262

6363
@Override
64-
protected FlatVectorsScorer flatVectorsScorer() {
64+
public FlatVectorsScorer flatVectorsScorer() {
6565
return scorer;
6666
}
6767

server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormat.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ public ES818BinaryQuantizedVectorsFormat() {
112112
}
113113

114114
@Override
115-
protected FlatVectorsScorer flatVectorsScorer() {
115+
public FlatVectorsScorer flatVectorsScorer() {
116116
return scorer;
117117
}
118118

server/src/main/java/org/elasticsearch/index/codec/vectors/es93/DirectIOCapableLucene99FlatVectorsFormat.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ public DirectIOCapableLucene99FlatVectorsFormat(FlatVectorsScorer vectorsScorer)
5050
}
5151

5252
@Override
53-
protected FlatVectorsScorer flatVectorsScorer() {
53+
public FlatVectorsScorer flatVectorsScorer() {
5454
return vectorsScorer;
5555
}
5656

server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BFloat16FlatVectorsFormat.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public ES93BFloat16FlatVectorsFormat(FlatVectorsScorer vectorsScorer) {
4848
}
4949

5050
@Override
51-
protected FlatVectorsScorer flatVectorsScorer() {
51+
public FlatVectorsScorer flatVectorsScorer() {
5252
return vectorsScorer;
5353
}
5454

server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,16 +97,16 @@ public class ES93BinaryQuantizedVectorsFormat extends AbstractFlatVectorsFormat
9797
private final ES93GenericFlatVectorsFormat rawFormat;
9898

9999
public ES93BinaryQuantizedVectorsFormat() {
100-
this(false, false);
100+
this(ES93GenericFlatVectorsFormat.ElementType.STANDARD, false);
101101
}
102102

103-
public ES93BinaryQuantizedVectorsFormat(boolean useBFloat16, boolean useDirectIO) {
103+
public ES93BinaryQuantizedVectorsFormat(ES93GenericFlatVectorsFormat.ElementType elementType, boolean useDirectIO) {
104104
super(NAME);
105-
rawFormat = new ES93GenericFlatVectorsFormat(useBFloat16, useDirectIO);
105+
rawFormat = new ES93GenericFlatVectorsFormat(elementType, useDirectIO);
106106
}
107107

108108
@Override
109-
protected FlatVectorsScorer flatVectorsScorer() {
109+
public FlatVectorsScorer flatVectorsScorer() {
110110
return scorer;
111111
}
112112

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.index.codec.vectors.es93;
11+
12+
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
13+
import org.apache.lucene.index.ByteVectorValues;
14+
import org.apache.lucene.index.KnnVectorValues;
15+
import org.apache.lucene.index.VectorSimilarityFunction;
16+
import org.apache.lucene.util.VectorUtil;
17+
import org.apache.lucene.util.hnsw.RandomVectorScorer;
18+
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
19+
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
20+
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
21+
22+
import java.io.IOException;
23+
24+
class ES93FlatBitVectorScorer implements FlatVectorsScorer {
25+
26+
static final ES93FlatBitVectorScorer INSTANCE = new ES93FlatBitVectorScorer();
27+
28+
static void checkDimensions(int queryLen, int fieldLen) {
29+
if (queryLen != fieldLen) {
30+
throw new IllegalArgumentException("vector query dimension: " + queryLen + " differs from field dimension: " + fieldLen);
31+
}
32+
}
33+
34+
@Override
35+
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
36+
VectorSimilarityFunction vectorSimilarityFunction,
37+
KnnVectorValues vectorValues
38+
) throws IOException {
39+
assert vectorValues instanceof ByteVectorValues;
40+
assert vectorSimilarityFunction == VectorSimilarityFunction.EUCLIDEAN;
41+
if (vectorValues instanceof ByteVectorValues byteVectorValues) {
42+
assert byteVectorValues instanceof QuantizedByteVectorValues == false;
43+
return new HammingScorerSupplier(byteVectorValues);
44+
}
45+
throw new IllegalArgumentException("Unsupported vector type or similarity function");
46+
}
47+
48+
@Override
49+
public RandomVectorScorer getRandomVectorScorer(
50+
VectorSimilarityFunction vectorSimilarityFunction,
51+
KnnVectorValues vectorValues,
52+
byte[] target
53+
) throws IOException {
54+
assert vectorValues instanceof ByteVectorValues;
55+
assert vectorSimilarityFunction == VectorSimilarityFunction.EUCLIDEAN;
56+
if (vectorValues instanceof ByteVectorValues byteVectorValues) {
57+
checkDimensions(target.length, byteVectorValues.dimension());
58+
return new HammingVectorScorer(byteVectorValues, target);
59+
}
60+
throw new IllegalArgumentException("Unsupported vector type or similarity function");
61+
}
62+
63+
@Override
64+
public RandomVectorScorer getRandomVectorScorer(
65+
VectorSimilarityFunction similarityFunction,
66+
KnnVectorValues vectorValues,
67+
float[] target
68+
) throws IOException {
69+
throw new IllegalArgumentException("Unsupported vector type");
70+
}
71+
72+
static float hammingScore(byte[] a, byte[] b) {
73+
return ((a.length * Byte.SIZE) - VectorUtil.xorBitCount(a, b)) / (float) (a.length * Byte.SIZE);
74+
}
75+
76+
static class HammingVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer {
77+
private final byte[] query;
78+
private final ByteVectorValues byteValues;
79+
80+
HammingVectorScorer(ByteVectorValues byteValues, byte[] query) {
81+
super(byteValues);
82+
this.query = query;
83+
this.byteValues = byteValues;
84+
}
85+
86+
@Override
87+
public float score(int i) throws IOException {
88+
return hammingScore(byteValues.vectorValue(i), query);
89+
}
90+
}
91+
92+
static class HammingScorerSupplier implements RandomVectorScorerSupplier {
93+
private final ByteVectorValues byteValues, targetValues;
94+
95+
HammingScorerSupplier(ByteVectorValues byteValues) throws IOException {
96+
this.byteValues = byteValues;
97+
this.targetValues = byteValues.copy();
98+
}
99+
100+
@Override
101+
public UpdateableRandomVectorScorer scorer() throws IOException {
102+
return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(targetValues) {
103+
private final byte[] query = new byte[targetValues.dimension()];
104+
private int currentOrd = -1;
105+
106+
@Override
107+
public void setScoringOrdinal(int i) throws IOException {
108+
if (currentOrd == i) {
109+
return;
110+
}
111+
System.arraycopy(targetValues.vectorValue(i), 0, query, 0, query.length);
112+
this.currentOrd = i;
113+
}
114+
115+
@Override
116+
public float score(int i) throws IOException {
117+
return hammingScore(targetValues.vectorValue(i), query);
118+
}
119+
};
120+
}
121+
122+
@Override
123+
public RandomVectorScorerSupplier copy() throws IOException {
124+
return new HammingScorerSupplier(byteValues);
125+
}
126+
}
127+
}

server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@
2323

2424
public class ES93GenericFlatVectorsFormat extends AbstractFlatVectorsFormat {
2525

26+
// TODO: replace with DenseVectorFieldMapper.ElementType
27+
public enum ElementType {
28+
STANDARD,
29+
BIT, // only supports byte[]
30+
BFLOAT16 // only supports float[]
31+
}
32+
2633
static final String NAME = "ES93GenericFlatVectorsFormat";
2734
static final String VECTOR_FORMAT_INFO_EXTENSION = "vfi";
2835
static final String META_CODEC_NAME = "ES93GenericFlatVectorsFormatMeta";
@@ -37,15 +44,27 @@ public class ES93GenericFlatVectorsFormat extends AbstractFlatVectorsFormat {
3744
VERSION_CURRENT
3845
);
3946

40-
private static final FlatVectorsScorer scorer = FlatVectorScorerUtil.getLucene99FlatVectorsScorer();
41-
42-
private static final DirectIOCapableFlatVectorsFormat float32VectorFormat = new DirectIOCapableLucene99FlatVectorsFormat(scorer);
47+
private static final DirectIOCapableFlatVectorsFormat standardVectorFormat = new DirectIOCapableLucene99FlatVectorsFormat(
48+
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
49+
);
50+
private static final DirectIOCapableFlatVectorsFormat bitVectorFormat = new DirectIOCapableLucene99FlatVectorsFormat(
51+
ES93FlatBitVectorScorer.INSTANCE
52+
) {
53+
@Override
54+
public String getName() {
55+
return "ES93BitFlatVectorsFormat";
56+
}
57+
};
4358
// TODO: a separate scorer for bfloat16
44-
private static final DirectIOCapableFlatVectorsFormat bfloat16VectorFormat = new ES93BFloat16FlatVectorsFormat(scorer);
59+
private static final DirectIOCapableFlatVectorsFormat bfloat16VectorFormat = new ES93BFloat16FlatVectorsFormat(
60+
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
61+
);
4562

4663
private static final Map<String, DirectIOCapableFlatVectorsFormat> supportedFormats = Map.of(
47-
float32VectorFormat.getName(),
48-
float32VectorFormat,
64+
bitVectorFormat.getName(),
65+
bitVectorFormat,
66+
standardVectorFormat.getName(),
67+
standardVectorFormat,
4968
bfloat16VectorFormat.getName(),
5069
bfloat16VectorFormat
5170
);
@@ -54,18 +73,22 @@ public class ES93GenericFlatVectorsFormat extends AbstractFlatVectorsFormat {
5473
private final boolean useDirectIO;
5574

5675
public ES93GenericFlatVectorsFormat() {
57-
this(false, false);
76+
this(ElementType.STANDARD, false);
5877
}
5978

60-
public ES93GenericFlatVectorsFormat(boolean useBFloat16, boolean useDirectIO) {
79+
public ES93GenericFlatVectorsFormat(ElementType elementType, boolean useDirectIO) {
6180
super(NAME);
62-
writeFormat = useBFloat16 ? bfloat16VectorFormat : float32VectorFormat;
81+
writeFormat = switch (elementType) {
82+
case STANDARD -> standardVectorFormat;
83+
case BIT -> bitVectorFormat;
84+
case BFLOAT16 -> bfloat16VectorFormat;
85+
};
6386
this.useDirectIO = useDirectIO;
6487
}
6588

6689
@Override
67-
protected FlatVectorsScorer flatVectorsScorer() {
68-
return scorer;
90+
public FlatVectorsScorer flatVectorsScorer() {
91+
return writeFormat.flatVectorsScorer();
6992
}
7093

7194
@Override

server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ public ES93HnswBinaryQuantizedVectorsFormat() {
4949
*
5050
* @param useDirectIO whether to use direct IO when reading raw vectors
5151
*/
52-
public ES93HnswBinaryQuantizedVectorsFormat(boolean useBFloat16, boolean useDirectIO) {
52+
public ES93HnswBinaryQuantizedVectorsFormat(ES93GenericFlatVectorsFormat.ElementType elementType, boolean useDirectIO) {
5353
super(NAME);
54-
flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(useBFloat16, useDirectIO);
54+
flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(elementType, useDirectIO);
5555
}
5656

5757
/**
@@ -61,9 +61,14 @@ public ES93HnswBinaryQuantizedVectorsFormat(boolean useBFloat16, boolean useDire
6161
* @param beamWidth the size of the queue maintained during graph construction.
6262
* @param useDirectIO whether to use direct IO when reading raw vectors
6363
*/
64-
public ES93HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, boolean useBFloat16, boolean useDirectIO) {
64+
public ES93HnswBinaryQuantizedVectorsFormat(
65+
int maxConn,
66+
int beamWidth,
67+
ES93GenericFlatVectorsFormat.ElementType elementType,
68+
boolean useDirectIO
69+
) {
6570
super(NAME, maxConn, beamWidth);
66-
flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(useBFloat16, useDirectIO);
71+
flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(elementType, useDirectIO);
6772
}
6873

6974
/**
@@ -80,13 +85,13 @@ public ES93HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, boolean
8085
public ES93HnswBinaryQuantizedVectorsFormat(
8186
int maxConn,
8287
int beamWidth,
83-
boolean useBFloat16,
88+
ES93GenericFlatVectorsFormat.ElementType elementType,
8489
boolean useDirectIO,
8590
int numMergeWorkers,
8691
ExecutorService mergeExec
8792
) {
8893
super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec);
89-
flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(useBFloat16, useDirectIO);
94+
flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(elementType, useDirectIO);
9095
}
9196

9297
@Override

server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormat.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,26 +32,26 @@ public ES93HnswVectorsFormat() {
3232
flatVectorsFormat = new ES93GenericFlatVectorsFormat();
3333
}
3434

35-
public ES93HnswVectorsFormat(boolean bfloat16, boolean useDirectIO) {
35+
public ES93HnswVectorsFormat(ES93GenericFlatVectorsFormat.ElementType elementType, boolean useDirectIO) {
3636
super(NAME);
37-
flatVectorsFormat = new ES93GenericFlatVectorsFormat(bfloat16, useDirectIO);
37+
flatVectorsFormat = new ES93GenericFlatVectorsFormat(elementType, useDirectIO);
3838
}
3939

40-
public ES93HnswVectorsFormat(int maxConn, int beamWidth, boolean bfloat16, boolean useDirectIO) {
40+
public ES93HnswVectorsFormat(int maxConn, int beamWidth, ES93GenericFlatVectorsFormat.ElementType elementType, boolean useDirectIO) {
4141
super(NAME, maxConn, beamWidth);
42-
flatVectorsFormat = new ES93GenericFlatVectorsFormat(bfloat16, useDirectIO);
42+
flatVectorsFormat = new ES93GenericFlatVectorsFormat(elementType, useDirectIO);
4343
}
4444

4545
public ES93HnswVectorsFormat(
4646
int maxConn,
4747
int beamWidth,
48-
boolean bfloat16,
48+
ES93GenericFlatVectorsFormat.ElementType elementType,
4949
boolean useDirectIO,
5050
int numMergeWorkers,
5151
ExecutorService mergeExec
5252
) {
5353
super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec);
54-
flatVectorsFormat = new ES93GenericFlatVectorsFormat(bfloat16, useDirectIO);
54+
flatVectorsFormat = new ES93GenericFlatVectorsFormat(elementType, useDirectIO);
5555
}
5656

5757
@Override

0 commit comments

Comments
 (0)