Skip to content

Commit 9412be0

Browse files
committed
Simplify logic for RescoreKnnVectorQuery now that k is not modifiable
1 parent 978cff3 commit 9412be0

File tree

5 files changed

+31
-55
lines changed

5 files changed

+31
-55
lines changed

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2134,7 +2134,6 @@ && isNotUnitVector(squaredMagnitude)) {
21342134
name(),
21352135
queryVector,
21362136
similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.FLOAT),
2137-
k,
21382137
knnQuery
21392138
);
21402139
}

server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java

Lines changed: 14 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
import org.apache.lucene.search.IndexSearcher;
1717
import org.apache.lucene.search.Query;
1818
import org.apache.lucene.search.QueryVisitor;
19-
import org.apache.lucene.search.ScoreDoc;
20-
import org.apache.lucene.search.TopDocs;
2119
import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource;
2220
import org.elasticsearch.search.profile.query.QueryProfiler;
2321

@@ -32,7 +30,6 @@ public class RescoreKnnVectorQuery extends Query implements QueryProfilerProvide
3230
private final String fieldName;
3331
private final float[] floatTarget;
3432
private final VectorSimilarityFunction vectorSimilarityFunction;
35-
private final Integer k;
3633
private final Query innerQuery;
3734

3835
private QueryProfilerProvider vectorProfiling;
@@ -41,13 +38,11 @@ public RescoreKnnVectorQuery(
4138
String fieldName,
4239
float[] floatTarget,
4340
VectorSimilarityFunction vectorSimilarityFunction,
44-
Integer k,
4541
Query innerQuery
4642
) {
4743
this.fieldName = fieldName;
4844
this.floatTarget = floatTarget;
4945
this.vectorSimilarityFunction = vectorSimilarityFunction;
50-
this.k = k;
5146
this.innerQuery = innerQuery;
5247
}
5348

@@ -58,34 +53,13 @@ public Query rewrite(IndexSearcher searcher) throws IOException {
5853
// to calculate top k and return directly the query to understand how many comparisons were done
5954
vectorProfiling = (QueryProfilerProvider) valueSource;
6055
FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(innerQuery, valueSource);
61-
Query query = searcher.rewrite(functionScoreQuery);
62-
63-
if (k == null) {
64-
// No need to calculate top k - let the request size limit the results.
65-
return query;
66-
}
67-
68-
// Retrieve top k documents from the rescored query
69-
TopDocs topDocs = searcher.search(query, k);
70-
ScoreDoc[] scoreDocs = topDocs.scoreDocs;
71-
int[] docIds = new int[scoreDocs.length];
72-
float[] scores = new float[scoreDocs.length];
73-
for (int i = 0; i < scoreDocs.length; i++) {
74-
docIds[i] = scoreDocs[i].doc;
75-
scores[i] = scoreDocs[i].score;
76-
}
77-
78-
return new KnnScoreDocQuery(docIds, scores, searcher.getIndexReader());
56+
return searcher.rewrite(functionScoreQuery);
7957
}
8058

8159
public Query innerQuery() {
8260
return innerQuery;
8361
}
8462

85-
public Integer k() {
86-
return k;
87-
}
88-
8963
@Override
9064
public void profile(QueryProfiler queryProfiler) {
9165
if (innerQuery instanceof QueryProfilerProvider queryProfilerProvider) {
@@ -111,24 +85,27 @@ public boolean equals(Object o) {
11185
return Objects.equals(fieldName, that.fieldName)
11286
&& Objects.deepEquals(floatTarget, that.floatTarget)
11387
&& vectorSimilarityFunction == that.vectorSimilarityFunction
114-
&& Objects.equals(k, that.k)
11588
&& Objects.equals(innerQuery, that.innerQuery);
11689
}
11790

11891
@Override
11992
public int hashCode() {
120-
return Objects.hash(fieldName, Arrays.hashCode(floatTarget), vectorSimilarityFunction, k, innerQuery);
93+
return Objects.hash(fieldName, Arrays.hashCode(floatTarget), vectorSimilarityFunction, innerQuery);
12194
}
12295

12396
@Override
12497
public String toString(String field) {
125-
final StringBuilder sb = new StringBuilder("KnnRescoreVectorQuery{");
126-
sb.append("fieldName='").append(fieldName).append('\'');
127-
sb.append(", floatTarget=").append(floatTarget[0]).append("...");
128-
sb.append(", vectorSimilarityFunction=").append(vectorSimilarityFunction);
129-
sb.append(", k=").append(k);
130-
sb.append(", vectorQuery=").append(innerQuery);
131-
sb.append('}');
132-
return sb.toString();
98+
return "KnnRescoreVectorQuery{"
99+
+ "fieldName='"
100+
+ fieldName
101+
+ '\''
102+
+ ", floatTarget="
103+
+ floatTarget[0]
104+
+ "..."
105+
+ ", vectorSimilarityFunction="
106+
+ vectorSimilarityFunction
107+
+ ", vectorQuery="
108+
+ innerQuery
109+
+ '}';
133110
}
134111
}

server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ public void testRescoreOversampleUsedWithoutQuantization() {
443443
}
444444
}
445445

446-
public void testRescoreOversampleModifiesKnnParams() {
446+
public void testRescoreOversampleModifiesNumCandidates() {
447447
DenseVectorFieldType fieldType = new DenseVectorFieldType(
448448
"f",
449449
IndexVersion.current(),
@@ -456,36 +456,38 @@ public void testRescoreOversampleModifiesKnnParams() {
456456
);
457457

458458
// Total results is k, internal k is multiplied by oversample
459-
checkRescoreQueryParameters(fieldType, 10, 200, 2.5F, 10, 25, 200);
459+
checkRescoreQueryParameters(fieldType, 10, 200, 2.5F, 10, 500);
460460
// If numCands < k, update numCands to k
461-
checkRescoreQueryParameters(fieldType, 10, 20, 2.5F, 10, 25, 25);
462-
// Oversampling limit
463-
checkRescoreQueryParameters(fieldType, 1000, 1000, 11.0F, 1000, 10000, 10000);
464-
checkRescoreQueryParameters(fieldType, 5000, 7500, 2.5F, 5000, 10000, 10000);
461+
checkRescoreQueryParameters(fieldType, 10, 20, 2.5F, 10, 50);
462+
// Oversampling limits for num candidates
463+
checkRescoreQueryParameters(fieldType, 1000, 1000, 11.0F, 1000, 10000);
464+
checkRescoreQueryParameters(fieldType, 5000, 7500, 2.5F, 5000, 10000);
465+
// Oversampling is capped at k as a minimum
466+
checkRescoreQueryParameters(fieldType, 10, 100, 0.01F, 10, 10);
467+
// Oversampling is capped at 1 as a minimum if k is not specified
468+
checkRescoreQueryParameters(fieldType, null, 100, 0.0001F, null, 1);
465469
}
466470

467471
private static void checkRescoreQueryParameters(
468472
DenseVectorFieldType fieldType,
469-
int k,
473+
Integer k,
470474
int candidates,
471-
float oversample,
472-
int expectedResults,
473-
int expectedK,
475+
float numCandsFactor,
476+
Integer expectedK,
474477
int expectedCandidates
475478
) {
476479
Query query = fieldType.createKnnQuery(
477480
VectorData.fromFloats(new float[] { 1, 4, 10 }),
478481
k,
479482
candidates,
480-
oversample,
483+
numCandsFactor,
481484
null,
482485
null,
483486
null
484487
);
485488
RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query;
486489
ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) rescoreQuery.innerQuery();
487-
assertThat("Unexpected total results", rescoreQuery.k(), equalTo(expectedResults));
488-
assertThat("Unexpected k parameter", esKnnQuery.kParam(), equalTo(expectedK));
490+
assertThat("Unexpected total results", esKnnQuery.kParam(), equalTo(expectedK));
489491
assertThat("Unexpected candidates", esKnnQuery.getK(), equalTo(expectedCandidates));
490492
}
491493
}

server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,9 @@ public void testInvalidK() {
257257
public void testInvalidRescoreVectorBuilder() {
258258
IllegalArgumentException e = expectThrows(
259259
IllegalArgumentException.class,
260-
() -> new KnnSearchBuilder("field", randomVector(3), 0, 100, new RescoreVectorBuilder(1.0F), null)
260+
() -> new KnnSearchBuilder("field", randomVector(3), 10, 100, new RescoreVectorBuilder(0.0F), null)
261261
);
262-
assertThat(e.getMessage(), containsString("[oversample] must be > 1.0"));
262+
assertThat(e.getMessage(), containsString("[num_candidates_factor] must be > 0.0"));
263263
}
264264

265265
public void testRewrite() throws Exception {

server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ public void testRescoreDocs() throws Exception {
7777
FIELD_NAME,
7878
queryVector,
7979
VectorSimilarityFunction.COSINE,
80-
adjustedK,
8180
new MatchAllDocsQuery()
8281
);
8382

@@ -143,7 +142,6 @@ private void checkProfiling(float[] queryVector, IndexReader reader, Query inner
143142
FIELD_NAME,
144143
queryVector,
145144
VectorSimilarityFunction.COSINE,
146-
k,
147145
innerQuery
148146
);
149147
IndexSearcher searcher = newSearcher(reader, true, false);

0 commit comments

Comments
 (0)