Skip to content

Commit 37be056

Browse files
committed
Use KnnRescoreVectorQuery to perform rescoring and limiting the number of results from each shard
1 parent c96a1dc commit 37be056

File tree

1 file changed

+151
-0
lines changed

1 file changed

+151
-0
lines changed
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.vectors;
11+
12+
import org.apache.lucene.index.VectorSimilarityFunction;
13+
import org.apache.lucene.queries.function.FunctionScoreQuery;
14+
import org.apache.lucene.search.DoubleValuesSource;
15+
import org.apache.lucene.search.IndexSearcher;
16+
import org.apache.lucene.search.Query;
17+
import org.apache.lucene.search.QueryVisitor;
18+
import org.apache.lucene.search.ScoreDoc;
19+
import org.apache.lucene.search.TopDocs;
20+
import org.elasticsearch.index.mapper.vectors.VectorSimilarityByteValueSource;
21+
import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource;
22+
import org.elasticsearch.search.profile.query.QueryProfiler;
23+
24+
import java.io.IOException;
25+
import java.util.Arrays;
26+
import java.util.Objects;
27+
28+
/**
29+
* Wraps a kNN vector query to rescore the results using the non-quantized vectors
30+
*/
31+
public class KnnRescoreVectorQuery extends Query implements ProfilingQuery {
32+
private final String fieldName;
33+
private final byte[] byteTarget;
34+
private final float[] floatTarget;
35+
private final VectorSimilarityFunction vectorSimilarityFunction;
36+
private final Integer k;
37+
private final Query vectorQuery;
38+
39+
private long vectorOpsCount;
40+
41+
public KnnRescoreVectorQuery(
42+
String fieldName,
43+
byte[] byteTarget,
44+
VectorSimilarityFunction vectorSimilarityFunction,
45+
Integer k,
46+
Query vectorQuery
47+
) {
48+
this.fieldName = fieldName;
49+
this.byteTarget = byteTarget;
50+
this.floatTarget = null;
51+
this.vectorSimilarityFunction = vectorSimilarityFunction;
52+
this.k = k;
53+
this.vectorQuery = vectorQuery;
54+
}
55+
56+
public KnnRescoreVectorQuery(
57+
String fieldName,
58+
float[] floatTarget,
59+
VectorSimilarityFunction vectorSimilarityFunction,
60+
Integer k,
61+
Query vectorQuery
62+
) {
63+
this.fieldName = fieldName;
64+
this.byteTarget = null;
65+
this.floatTarget = floatTarget;
66+
this.vectorSimilarityFunction = vectorSimilarityFunction;
67+
this.k = k;
68+
this.vectorQuery = vectorQuery;
69+
}
70+
71+
@Override
72+
public Query rewrite(IndexSearcher searcher) throws IOException {
73+
Query rewritten = super.rewrite(searcher);
74+
if (rewritten != this) {
75+
return rewritten;
76+
}
77+
78+
final DoubleValuesSource valueSource;
79+
if (byteTarget != null) {
80+
valueSource = new VectorSimilarityByteValueSource(fieldName, byteTarget, vectorSimilarityFunction);
81+
} else {
82+
valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction);
83+
}
84+
FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(vectorQuery, valueSource);
85+
Query query = searcher.rewrite(functionScoreQuery);
86+
87+
if (k == null) {
88+
// No need to calculate top k - let the request size limit the results
89+
return query;
90+
}
91+
92+
TopDocs topDocs = searcher.search(query, k);
93+
ScoreDoc[] scoreDocs = topDocs.scoreDocs;
94+
int[] docIds = new int[scoreDocs.length];
95+
float[] scores = new float[scoreDocs.length];
96+
for (int i = 0; i < scoreDocs.length; i++) {
97+
docIds[i] = scoreDocs[i].doc;
98+
scores[i] = scoreDocs[i].score;
99+
}
100+
101+
vectorOpsCount = scoreDocs.length;
102+
103+
return new KnnScoreDocQuery(docIds, scores, searcher.getIndexReader());
104+
}
105+
106+
@Override
107+
public void profile(QueryProfiler queryProfiler) {
108+
queryProfiler.setVectorOpsCount(vectorOpsCount);
109+
}
110+
111+
@Override
112+
public void visit(QueryVisitor visitor) {
113+
if (visitor.acceptField(fieldName)) {
114+
visitor.visitLeaf(this);
115+
}
116+
}
117+
118+
@Override
119+
public boolean equals(Object o) {
120+
if (this == o) return true;
121+
if (o == null || getClass() != o.getClass()) return false;
122+
KnnRescoreVectorQuery that = (KnnRescoreVectorQuery) o;
123+
return Objects.equals(fieldName, that.fieldName)
124+
&& Objects.deepEquals(byteTarget, that.byteTarget)
125+
&& Objects.deepEquals(floatTarget, that.floatTarget)
126+
&& vectorSimilarityFunction == that.vectorSimilarityFunction
127+
&& Objects.equals(k, that.k)
128+
&& Objects.equals(vectorQuery, that.vectorQuery);
129+
}
130+
131+
@Override
132+
public int hashCode() {
133+
return Objects.hash(fieldName, Arrays.hashCode(byteTarget), Arrays.hashCode(floatTarget), vectorSimilarityFunction, k, vectorQuery);
134+
}
135+
136+
@Override
137+
public String toString(String field) {
138+
final StringBuilder sb = new StringBuilder("KnnRescoreVectorQuery{");
139+
sb.append("fieldName='").append(fieldName).append('\'');
140+
if (byteTarget != null) {
141+
sb.append(", byteTarget=").append(Arrays.toString(byteTarget));
142+
} else {
143+
sb.append(", floatTarget=").append(Arrays.toString(floatTarget));
144+
}
145+
sb.append(", vectorSimilarityFunction=").append(vectorSimilarityFunction);
146+
sb.append(", k=").append(k);
147+
sb.append(", vectorQuery=").append(vectorQuery);
148+
sb.append('}');
149+
return sb.toString();
150+
}
151+
}

0 commit comments

Comments
 (0)