Skip to content

Commit a7936da

Browse files
committed
Use KnnRescoreVectorQuery to perform rescoring and limiting the number of results from each shard
1 parent bc1e5c6 commit a7936da

File tree

2 files changed

+172
-42
lines changed

2 files changed

+172
-42
lines changed

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 21 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import org.apache.lucene.index.SegmentWriteState;
3131
import org.apache.lucene.index.VectorEncoding;
3232
import org.apache.lucene.index.VectorSimilarityFunction;
33-
import org.apache.lucene.queries.function.FunctionScoreQuery;
3433
import org.apache.lucene.search.FieldExistsQuery;
3534
import org.apache.lucene.search.Query;
3635
import org.apache.lucene.search.join.BitSetProducer;
@@ -71,6 +70,7 @@
7170
import org.elasticsearch.search.vectors.ESDiversifyingChildrenFloatKnnVectorQuery;
7271
import org.elasticsearch.search.vectors.ESKnnByteVectorQuery;
7372
import org.elasticsearch.search.vectors.ESKnnFloatVectorQuery;
73+
import org.elasticsearch.search.vectors.KnnRescoreVectorQuery;
7474
import org.elasticsearch.search.vectors.VectorData;
7575
import org.elasticsearch.search.vectors.VectorSimilarityQuery;
7676
import org.elasticsearch.xcontent.ToXContent;
@@ -2019,16 +2019,6 @@ public Query createKnnQuery(
20192019
"to perform knn search on field [" + name() + "], its mapping must have [index] set to [true]"
20202020
);
20212021
}
2022-
if (rescoreOversample != null && indexOptions.type.isQuantized() == false) {
2023-
throw new IllegalArgumentException(
2024-
"cannot use rescore oversample on field ["
2025-
+ name()
2026-
+ "], that uses non-quantized type ["
2027-
+ indexOptions.type
2028-
+ "]. "
2029-
+ "Only quantized index option types support rescore oversample."
2030-
);
2031-
}
20322022
return switch (getElementType()) {
20332023
case BYTE -> createKnnByteQuery(
20342024
queryVector.asByteVector(),
@@ -2060,6 +2050,10 @@ public Query createKnnQuery(
20602050
};
20612051
}
20622052

2053+
private boolean needsRescore(Float rescoreOversample) {
2054+
return rescoreOversample != null && (indexOptions == null || indexOptions.type == null || indexOptions.type.isQuantized());
2055+
}
2056+
20632057
private Query createKnnBitQuery(
20642058
byte[] queryVector,
20652059
Integer k,
@@ -2084,17 +2078,6 @@ private Query createKnnBitQuery(
20842078
similarity.score(similarityThreshold, elementType, dims)
20852079
);
20862080
}
2087-
if (rescoreOversample != null) {
2088-
knnQuery = new FunctionScoreQuery(
2089-
knnQuery,
2090-
new VectorSimilarityByteValueSource(
2091-
name(),
2092-
queryVector,
2093-
similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.BYTE)
2094-
)
2095-
);
2096-
2097-
}
20982081
return knnQuery;
20992082
}
21002083

@@ -2113,7 +2096,7 @@ private Query createKnnByteQuery(
21132096
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
21142097
elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude);
21152098
}
2116-
Integer adjustedK = k == null || rescoreOversample == null
2099+
Integer adjustedK = k == null || needsRescore(rescoreOversample) == false
21172100
? null
21182101
: Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample));
21192102
int adjustedNumCands = Math.max(adjustedK == null ? 0 : adjustedK, numCands);
@@ -2128,16 +2111,14 @@ private Query createKnnByteQuery(
21282111
similarity.score(similarityThreshold, elementType, dims)
21292112
);
21302113
}
2131-
if (rescoreOversample != null) {
2132-
knnQuery = new FunctionScoreQuery(
2133-
knnQuery,
2134-
new VectorSimilarityByteValueSource(
2135-
name(),
2136-
queryVector,
2137-
similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.BYTE)
2138-
)
2114+
if (needsRescore(rescoreOversample)) {
2115+
knnQuery = new KnnRescoreVectorQuery(
2116+
name(),
2117+
queryVector,
2118+
similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.BYTE),
2119+
k,
2120+
knnQuery
21392121
);
2140-
21412122
}
21422123
return knnQuery;
21432124
}
@@ -2167,7 +2148,7 @@ && isNotUnitVector(squaredMagnitude)) {
21672148
}
21682149
}
21692150

2170-
Integer adjustedK = k == null || rescoreOversample == null
2151+
Integer adjustedK = k == null || needsRescore(rescoreOversample) == false
21712152
? k
21722153
: Integer.valueOf(Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample)));
21732154
int adjustedNumCands = adjustedK == null ? numCands : Math.max(adjustedK, numCands);
@@ -2181,16 +2162,14 @@ && isNotUnitVector(squaredMagnitude)) {
21812162
similarity.score(similarityThreshold, elementType, dims)
21822163
);
21832164
}
2184-
if (rescoreOversample != null) {
2185-
knnQuery = new FunctionScoreQuery(
2186-
knnQuery,
2187-
new VectorSimilarityFloatValueSource(
2188-
name(),
2189-
queryVector,
2190-
similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.FLOAT)
2191-
)
2165+
if (needsRescore(rescoreOversample)) {
2166+
knnQuery = new KnnRescoreVectorQuery(
2167+
name(),
2168+
queryVector,
2169+
similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.FLOAT),
2170+
k,
2171+
knnQuery
21922172
);
2193-
21942173
}
21952174
return knnQuery;
21962175
}
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.vectors;
11+
12+
import org.apache.lucene.index.VectorSimilarityFunction;
13+
import org.apache.lucene.queries.function.FunctionScoreQuery;
14+
import org.apache.lucene.search.DoubleValuesSource;
15+
import org.apache.lucene.search.IndexSearcher;
16+
import org.apache.lucene.search.Query;
17+
import org.apache.lucene.search.QueryVisitor;
18+
import org.apache.lucene.search.ScoreDoc;
19+
import org.apache.lucene.search.TopDocs;
20+
import org.elasticsearch.index.mapper.vectors.VectorSimilarityByteValueSource;
21+
import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource;
22+
import org.elasticsearch.search.profile.query.QueryProfiler;
23+
24+
import java.io.IOException;
25+
import java.util.Arrays;
26+
import java.util.Objects;
27+
28+
/**
29+
* Wraps a kNN vector query to rescore the results using the non-quantized vectors
30+
*/
31+
public class KnnRescoreVectorQuery extends Query implements ProfilingQuery {
32+
private final String fieldName;
33+
private final byte[] byteTarget;
34+
private final float[] floatTarget;
35+
private final VectorSimilarityFunction vectorSimilarityFunction;
36+
private final Integer k;
37+
private final Query vectorQuery;
38+
39+
private long vectorOpsCount;
40+
41+
public KnnRescoreVectorQuery(
42+
String fieldName,
43+
byte[] byteTarget,
44+
VectorSimilarityFunction vectorSimilarityFunction,
45+
Integer k,
46+
Query vectorQuery
47+
) {
48+
this.fieldName = fieldName;
49+
this.byteTarget = byteTarget;
50+
this.floatTarget = null;
51+
this.vectorSimilarityFunction = vectorSimilarityFunction;
52+
this.k = k;
53+
this.vectorQuery = vectorQuery;
54+
}
55+
56+
public KnnRescoreVectorQuery(
57+
String fieldName,
58+
float[] floatTarget,
59+
VectorSimilarityFunction vectorSimilarityFunction,
60+
Integer k,
61+
Query vectorQuery
62+
) {
63+
this.fieldName = fieldName;
64+
this.byteTarget = null;
65+
this.floatTarget = floatTarget;
66+
this.vectorSimilarityFunction = vectorSimilarityFunction;
67+
this.k = k;
68+
this.vectorQuery = vectorQuery;
69+
}
70+
71+
@Override
72+
public Query rewrite(IndexSearcher searcher) throws IOException {
73+
Query rewritten = super.rewrite(searcher);
74+
if (rewritten != this) {
75+
return rewritten;
76+
}
77+
78+
final DoubleValuesSource valueSource;
79+
if (byteTarget != null) {
80+
valueSource = new VectorSimilarityByteValueSource(fieldName, byteTarget, vectorSimilarityFunction);
81+
} else {
82+
valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction);
83+
}
84+
FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(vectorQuery, valueSource);
85+
Query query = searcher.rewrite(functionScoreQuery);
86+
87+
if (k == null) {
88+
// No need to calculate top k - let the request size limit the results
89+
return query;
90+
}
91+
92+
TopDocs topDocs = searcher.search(query, k);
93+
ScoreDoc[] scoreDocs = topDocs.scoreDocs;
94+
int[] docIds = new int[scoreDocs.length];
95+
float[] scores = new float[scoreDocs.length];
96+
for (int i = 0; i < scoreDocs.length; i++) {
97+
docIds[i] = scoreDocs[i].doc;
98+
scores[i] = scoreDocs[i].score;
99+
}
100+
101+
vectorOpsCount = scoreDocs.length;
102+
103+
return new KnnScoreDocQuery(docIds, scores, searcher.getIndexReader());
104+
}
105+
106+
@Override
107+
public void profile(QueryProfiler queryProfiler) {
108+
queryProfiler.setVectorOpsCount(vectorOpsCount);
109+
}
110+
111+
@Override
112+
public void visit(QueryVisitor visitor) {
113+
if (visitor.acceptField(fieldName)) {
114+
visitor.visitLeaf(this);
115+
}
116+
}
117+
118+
@Override
119+
public boolean equals(Object o) {
120+
if (this == o) return true;
121+
if (o == null || getClass() != o.getClass()) return false;
122+
KnnRescoreVectorQuery that = (KnnRescoreVectorQuery) o;
123+
return Objects.equals(fieldName, that.fieldName)
124+
&& Objects.deepEquals(byteTarget, that.byteTarget)
125+
&& Objects.deepEquals(floatTarget, that.floatTarget)
126+
&& vectorSimilarityFunction == that.vectorSimilarityFunction
127+
&& Objects.equals(k, that.k)
128+
&& Objects.equals(vectorQuery, that.vectorQuery);
129+
}
130+
131+
@Override
132+
public int hashCode() {
133+
return Objects.hash(fieldName, Arrays.hashCode(byteTarget), Arrays.hashCode(floatTarget), vectorSimilarityFunction, k, vectorQuery);
134+
}
135+
136+
@Override
137+
public String toString(String field) {
138+
final StringBuilder sb = new StringBuilder("KnnRescoreVectorQuery{");
139+
sb.append("fieldName='").append(fieldName).append('\'');
140+
if (byteTarget != null) {
141+
sb.append(", byteTarget=").append(Arrays.toString(byteTarget));
142+
} else {
143+
sb.append(", floatTarget=").append(Arrays.toString(floatTarget));
144+
}
145+
sb.append(", vectorSimilarityFunction=").append(vectorSimilarityFunction);
146+
sb.append(", k=").append(k);
147+
sb.append(", vectorQuery=").append(vectorQuery);
148+
sb.append('}');
149+
return sb.toString();
150+
}
151+
}

0 commit comments

Comments
 (0)