Skip to content

Commit 763cd14

Browse files
authored
Improve brute force vector search speed by using Lucene functions (#96617)
Lucene has integrated hardware accelerated vector calculations. Meaning, calculations like `dot_product` can be much faster when using the Lucene defined functions. When a `dense_vector` is indexed, we already support this. However, when `index: false` we store float vectors as binary fields in Lucene and decode them ourselves. Meaning, we don't use the underlying Lucene structures or functions. To take advantage of the large performance boost, this PR refactors the binary vector values in the following way: - Eagerly decode the binary blobs when iterated - Call the Lucene defined VectorUtil functions when possible related to: #96370
1 parent 85d5a32 commit 763cd14

File tree

10 files changed

+149
-131
lines changed

10 files changed

+149
-131
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/vector/DistanceFunctionBenchmark.java

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -110,19 +110,20 @@ private KnnFloatBenchmarkFunction(int dims, boolean normalize) {
110110
private abstract static class BinaryFloatBenchmarkFunction extends BenchmarkFunction {
111111

112112
final BytesRef docVector;
113+
final float[] docFloatVector;
113114
final float[] queryVector;
114115

115116
private BinaryFloatBenchmarkFunction(int dims, boolean normalize) {
116117
super(dims);
117118

118-
float[] docVector = new float[dims];
119+
docFloatVector = new float[dims];
119120
queryVector = new float[dims];
120121

121122
float docMagnitude = 0f;
122123
float queryMagnitude = 0f;
123124

124125
for (int i = 0; i < dims; ++i) {
125-
docVector[i] = (float) (dims - i);
126+
docFloatVector[i] = (float) (dims - i);
126127
queryVector[i] = (float) i;
127128

128129
docMagnitude += (float) (dims - i);
@@ -136,11 +137,11 @@ private BinaryFloatBenchmarkFunction(int dims, boolean normalize) {
136137

137138
for (int i = 0; i < dims; ++i) {
138139
if (normalize) {
139-
docVector[i] /= docMagnitude;
140+
docFloatVector[i] /= docMagnitude;
140141
queryVector[i] /= queryMagnitude;
141142
}
142143

143-
byteBuffer.putFloat(docVector[i]);
144+
byteBuffer.putFloat(docFloatVector[i]);
144145
}
145146

146147
byteBuffer.putFloat(docMagnitude);
@@ -178,6 +179,7 @@ private KnnByteBenchmarkFunction(int dims) {
178179
private abstract static class BinaryByteBenchmarkFunction extends BenchmarkFunction {
179180

180181
final BytesRef docVector;
182+
final byte[] vectorValue;
181183
final byte[] queryVector;
182184

183185
final float queryMagnitude;
@@ -187,12 +189,14 @@ private BinaryByteBenchmarkFunction(int dims) {
187189

188190
ByteBuffer docVector = ByteBuffer.allocate(dims + 4);
189191
queryVector = new byte[dims];
192+
vectorValue = new byte[dims];
190193

191194
float docMagnitude = 0f;
192195
float queryMagnitude = 0f;
193196

194197
for (int i = 0; i < dims; ++i) {
195198
docVector.put((byte) (dims - i));
199+
vectorValue[i] = (byte) (dims - i);
196200
queryVector[i] = (byte) i;
197201

198202
docMagnitude += (float) (dims - i);
@@ -238,7 +242,7 @@ private DotBinaryFloatBenchmarkFunction(int dims) {
238242

239243
@Override
240244
public void execute(Consumer<Object> consumer) {
241-
new BinaryDenseVector(docVector, dims, Version.CURRENT).dotProduct(queryVector);
245+
new BinaryDenseVector(docFloatVector, docVector, dims, Version.CURRENT).dotProduct(queryVector);
242246
}
243247
}
244248

@@ -250,7 +254,7 @@ private DotBinaryByteBenchmarkFunction(int dims) {
250254

251255
@Override
252256
public void execute(Consumer<Object> consumer) {
253-
new ByteBinaryDenseVector(docVector, dims).dotProduct(queryVector);
257+
new ByteBinaryDenseVector(vectorValue, docVector, dims).dotProduct(queryVector);
254258
}
255259
}
256260

@@ -286,7 +290,7 @@ private CosineBinaryFloatBenchmarkFunction(int dims) {
286290

287291
@Override
288292
public void execute(Consumer<Object> consumer) {
289-
new BinaryDenseVector(docVector, dims, Version.CURRENT).cosineSimilarity(queryVector, false);
293+
new BinaryDenseVector(docFloatVector, docVector, dims, Version.CURRENT).cosineSimilarity(queryVector, false);
290294
}
291295
}
292296

@@ -298,7 +302,7 @@ private CosineBinaryByteBenchmarkFunction(int dims) {
298302

299303
@Override
300304
public void execute(Consumer<Object> consumer) {
301-
new ByteBinaryDenseVector(docVector, dims).cosineSimilarity(queryVector, queryMagnitude);
305+
new ByteBinaryDenseVector(vectorValue, docVector, dims).cosineSimilarity(queryVector, queryMagnitude);
302306
}
303307
}
304308

@@ -334,7 +338,7 @@ private L1BinaryFloatBenchmarkFunction(int dims) {
334338

335339
@Override
336340
public void execute(Consumer<Object> consumer) {
337-
new BinaryDenseVector(docVector, dims, Version.CURRENT).l1Norm(queryVector);
341+
new BinaryDenseVector(docFloatVector, docVector, dims, Version.CURRENT).l1Norm(queryVector);
338342
}
339343
}
340344

@@ -346,7 +350,7 @@ private L1BinaryByteBenchmarkFunction(int dims) {
346350

347351
@Override
348352
public void execute(Consumer<Object> consumer) {
349-
new ByteBinaryDenseVector(docVector, dims).l1Norm(queryVector);
353+
new ByteBinaryDenseVector(vectorValue, docVector, dims).l1Norm(queryVector);
350354
}
351355
}
352356

@@ -382,7 +386,7 @@ private L2BinaryFloatBenchmarkFunction(int dims) {
382386

383387
@Override
384388
public void execute(Consumer<Object> consumer) {
385-
new BinaryDenseVector(docVector, dims, Version.CURRENT).l1Norm(queryVector);
389+
new BinaryDenseVector(docFloatVector, docVector, dims, Version.CURRENT).l1Norm(queryVector);
386390
}
387391
}
388392

@@ -394,7 +398,7 @@ private L2BinaryByteBenchmarkFunction(int dims) {
394398

395399
@Override
396400
public void execute(Consumer<Object> consumer) {
397-
consumer.accept(new ByteBinaryDenseVector(docVector, dims).l2Norm(queryVector));
401+
consumer.accept(new ByteBinaryDenseVector(vectorValue, docVector, dims).l2Norm(queryVector));
398402
}
399403
}
400404

docs/changelog/96617.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 96617
2+
summary: Improve brute force vector search speed by using Lucene functions
3+
area: Search
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,26 +36,23 @@ public static float decodeMagnitude(Version indexVersion, BytesRef vectorBR) {
3636
/**
3737
* Calculates vector magnitude
3838
*/
39-
private static float calculateMagnitude(Version indexVersion, BytesRef vectorBR) {
40-
final int length = denseVectorLength(indexVersion, vectorBR);
41-
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
39+
private static float calculateMagnitude(float[] decodedVector) {
4240
double magnitude = 0.0f;
43-
for (int i = 0; i < length; i++) {
44-
float value = byteBuffer.getFloat();
45-
magnitude += value * value;
41+
for (int i = 0; i < decodedVector.length; i++) {
42+
magnitude += decodedVector[i] * decodedVector[i];
4643
}
4744
magnitude = Math.sqrt(magnitude);
4845
return (float) magnitude;
4946
}
5047

51-
public static float getMagnitude(Version indexVersion, BytesRef vectorBR) {
48+
public static float getMagnitude(Version indexVersion, BytesRef vectorBR, float[] decodedVector) {
5249
if (vectorBR == null) {
5350
throw new IllegalArgumentException(DenseVectorScriptDocValues.MISSING_VECTOR_FIELD_MESSAGE);
5451
}
5552
if (indexVersion.onOrAfter(Version.V_7_5_0)) {
5653
return decodeMagnitude(indexVersion, vectorBR);
5754
} else {
58-
return calculateMagnitude(indexVersion, vectorBR);
55+
return calculateMagnitude(decodedVector);
5956
}
6057
}
6158

@@ -70,7 +67,7 @@ public static void decodeDenseVector(BytesRef vectorBR, float[] vector) {
7067
}
7168
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
7269
for (int dim = 0; dim < vector.length; dim++) {
73-
vector[dim] = byteBuffer.getFloat();
70+
vector[dim] = byteBuffer.getFloat((dim * Float.BYTES) + vectorBR.offset);
7471
}
7572
}
7673

server/src/main/java/org/elasticsearch/script/field/vectors/BinaryDenseVector.java

Lines changed: 17 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -9,38 +9,36 @@
99
package org.elasticsearch.script.field.vectors;
1010

1111
import org.apache.lucene.util.BytesRef;
12+
import org.apache.lucene.util.VectorUtil;
1213
import org.elasticsearch.Version;
1314
import org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder;
1415

15-
import java.nio.ByteBuffer;
1616
import java.util.List;
1717

1818
public class BinaryDenseVector implements DenseVector {
1919

20-
protected final BytesRef docVector;
21-
protected final int dims;
22-
protected final Version indexVersion;
20+
private final BytesRef docVector;
2321

24-
protected float[] decodedDocVector;
22+
private final int dims;
23+
private final Version indexVersion;
2524

26-
public BinaryDenseVector(BytesRef docVector, int dims, Version indexVersion) {
25+
private final float[] decodedDocVector;
26+
27+
public BinaryDenseVector(float[] decodedDocVector, BytesRef docVector, int dims, Version indexVersion) {
28+
this.decodedDocVector = decodedDocVector;
2729
this.docVector = docVector;
2830
this.indexVersion = indexVersion;
2931
this.dims = dims;
3032
}
3133

3234
@Override
3335
public float[] getVector() {
34-
if (decodedDocVector == null) {
35-
decodedDocVector = new float[dims];
36-
VectorEncoderDecoder.decodeDenseVector(docVector, decodedDocVector);
37-
}
3836
return decodedDocVector;
3937
}
4038

4139
@Override
4240
public float getMagnitude() {
43-
return VectorEncoderDecoder.getMagnitude(indexVersion, docVector);
41+
return VectorEncoderDecoder.getMagnitude(indexVersion, docVector, decodedDocVector);
4442
}
4543

4644
@Override
@@ -50,22 +48,14 @@ public int dotProduct(byte[] queryVector) {
5048

5149
@Override
5250
public double dotProduct(float[] queryVector) {
53-
ByteBuffer byteBuffer = wrap(docVector);
54-
55-
double dotProduct = 0;
56-
for (float v : queryVector) {
57-
dotProduct += byteBuffer.getFloat() * v;
58-
}
59-
return dotProduct;
51+
return VectorUtil.dotProduct(decodedDocVector, queryVector);
6052
}
6153

6254
@Override
6355
public double dotProduct(List<Number> queryVector) {
64-
ByteBuffer byteBuffer = wrap(docVector);
65-
6656
double dotProduct = 0;
6757
for (int i = 0; i < queryVector.size(); i++) {
68-
dotProduct += byteBuffer.getFloat() * queryVector.get(i).floatValue();
58+
dotProduct += decodedDocVector[i] * queryVector.get(i).floatValue();
6959
}
7060
return dotProduct;
7161
}
@@ -77,22 +67,18 @@ public int l1Norm(byte[] queryVector) {
7767

7868
@Override
7969
public double l1Norm(float[] queryVector) {
80-
ByteBuffer byteBuffer = wrap(docVector);
81-
8270
double l1norm = 0;
83-
for (float v : queryVector) {
84-
l1norm += Math.abs(v - byteBuffer.getFloat());
71+
for (int i = 0; i < queryVector.length; i++) {
72+
l1norm += Math.abs(queryVector[i] - decodedDocVector[i]);
8573
}
8674
return l1norm;
8775
}
8876

8977
@Override
9078
public double l1Norm(List<Number> queryVector) {
91-
ByteBuffer byteBuffer = wrap(docVector);
92-
9379
double l1norm = 0;
9480
for (int i = 0; i < queryVector.size(); i++) {
95-
l1norm += Math.abs(queryVector.get(i).floatValue() - byteBuffer.getFloat());
81+
l1norm += Math.abs(queryVector.get(i).floatValue() - decodedDocVector[i]);
9682
}
9783
return l1norm;
9884
}
@@ -104,21 +90,14 @@ public double l2Norm(byte[] queryVector) {
10490

10591
@Override
10692
public double l2Norm(float[] queryVector) {
107-
ByteBuffer byteBuffer = wrap(docVector);
108-
double l2norm = 0;
109-
for (float queryValue : queryVector) {
110-
double diff = byteBuffer.getFloat() - queryValue;
111-
l2norm += diff * diff;
112-
}
113-
return Math.sqrt(l2norm);
93+
return Math.sqrt(VectorUtil.squareDistance(queryVector, decodedDocVector));
11494
}
11595

11696
@Override
11797
public double l2Norm(List<Number> queryVector) {
118-
ByteBuffer byteBuffer = wrap(docVector);
11998
double l2norm = 0;
120-
for (Number number : queryVector) {
121-
double diff = byteBuffer.getFloat() - number.floatValue();
99+
for (int i = 0; i < queryVector.size(); i++) {
100+
double diff = decodedDocVector[i] - queryVector.get(i).floatValue();
122101
l2norm += diff * diff;
123102
}
124103
return Math.sqrt(l2norm);
@@ -156,8 +135,4 @@ public boolean isEmpty() {
156135
public int getDims() {
157136
return dims;
158137
}
159-
160-
private static ByteBuffer wrap(BytesRef dv) {
161-
return ByteBuffer.wrap(dv.bytes, dv.offset, dv.length);
162-
}
163138
}

server/src/main/java/org/elasticsearch/script/field/vectors/BinaryDenseVectorDocValuesField.java

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,30 @@
1313
import org.elasticsearch.Version;
1414
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
1515
import org.elasticsearch.index.mapper.vectors.DenseVectorScriptDocValues;
16+
import org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder;
1617

1718
import java.io.IOException;
1819

1920
public class BinaryDenseVectorDocValuesField extends DenseVectorDocValuesField {
2021

21-
protected final BinaryDocValues input;
22-
protected final Version indexVersion;
23-
protected final int dims;
24-
protected BytesRef value;
22+
private final BinaryDocValues input;
23+
private final float[] vectorValue;
24+
private final Version indexVersion;
25+
private boolean decoded;
26+
private final int dims;
27+
private BytesRef value;
2528

2629
public BinaryDenseVectorDocValuesField(BinaryDocValues input, String name, ElementType elementType, int dims, Version indexVersion) {
2730
super(name, elementType);
2831
this.input = input;
2932
this.indexVersion = indexVersion;
3033
this.dims = dims;
34+
this.vectorValue = new float[dims];
3135
}
3236

3337
@Override
3438
public void setNextDocId(int docId) throws IOException {
39+
decoded = false;
3540
if (input.advanceExact(docId)) {
3641
value = input.binaryValue();
3742
} else {
@@ -54,20 +59,28 @@ public DenseVector get() {
5459
if (isEmpty()) {
5560
return DenseVector.EMPTY;
5661
}
57-
58-
return new BinaryDenseVector(value, dims, indexVersion);
62+
decodeVectorIfNecessary();
63+
return new BinaryDenseVector(vectorValue, value, dims, indexVersion);
5964
}
6065

6166
@Override
6267
public DenseVector get(DenseVector defaultValue) {
6368
if (isEmpty()) {
6469
return defaultValue;
6570
}
66-
return new BinaryDenseVector(value, dims, indexVersion);
71+
decodeVectorIfNecessary();
72+
return new BinaryDenseVector(vectorValue, value, dims, indexVersion);
6773
}
6874

6975
@Override
7076
public DenseVector getInternal() {
7177
return get(null);
7278
}
79+
80+
private void decodeVectorIfNecessary() {
81+
if (decoded == false && value != null) {
82+
VectorEncoderDecoder.decodeDenseVector(value, vectorValue);
83+
decoded = true;
84+
}
85+
}
7386
}

0 commit comments

Comments
 (0)