Skip to content

Commit 83238a5

Browse files
committed
Vector similarity needs to wrap the new rescoring query and not the other way round
1 parent 497d8e2 commit 83238a5

File tree

6 files changed

+69
-44
lines changed

6 files changed

+69
-44
lines changed

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2111,29 +2111,33 @@ && isNotUnitVector(squaredMagnitude)) {
21112111
}
21122112
}
21132113

2114+
Integer adjustedK = k;
21142115
int adjustedNumCands = numCands;
21152116
if (needsRescore(numCandsFactor)) {
2116-
// k <= numCands * numCandsFactor <= NUM_CANDS_OVERSAMPLE_LIMIT. Adjust otherwise.
2117+
// Get all candidates, get top k as part of rescoring
2118+
adjustedK = null;
2119+
// numCands * numCandsFactor <= NUM_CANDS_OVERSAMPLE_LIMIT. Adjust otherwise.
21172120
adjustedNumCands = Math.min((int) Math.ceil(numCands * numCandsFactor), NUM_CANDS_OVERSAMPLE_LIMIT);
21182121
}
21192122
Query knnQuery = parentFilter != null
2120-
? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, k, adjustedNumCands, parentFilter)
2121-
: new ESKnnFloatVectorQuery(name(), queryVector, k, adjustedNumCands, filter);
2122-
if (similarityThreshold != null) {
2123-
knnQuery = new VectorSimilarityQuery(
2124-
knnQuery,
2125-
similarityThreshold,
2126-
similarity.score(similarityThreshold, elementType, dims)
2127-
);
2128-
}
2123+
? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, adjustedK, adjustedNumCands, parentFilter)
2124+
: new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, adjustedNumCands, filter);
21292125
if (needsRescore(numCandsFactor)) {
21302126
knnQuery = new RescoreKnnVectorQuery(
21312127
name(),
21322128
queryVector,
21332129
similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.FLOAT),
2130+
k,
21342131
knnQuery
21352132
);
21362133
}
2134+
if (similarityThreshold != null) {
2135+
knnQuery = new VectorSimilarityQuery(
2136+
knnQuery,
2137+
similarityThreshold,
2138+
similarity.score(similarityThreshold, elementType, dims)
2139+
);
2140+
}
21372141
return knnQuery;
21382142
}
21392143

server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import org.apache.lucene.search.IndexSearcher;
1717
import org.apache.lucene.search.Query;
1818
import org.apache.lucene.search.QueryVisitor;
19+
import org.apache.lucene.search.ScoreDoc;
20+
import org.apache.lucene.search.TopDocs;
1921
import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource;
2022
import org.elasticsearch.search.profile.query.QueryProfiler;
2123

@@ -30,6 +32,7 @@ public class RescoreKnnVectorQuery extends Query implements QueryProfilerProvide
3032
private final String fieldName;
3133
private final float[] floatTarget;
3234
private final VectorSimilarityFunction vectorSimilarityFunction;
35+
private final Integer k;
3336
private final Query innerQuery;
3437

3538
private QueryProfilerProvider vectorProfiling;
@@ -38,11 +41,13 @@ public RescoreKnnVectorQuery(
3841
String fieldName,
3942
float[] floatTarget,
4043
VectorSimilarityFunction vectorSimilarityFunction,
44+
Integer k,
4145
Query innerQuery
4246
) {
4347
this.fieldName = fieldName;
4448
this.floatTarget = floatTarget;
4549
this.vectorSimilarityFunction = vectorSimilarityFunction;
50+
this.k = k;
4651
this.innerQuery = innerQuery;
4752
}
4853

@@ -53,13 +58,34 @@ public Query rewrite(IndexSearcher searcher) throws IOException {
5358
// to calculate top k and return directly the query to understand how many comparisons were done
5459
vectorProfiling = (QueryProfilerProvider) valueSource;
5560
FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(innerQuery, valueSource);
56-
return searcher.rewrite(functionScoreQuery);
61+
Query query = searcher.rewrite(functionScoreQuery);
62+
63+
if (k == null) {
64+
// No need to calculate top k - let the request size limit the results.
65+
return query;
66+
}
67+
68+
// Retrieve top k documents from the rescored query
69+
TopDocs topDocs = searcher.search(query, k);
70+
ScoreDoc[] scoreDocs = topDocs.scoreDocs;
71+
int[] docIds = new int[scoreDocs.length];
72+
float[] scores = new float[scoreDocs.length];
73+
for (int i = 0; i < scoreDocs.length; i++) {
74+
docIds[i] = scoreDocs[i].doc;
75+
scores[i] = scoreDocs[i].score;
76+
}
77+
78+
return new KnnScoreDocQuery(docIds, scores, searcher.getIndexReader());
5779
}
5880

5981
public Query innerQuery() {
6082
return innerQuery;
6183
}
6284

85+
public Integer k() {
86+
return k;
87+
}
88+
6389
@Override
6490
public void profile(QueryProfiler queryProfiler) {
6591
if (innerQuery instanceof QueryProfilerProvider queryProfilerProvider) {
@@ -85,27 +111,22 @@ public boolean equals(Object o) {
85111
return Objects.equals(fieldName, that.fieldName)
86112
&& Objects.deepEquals(floatTarget, that.floatTarget)
87113
&& vectorSimilarityFunction == that.vectorSimilarityFunction
114+
&& Objects.equals(k, that.k)
88115
&& Objects.equals(innerQuery, that.innerQuery);
89116
}
90117

91118
@Override
92119
public int hashCode() {
93-
return Objects.hash(fieldName, Arrays.hashCode(floatTarget), vectorSimilarityFunction, innerQuery);
120+
return Objects.hash(fieldName, Arrays.hashCode(floatTarget), vectorSimilarityFunction, k, innerQuery);
94121
}
95122

96123
@Override
97124
public String toString(String field) {
98-
return "KnnRescoreVectorQuery{"
99-
+ "fieldName='"
100-
+ fieldName
101-
+ '\''
102-
+ ", floatTarget="
103-
+ floatTarget[0]
104-
+ "..."
105-
+ ", vectorSimilarityFunction="
106-
+ vectorSimilarityFunction
107-
+ ", vectorQuery="
108-
+ innerQuery
109-
+ '}';
125+
return "KnnRescoreVectorQuery{" + "fieldName='" + fieldName + '\'' +
126+
", floatTarget=" + floatTarget[0] + "..." +
127+
", vectorSimilarityFunction=" + vectorSimilarityFunction +
128+
", k=" + k +
129+
", vectorQuery=" + innerQuery +
130+
'}';
110131
}
111132
}

server/src/main/java/org/elasticsearch/search/vectors/VectorSimilarityQuery.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
import static org.elasticsearch.common.Strings.format;
3030

3131
/**
32-
* This query provides a simple post-filter for the provided Query. The query is assumed to be a Knn(Float|Byte)VectorQuery.
32+
* This query provides a simple post-filter for the provided Query to limit the results of the inner query to those that have a similarity
33+
* above a certain threshold
3334
*/
3435
public class VectorSimilarityQuery extends Query implements QueryProfilerProvider {
3536
private final float similarity;

server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -456,12 +456,12 @@ public void testRescoreOversampleModifiesNumCandidates() {
456456
);
457457

458458
// Total results is k, internal k is multiplied by oversample
459-
checkRescoreQueryParameters(fieldType, 10, 200, 2.5F, 10, 500);
459+
checkRescoreQueryParameters(fieldType, 10, 200, 2.5F, null, 500, 10);
460460
// If numCands < k, update numCands to k
461-
checkRescoreQueryParameters(fieldType, 10, 20, 2.5F, 10, 50);
461+
checkRescoreQueryParameters(fieldType, 10, 20, 2.5F, null, 50, 10);
462462
// Oversampling limits for num candidates
463-
checkRescoreQueryParameters(fieldType, 1000, 1000, 11.0F, 1000, 10000);
464-
checkRescoreQueryParameters(fieldType, 5000, 7500, 2.5F, 5000, 10000);
463+
checkRescoreQueryParameters(fieldType, 1000, 1000, 11.0F, null, 10000, 1000);
464+
checkRescoreQueryParameters(fieldType, 5000, 7500, 2.5F, null, 10000, 5000);
465465
}
466466

467467
private static void checkRescoreQueryParameters(
@@ -470,7 +470,8 @@ private static void checkRescoreQueryParameters(
470470
int candidates,
471471
float numCandsFactor,
472472
Integer expectedK,
473-
int expectedCandidates
473+
int expectedCandidates,
474+
int expectedResults
474475
) {
475476
Query query = fieldType.createKnnQuery(
476477
VectorData.fromFloats(new float[] { 1, 4, 10 }),
@@ -483,7 +484,8 @@ private static void checkRescoreQueryParameters(
483484
);
484485
RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query;
485486
ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) rescoreQuery.innerQuery();
486-
assertThat("Unexpected total results", esKnnQuery.kParam(), equalTo(expectedK));
487+
assertThat("Unexpected total results", rescoreQuery.k(), equalTo(expectedResults));
488+
assertThat("Unexpected k parameter", esKnnQuery.kParam(), equalTo(expectedK));
487489
assertThat("Unexpected candidates", esKnnQuery.getK(), equalTo(expectedCandidates));
488490
}
489491
}

server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -171,23 +171,18 @@ protected RescoreVectorBuilder randomRescoreVectorBuilder() {
171171

172172
@Override
173173
protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query query, SearchExecutionContext context) throws IOException {
174+
if (queryBuilder.getVectorSimilarity() != null) {
175+
assertTrue(query instanceof VectorSimilarityQuery);
176+
assertThat(((VectorSimilarityQuery) query).getSimilarity(), equalTo(queryBuilder.getVectorSimilarity()));
177+
query = ((VectorSimilarityQuery) query).getInnerKnnQuery();
178+
}
174179
if (queryBuilder.rescoreVectorBuilder() != null && isQuantizedElementType()) {
175180
RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query;
176181
query = rescoreQuery.innerQuery();
177182
}
178-
if (queryBuilder.getVectorSimilarity() != null) {
179-
assertTrue(query instanceof VectorSimilarityQuery);
180-
Query knnQuery = ((VectorSimilarityQuery) query).getInnerKnnQuery();
181-
assertThat(((VectorSimilarityQuery) query).getSimilarity(), equalTo(queryBuilder.getVectorSimilarity()));
182-
switch (elementType()) {
183-
case FLOAT -> assertTrue(knnQuery instanceof ESKnnFloatVectorQuery);
184-
case BYTE -> assertTrue(knnQuery instanceof ESKnnByteVectorQuery);
185-
}
186-
} else {
187-
switch (elementType()) {
188-
case FLOAT -> assertTrue(query instanceof ESKnnFloatVectorQuery);
189-
case BYTE -> assertTrue(query instanceof ESKnnByteVectorQuery);
190-
}
183+
switch (elementType()) {
184+
case FLOAT -> assertTrue(query instanceof ESKnnFloatVectorQuery);
185+
case BYTE -> assertTrue(query instanceof ESKnnByteVectorQuery);
191186
}
192187

193188
BooleanQuery.Builder builder = new BooleanQuery.Builder();

server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ public void testRescoreDocs() throws Exception {
7777
FIELD_NAME,
7878
queryVector,
7979
VectorSimilarityFunction.COSINE,
80+
adjustedK,
8081
new MatchAllDocsQuery()
8182
);
8283

@@ -142,6 +143,7 @@ private void checkProfiling(float[] queryVector, IndexReader reader, Query inner
142143
FIELD_NAME,
143144
queryVector,
144145
VectorSimilarityFunction.COSINE,
146+
k,
145147
innerQuery
146148
);
147149
IndexSearcher searcher = newSearcher(reader, true, false);

0 commit comments

Comments
 (0)