Skip to content

Commit 3a65731

Browse files
committed
WIP - add rescore to custom ESKnnFloatVectorQuery to do exact search after an approximate search
1 parent 6182921 commit 3a65731

File tree

9 files changed

+338
-26
lines changed

9 files changed

+338
-26
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ static TransportVersion def(int id) {
185185
public static final TransportVersion INDEX_REQUEST_REMOVE_METERING = def(8_780_00_0);
186186
public static final TransportVersion CPU_STAT_STRING_PARSING = def(8_781_00_0);
187187
public static final TransportVersion QUERY_RULES_RETRIEVER = def(8_782_00_0);
188+
public static final TransportVersion KNN_QUERY_RESCORE_OVERSAMPLE = def(8_783_00_0);
188189

189190
/*
190191
* STOP! READ THIS FIRST! No, really,

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1983,6 +1983,7 @@ public Query createKnnQuery(
19831983
VectorData queryVector,
19841984
Integer k,
19851985
int numCands,
1986+
Float rescoreOversample,
19861987
Query filter,
19871988
Float similarityThreshold,
19881989
BitSetProducer parentFilter
@@ -1994,7 +1995,15 @@ public Query createKnnQuery(
19941995
}
19951996
return switch (getElementType()) {
19961997
case BYTE -> createKnnByteQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter);
1997-
case FLOAT -> createKnnFloatQuery(queryVector.asFloatVector(), k, numCands, filter, similarityThreshold, parentFilter);
1998+
case FLOAT -> createKnnFloatQuery(
1999+
queryVector.asFloatVector(),
2000+
k,
2001+
numCands,
2002+
rescoreOversample,
2003+
filter,
2004+
similarityThreshold,
2005+
parentFilter
2006+
);
19982007
case BIT -> createKnnBitQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter);
19992008
};
20002009
}
@@ -2052,6 +2061,7 @@ private Query createKnnFloatQuery(
20522061
float[] queryVector,
20532062
Integer k,
20542063
int numCands,
2064+
Float rescoreOversample,
20552065
Query filter,
20562066
Float similarityThreshold,
20572067
BitSetProducer parentFilter
@@ -2073,7 +2083,7 @@ && isNotUnitVector(squaredMagnitude)) {
20732083
}
20742084
Query knnQuery = parentFilter != null
20752085
? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter)
2076-
: new ESKnnFloatVectorQuery(name(), queryVector, k, numCands, filter);
2086+
: new ESKnnFloatVectorQuery(name(), queryVector, k, numCands, rescoreOversample, filter);
20772087
if (similarityThreshold != null) {
20782088
knnQuery = new VectorSimilarityQuery(
20792089
knnQuery,

server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,35 @@
99

1010
package org.elasticsearch.search.vectors;
1111

12+
import org.apache.lucene.index.LeafReaderContext;
13+
import org.apache.lucene.index.QueryTimeout;
1214
import org.apache.lucene.search.KnnFloatVectorQuery;
1315
import org.apache.lucene.search.Query;
16+
import org.apache.lucene.search.ScoreDoc;
17+
import org.apache.lucene.search.TimeLimitingKnnCollectorManager;
1418
import org.apache.lucene.search.TopDocs;
19+
import org.apache.lucene.search.knn.KnnCollectorManager;
20+
import org.apache.lucene.util.BitSet;
21+
import org.apache.lucene.util.BitSetIterator;
22+
import org.apache.lucene.util.Bits;
23+
import org.apache.lucene.util.FixedBitSet;
1524
import org.elasticsearch.search.profile.query.QueryProfiler;
1625

26+
import java.io.IOException;
27+
1728
public class ESKnnFloatVectorQuery extends KnnFloatVectorQuery implements ProfilingQuery {
1829
private final Integer kParam;
1930
private long vectorOpsCount;
31+
private final Float rescoreOversample;
2032

21-
public ESKnnFloatVectorQuery(String field, float[] target, Integer k, int numCands, Query filter) {
22-
super(field, target, numCands, filter);
33+
public ESKnnFloatVectorQuery(String field, float[] target, Integer k, int numCands, Float rescoreOversample, Query filter) {
34+
super(field, target, adjustCandidates(numCands, rescoreOversample), filter);
2335
this.kParam = k;
36+
this.rescoreOversample = rescoreOversample;
37+
}
38+
39+
private static int adjustCandidates(int numCands, Float rescoreOversample) {
40+
return rescoreOversample == null ? numCands : (int) Math.ceil(numCands * rescoreOversample);
2441
}
2542

2643
@Override
@@ -31,8 +48,43 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
3148
return topK;
3249
}
3350

51+
@Override
52+
protected TopDocs approximateSearch(
53+
LeafReaderContext context,
54+
Bits acceptDocs,
55+
int visitedLimit,
56+
KnnCollectorManager knnCollectorManager
57+
) throws IOException {
58+
TopDocs topDocs = super.approximateSearch(context, acceptDocs, visitedLimit, knnCollectorManager);
59+
if (rescoreOversample == null) {
60+
return topDocs;
61+
}
62+
63+
BitSet exactSearchAcceptDocs = topDocsToBitSet(topDocs, acceptDocs.length());
64+
BitSetIterator bitSetIterator = new BitSetIterator(exactSearchAcceptDocs, topDocs.scoreDocs.length);
65+
QueryTimeout queryTimeout = null;
66+
if (knnCollectorManager instanceof TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager) {
67+
queryTimeout = timeLimitingKnnCollectorManager.getQueryTimeout();
68+
}
69+
return exactSearch(context, bitSetIterator, queryTimeout);
70+
}
71+
3472
@Override
3573
public void profile(QueryProfiler queryProfiler) {
3674
queryProfiler.setVectorOpsCount(vectorOpsCount);
3775
}
76+
77+
// Convert TopDocs to BitSet
78+
private static BitSet topDocsToBitSet(TopDocs topDocs, int numBits) {
79+
// Create a FixedBitSet with a size equal to the maximum number of documents
80+
BitSet bitSet = new FixedBitSet(numBits);
81+
82+
// Iterate through each document in TopDocs
83+
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
84+
// Set the corresponding bit for each doc ID
85+
bitSet.set(scoreDoc.doc);
86+
}
87+
88+
return bitSet;
89+
}
3890
}

server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import java.util.Objects;
4646
import java.util.function.Supplier;
4747

48+
import static org.elasticsearch.TransportVersions.KNN_QUERY_RESCORE_OVERSAMPLE;
4849
import static org.elasticsearch.common.Strings.format;
4950
import static org.elasticsearch.search.SearchService.DEFAULT_SIZE;
5051
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
@@ -66,6 +67,7 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
6667
public static final ParseField NUM_CANDS_FIELD = new ParseField("num_candidates");
6768
public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector");
6869
public static final ParseField VECTOR_SIMILARITY_FIELD = new ParseField("similarity");
70+
public static final ParseField RESCORE_VECTOR_OVERSAMPLE = new ParseField("rescore_vector_oversample");
6971
public static final ParseField FILTER_FIELD = new ParseField("filter");
7072
public static final ParseField QUERY_VECTOR_BUILDER_FIELD = new ParseField("query_vector_builder");
7173

@@ -79,7 +81,8 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
7981
null,
8082
(Integer) args[2],
8183
(Integer) args[3],
82-
(Float) args[4]
84+
(Float) args[4],
85+
(Float) args[5]
8386
)
8487
);
8588

@@ -106,6 +109,7 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
106109
ObjectParser.ValueType.OBJECT_ARRAY
107110
);
108111
declareStandardFields(PARSER);
112+
PARSER.declareFloat(optionalConstructorArg(), RESCORE_VECTOR_OVERSAMPLE);
109113
}
110114

111115
public static KnnVectorQueryBuilder fromXContent(XContentParser parser) {
@@ -120,6 +124,7 @@ public static KnnVectorQueryBuilder fromXContent(XContentParser parser) {
120124
private final Float vectorSimilarity;
121125
private final QueryVectorBuilder queryVectorBuilder;
122126
private final Supplier<float[]> queryVectorSupplier;
127+
private final Float rescoreOversample;
123128

124129
public KnnVectorQueryBuilder(String fieldName, float[] queryVector, Integer k, Integer numCands, Float vectorSimilarity) {
125130
this(fieldName, VectorData.fromFloats(queryVector), null, null, k, numCands, vectorSimilarity);
@@ -151,6 +156,19 @@ private KnnVectorQueryBuilder(
151156
Integer k,
152157
Integer numCands,
153158
Float vectorSimilarity
159+
) {
160+
this(fieldName, queryVector, null, null, k, numCands, vectorSimilarity, 0F);
161+
}
162+
163+
private KnnVectorQueryBuilder(
164+
String fieldName,
165+
VectorData queryVector,
166+
QueryVectorBuilder queryVectorBuilder,
167+
Supplier<float[]> queryVectorSupplier,
168+
Integer k,
169+
Integer numCands,
170+
Float vectorSimilarity,
171+
Float rescoreOversample
154172
) {
155173
if (k != null && k < 1) {
156174
throw new IllegalArgumentException("[" + K_FIELD.getPreferredName() + "] must be greater than 0");
@@ -187,6 +205,7 @@ private KnnVectorQueryBuilder(
187205
this.vectorSimilarity = vectorSimilarity;
188206
this.queryVectorBuilder = queryVectorBuilder;
189207
this.queryVectorSupplier = queryVectorSupplier;
208+
this.rescoreOversample = rescoreOversample;
190209
}
191210

192211
public KnnVectorQueryBuilder(StreamInput in) throws IOException {
@@ -227,6 +246,12 @@ public KnnVectorQueryBuilder(StreamInput in) throws IOException {
227246
} else {
228247
this.queryVectorBuilder = null;
229248
}
249+
if (in.getTransportVersion().onOrAfter(KNN_QUERY_RESCORE_OVERSAMPLE)) {
250+
this.rescoreOversample = in.readOptionalFloat();
251+
} else {
252+
this.rescoreOversample = null;
253+
}
254+
230255
this.queryVectorSupplier = null;
231256
}
232257

@@ -252,6 +277,10 @@ public Integer numCands() {
252277
return numCands;
253278
}
254279

280+
public Float rescoreOversample() {
281+
return rescoreOversample;
282+
}
283+
255284
public List<QueryBuilder> filterQueries() {
256285
return filterQueries;
257286
}
@@ -327,6 +356,9 @@ protected void doWriteTo(StreamOutput out) throws IOException {
327356
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0)) {
328357
out.writeOptionalNamedWriteable(queryVectorBuilder);
329358
}
359+
if (out.getTransportVersion().onOrAfter(KNN_QUERY_RESCORE_OVERSAMPLE)) {
360+
out.writeOptionalFloat(rescoreOversample);
361+
}
330362
}
331363

332364
@Override
@@ -491,14 +523,31 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
491523
// Now join the filterQuery & parentFilter to provide the matching blocks of children
492524
filterQuery = new ToChildBlockJoinQuery(filterQuery, parentBitSet);
493525
}
494-
return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, filterQuery, vectorSimilarity, parentBitSet);
526+
return vectorFieldType.createKnnQuery(
527+
queryVector,
528+
k,
529+
adjustedNumCands,
530+
rescoreOversample,
531+
filterQuery,
532+
vectorSimilarity,
533+
parentBitSet
534+
);
495535
}
496-
return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, filterQuery, vectorSimilarity, null);
536+
return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, rescoreOversample, filterQuery, vectorSimilarity, null);
497537
}
498538

499539
@Override
500540
protected int doHashCode() {
501-
return Objects.hash(fieldName, Objects.hashCode(queryVector), k, numCands, filterQueries, vectorSimilarity, queryVectorBuilder);
541+
return Objects.hash(
542+
fieldName,
543+
Objects.hashCode(queryVector),
544+
k,
545+
numCands,
546+
filterQueries,
547+
vectorSimilarity,
548+
queryVectorBuilder,
549+
rescoreOversample
550+
);
502551
}
503552

504553
@Override
@@ -509,7 +558,8 @@ protected boolean doEquals(KnnVectorQueryBuilder other) {
509558
&& Objects.equals(numCands, other.numCands)
510559
&& Objects.equals(filterQueries, other.filterQueries)
511560
&& Objects.equals(vectorSimilarity, other.vectorSimilarity)
512-
&& Objects.equals(queryVectorBuilder, other.queryVectorBuilder);
561+
&& Objects.equals(queryVectorBuilder, other.queryVectorBuilder)
562+
&& Objects.equals(rescoreOversample, other.rescoreOversample);
513563
}
514564

515565
@Override
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.vectors;
11+
12+
import org.apache.lucene.search.IndexSearcher;
13+
import org.apache.lucene.search.KnnFloatVectorQuery;
14+
import org.apache.lucene.search.Query;
15+
import org.apache.lucene.search.QueryVisitor;
16+
import org.apache.lucene.search.ScoreMode;
17+
import org.apache.lucene.search.Weight;
18+
19+
import java.io.IOException;
20+
21+
public class VectorRescoreQuery extends Query {
22+
23+
private final KnnFloatVectorQuery knnQuery;
24+
25+
public VectorRescoreQuery(KnnFloatVectorQuery knnQuery) {
26+
this.knnQuery = knnQuery;
27+
}
28+
29+
@Override
30+
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
31+
return super.createWeight(searcher, scoreMode, boost);
32+
}
33+
34+
@Override
35+
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
36+
return super.rewrite(indexSearcher);
37+
}
38+
39+
@Override
40+
public String toString(String field) {
41+
return "";
42+
}
43+
44+
@Override
45+
public void visit(QueryVisitor visitor) {
46+
47+
}
48+
49+
@Override
50+
public boolean equals(Object obj) {
51+
return false;
52+
}
53+
54+
@Override
55+
public int hashCode() {
56+
return 0;
57+
}
58+
}

0 commit comments

Comments
 (0)