Skip to content

Commit 288a728

Browse files
committed
kNN vector rescoring for quantized vectors (elastic#116663)
(cherry picked from commit 5996772) # Conflicts: # server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java # x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java
1 parent b398448 commit 288a728

File tree

5 files changed

+34
-79
lines changed

5 files changed

+34
-79
lines changed

server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
import org.apache.lucene.search.DoubleValues;
1818
import org.apache.lucene.search.DoubleValuesSource;
1919
import org.apache.lucene.search.IndexSearcher;
20-
import org.elasticsearch.search.profile.query.QueryProfiler;
21-
import org.elasticsearch.search.vectors.QueryProfilerProvider;
2220

2321
import java.io.IOException;
2422
import java.util.Arrays;
@@ -28,12 +26,11 @@
2826
* DoubleValuesSource that is used to calculate scores according to a similarity function for a KnnFloatVectorField, using the
2927
* original vector values stored in the index
3028
*/
31-
public class VectorSimilarityFloatValueSource extends DoubleValuesSource implements QueryProfilerProvider {
29+
public class VectorSimilarityFloatValueSource extends DoubleValuesSource {
3230

3331
private final String field;
3432
private final float[] target;
3533
private final VectorSimilarityFunction vectorSimilarityFunction;
36-
private long vectorOpsCount;
3734

3835
public VectorSimilarityFloatValueSource(String field, float[] target, VectorSimilarityFunction vectorSimilarityFunction) {
3936
this.field = field;
@@ -50,7 +47,6 @@ public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws
5047
return new DoubleValues() {
5148
@Override
5249
public double doubleValue() throws IOException {
53-
vectorOpsCount++;
5450
return vectorSimilarityFunction.compare(target, vectorValues.vectorValue());
5551
}
5652

@@ -73,11 +69,6 @@ public DoubleValuesSource rewrite(IndexSearcher reader) throws IOException {
7369
return this;
7470
}
7571

76-
@Override
77-
public void profile(QueryProfiler queryProfiler) {
78-
queryProfiler.addVectorOpsCount(vectorOpsCount);
79-
}
80-
8172
@Override
8273
public int hashCode() {
8374
return Objects.hash(field, Arrays.hashCode(target), vectorSimilarityFunction);

server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -487,11 +487,6 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException {
487487
return this;
488488
}
489489

490-
@Override
491-
protected QueryBuilder doIndexMetadataRewrite(QueryRewriteContext context) throws IOException {
492-
return super.doIndexMetadataRewrite(context);
493-
}
494-
495490
@Override
496491
protected Query doToQuery(SearchExecutionContext context) throws IOException {
497492
MappedFieldType fieldType = context.getFieldType(fieldName);
@@ -529,8 +524,8 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
529524
String parentPath = context.nestedLookup().getNestedParent(fieldName);
530525
Float numCandidatesFactor = rescoreVectorBuilder() == null ? null : rescoreVectorBuilder.numCandidatesFactor();
531526

527+
BitSetProducer parentBitSet = null;
532528
if (parentPath != null) {
533-
final BitSetProducer parentBitSet;
534529
final Query parentFilter;
535530
NestedObjectMapper originalObjectMapper = context.nestedScope().getObjectMapper();
536531
if (originalObjectMapper != null) {
@@ -560,17 +555,17 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
560555
// Now join the filterQuery & parentFilter to provide the matching blocks of children
561556
filterQuery = new ToChildBlockJoinQuery(filterQuery, parentBitSet);
562557
}
563-
return vectorFieldType.createKnnQuery(
564-
queryVector,
565-
k,
566-
adjustedNumCands,
567-
numCandidatesFactor,
568-
filterQuery,
569-
vectorSimilarity,
570-
parentBitSet
571-
);
572558
}
573-
return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, numCandidatesFactor, filterQuery, vectorSimilarity, null);
559+
560+
return vectorFieldType.createKnnQuery(
561+
queryVector,
562+
k,
563+
adjustedNumCands,
564+
numCandidatesFactor,
565+
filterQuery,
566+
vectorSimilarity,
567+
parentBitSet
568+
);
574569
}
575570

576571
@Override

server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,15 @@ public class RescoreKnnVectorQuery extends Query implements QueryProfilerProvide
3232
private final String fieldName;
3333
private final float[] floatTarget;
3434
private final VectorSimilarityFunction vectorSimilarityFunction;
35-
private final Integer k;
35+
private final int k;
3636
private final Query innerQuery;
37-
38-
private QueryProfilerProvider vectorProfiling;
37+
private long vectorOperations = 0;
3938

4039
public RescoreKnnVectorQuery(
4140
String fieldName,
4241
float[] floatTarget,
4342
VectorSimilarityFunction vectorSimilarityFunction,
44-
Integer k,
43+
int k,
4544
Query innerQuery
4645
) {
4746
this.fieldName = fieldName;
@@ -54,19 +53,12 @@ public RescoreKnnVectorQuery(
5453
@Override
5554
public Query rewrite(IndexSearcher searcher) throws IOException {
5655
DoubleValuesSource valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction);
57-
// Vector similarity VectorSimilarityFloatValueSource keep track of the compared vectors - we need that in case we don't need
58-
// to calculate top k and return directly the query to understand how many comparisons were done
59-
vectorProfiling = (QueryProfilerProvider) valueSource;
6056
FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(innerQuery, valueSource);
6157
Query query = searcher.rewrite(functionScoreQuery);
6258

63-
if (k == null) {
64-
// No need to calculate top k - let the request size limit the results.
65-
return query;
66-
}
67-
6859
// Retrieve top k documents from the rescored query
6960
TopDocs topDocs = searcher.search(query, k);
61+
vectorOperations = topDocs.totalHits.value;
7062
ScoreDoc[] scoreDocs = topDocs.scoreDocs;
7163
int[] docIds = new int[scoreDocs.length];
7264
float[] scores = new float[scoreDocs.length];
@@ -82,7 +74,7 @@ public Query innerQuery() {
8274
return innerQuery;
8375
}
8476

85-
public Integer k() {
77+
public int k() {
8678
return k;
8779
}
8880

@@ -92,10 +84,7 @@ public void profile(QueryProfiler queryProfiler) {
9284
queryProfilerProvider.profile(queryProfiler);
9385
}
9486

95-
if (vectorProfiling == null) {
96-
throw new IllegalStateException("Query should have been rewritten");
97-
}
98-
vectorProfiling.profile(queryProfiler);
87+
queryProfiler.addVectorOpsCount(vectorOperations);
9988
}
10089

10190
@Override

server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -456,18 +456,19 @@ public void testRescoreOversampleModifiesNumCandidates() {
456456
);
457457

458458
// Total results is k, internal k is multiplied by oversample
459-
checkRescoreQueryParameters(fieldType, 10, 200, 2.5F, null, 500, 10);
459+
checkRescoreQueryParameters(fieldType, 10, 200, randomInt(), 2.5F, null, 500, 10);
460460
// If numCands < k, update numCands to k
461-
checkRescoreQueryParameters(fieldType, 10, 20, 2.5F, null, 50, 10);
461+
checkRescoreQueryParameters(fieldType, 10, 20, randomInt(), 2.5F, null, 50, 10);
462462
// Oversampling limits for num candidates
463-
checkRescoreQueryParameters(fieldType, 1000, 1000, 11.0F, null, 10000, 1000);
464-
checkRescoreQueryParameters(fieldType, 5000, 7500, 2.5F, null, 10000, 5000);
463+
checkRescoreQueryParameters(fieldType, 1000, 1000, randomInt(), 11.0F, null, 10000, 1000);
464+
checkRescoreQueryParameters(fieldType, 5000, 7500, randomInt(), 2.5F, null, 10000, 5000);
465465
}
466466

467467
private static void checkRescoreQueryParameters(
468468
DenseVectorFieldType fieldType,
469-
Integer k,
469+
int k,
470470
int candidates,
471+
int requestSize,
471472
float numCandsFactor,
472473
Integer expectedK,
473474
int expectedCandidates,

server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99

1010
package org.elasticsearch.search.vectors;
1111

12-
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
13-
1412
import org.apache.lucene.document.Document;
1513
import org.apache.lucene.document.KnnFloatVectorField;
1614
import org.apache.lucene.index.DirectoryReader;
@@ -32,11 +30,9 @@
3230

3331
import java.io.IOException;
3432
import java.io.UnsupportedEncodingException;
35-
import java.util.ArrayList;
3633
import java.util.Arrays;
3734
import java.util.Collection;
3835
import java.util.HashSet;
39-
import java.util.List;
4036
import java.util.Map;
4137
import java.util.PriorityQueue;
4238
import java.util.stream.Collectors;
@@ -48,21 +44,11 @@
4844
public class RescoreKnnVectorQueryTests extends ESTestCase {
4945

5046
public static final String FIELD_NAME = "float_vector";
51-
private final int numDocs;
52-
private final Integer k;
53-
54-
public RescoreKnnVectorQueryTests(boolean useK) {
55-
this.numDocs = randomIntBetween(10, 100);
56-
this.k = useK ? randomIntBetween(1, numDocs - 1) : null;
57-
}
5847

5948
public void testRescoreDocs() throws Exception {
49+
int numDocs = randomIntBetween(10, 100);
6050
int numDims = randomIntBetween(5, 100);
61-
62-
Integer adjustedK = k;
63-
if (k == null) {
64-
adjustedK = numDocs;
65-
}
51+
int k = randomIntBetween(1, numDocs - 1);
6652

6753
try (Directory d = newDirectory()) {
6854
addRandomDocuments(numDocs, d, numDims);
@@ -76,7 +62,7 @@ public void testRescoreDocs() throws Exception {
7662
FIELD_NAME,
7763
queryVector,
7864
VectorSimilarityFunction.COSINE,
79-
adjustedK,
65+
k,
8066
new MatchAllDocsQuery()
8167
);
8268

@@ -85,7 +71,7 @@ public void testRescoreDocs() throws Exception {
8571
Map<Integer, Float> rescoredDocs = Arrays.stream(docs.scoreDocs)
8672
.collect(Collectors.toMap(scoreDoc -> scoreDoc.doc, scoreDoc -> scoreDoc.score));
8773

88-
assertThat(rescoredDocs.size(), equalTo(adjustedK));
74+
assertThat(rescoredDocs.size(), equalTo(k));
8975

9076
Collection<Float> rescoredScores = new HashSet<>(rescoredDocs.values());
9177

@@ -111,7 +97,7 @@ public void testRescoreDocs() throws Exception {
11197
assertThat(rescoredDocs.size(), equalTo(0));
11298

11399
// Check top scoring docs are contained in rescored docs
114-
for (int i = 0; i < adjustedK; i++) {
100+
for (int i = 0; i < k; i++) {
115101
Float topScore = topK.poll();
116102
if (rescoredScores.contains(topScore) == false) {
117103
fail("Top score " + topScore + " not contained in rescored doc scores " + rescoredScores);
@@ -122,21 +108,23 @@ public void testRescoreDocs() throws Exception {
122108
}
123109

124110
public void testProfiling() throws Exception {
111+
int numDocs = randomIntBetween(10, 100);
125112
int numDims = randomIntBetween(5, 100);
113+
int k = randomIntBetween(1, numDocs - 1);
126114

127115
try (Directory d = newDirectory()) {
128116
addRandomDocuments(numDocs, d, numDims);
129117

130118
try (IndexReader reader = DirectoryReader.open(d)) {
131119
float[] queryVector = randomVector(numDims);
132120

133-
checkProfiling(queryVector, reader, new MatchAllDocsQuery());
134-
checkProfiling(queryVector, reader, new MockQueryProfilerProvider(randomIntBetween(1, 100)));
121+
checkProfiling(k, numDocs, queryVector, reader, new MatchAllDocsQuery());
122+
checkProfiling(k, numDocs, queryVector, reader, new MockQueryProfilerProvider(randomIntBetween(1, 100)));
135123
}
136124
}
137125
}
138126

139-
private void checkProfiling(float[] queryVector, IndexReader reader, Query innerQuery) throws IOException {
127+
private void checkProfiling(int k, int numDocs, float[] queryVector, IndexReader reader, Query innerQuery) throws IOException {
140128
RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery(
141129
FIELD_NAME,
142130
queryVector,
@@ -227,13 +215,4 @@ private static void addRandomDocuments(int numDocs, Directory d, int numDims) th
227215
w.forceMerge(1);
228216
}
229217
}
230-
231-
@ParametersFactory
232-
public static Iterable<Object[]> parameters() {
233-
List<Object[]> params = new ArrayList<>();
234-
params.add(new Object[] { true });
235-
params.add(new Object[] { false });
236-
237-
return params;
238-
}
239218
}

0 commit comments

Comments
 (0)