Skip to content

Commit f5e2a92

Browse files
authored
Add rescore knn vector test coverage (#122801)
1 parent 4ca669a commit f5e2a92

File tree

6 files changed

+316
-77
lines changed

6 files changed

+316
-77
lines changed
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.query;
11+
12+
import org.apache.lucene.index.VectorSimilarityFunction;
13+
import org.elasticsearch.action.index.IndexRequestBuilder;
14+
import org.elasticsearch.action.search.SearchRequestBuilder;
15+
import org.elasticsearch.action.search.SearchResponse;
16+
import org.elasticsearch.cluster.metadata.IndexMetadata;
17+
import org.elasticsearch.common.settings.Settings;
18+
import org.elasticsearch.index.IndexVersion;
19+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
20+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.VectorIndexType;
21+
import org.elasticsearch.index.mapper.vectors.DenseVectorScriptDocValues;
22+
import org.elasticsearch.index.query.MatchAllQueryBuilder;
23+
import org.elasticsearch.index.query.QueryBuilders;
24+
import org.elasticsearch.index.query.functionscore.ScriptScoreQueryBuilder;
25+
import org.elasticsearch.plugins.Plugin;
26+
import org.elasticsearch.script.MockScriptPlugin;
27+
import org.elasticsearch.script.Script;
28+
import org.elasticsearch.script.ScriptType;
29+
import org.elasticsearch.search.SearchHit;
30+
import org.elasticsearch.search.builder.SearchSourceBuilder;
31+
import org.elasticsearch.search.retriever.KnnRetrieverBuilder;
32+
import org.elasticsearch.search.vectors.KnnSearchBuilder;
33+
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
34+
import org.elasticsearch.search.vectors.RescoreVectorBuilder;
35+
import org.elasticsearch.test.ESIntegTestCase;
36+
import org.elasticsearch.xcontent.XContentBuilder;
37+
import org.elasticsearch.xcontent.XContentFactory;
38+
import org.junit.Before;
39+
40+
import java.io.IOException;
41+
import java.util.ArrayList;
42+
import java.util.Arrays;
43+
import java.util.Collection;
44+
import java.util.Collections;
45+
import java.util.List;
46+
import java.util.Locale;
47+
import java.util.Map;
48+
import java.util.function.BiFunction;
49+
import java.util.function.Function;
50+
import java.util.stream.Collectors;
51+
52+
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount;
53+
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailuresAndResponse;
54+
import static org.hamcrest.Matchers.equalTo;
55+
56+
public class RescoreKnnVectorQueryIT extends ESIntegTestCase {
57+
58+
public static final String INDEX_NAME = "test";
59+
public static final String VECTOR_FIELD = "vector";
60+
public static final String VECTOR_SCORE_SCRIPT = "vector_scoring";
61+
public static final String QUERY_VECTOR_PARAM = "query_vector";
62+
63+
@Override
64+
protected Collection<Class<? extends Plugin>> nodePlugins() {
65+
return Collections.singleton(CustomScriptPlugin.class);
66+
}
67+
68+
public static class CustomScriptPlugin extends MockScriptPlugin {
69+
private static final VectorSimilarityFunction SIMILARITY_FUNCTION = DenseVectorFieldMapper.VectorSimilarity.L2_NORM
70+
.vectorSimilarityFunction(IndexVersion.current(), DenseVectorFieldMapper.ElementType.FLOAT);
71+
72+
@Override
73+
protected Map<String, Function<Map<String, Object>, Object>> pluginScripts() {
74+
return Map.of(VECTOR_SCORE_SCRIPT, vars -> {
75+
Map<?, ?> doc = (Map<?, ?>) vars.get("doc");
76+
return SIMILARITY_FUNCTION.compare(
77+
((DenseVectorScriptDocValues) doc.get(VECTOR_FIELD)).getVectorValue(),
78+
(float[]) vars.get(QUERY_VECTOR_PARAM)
79+
);
80+
});
81+
}
82+
}
83+
84+
@Before
85+
public void setup() throws IOException {
86+
String type = randomFrom(
87+
Arrays.stream(VectorIndexType.values())
88+
.filter(VectorIndexType::isQuantized)
89+
.map(t -> t.name().toLowerCase(Locale.ROOT))
90+
.collect(Collectors.toCollection(ArrayList::new))
91+
);
92+
XContentBuilder mapping = XContentFactory.jsonBuilder()
93+
.startObject()
94+
.startObject("properties")
95+
.startObject(VECTOR_FIELD)
96+
.field("type", "dense_vector")
97+
.field("similarity", "l2_norm")
98+
.startObject("index_options")
99+
.field("type", type)
100+
.endObject()
101+
.endObject()
102+
.endObject()
103+
.endObject();
104+
105+
Settings settings = Settings.builder()
106+
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
107+
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5))
108+
.build();
109+
prepareCreate(INDEX_NAME).setMapping(mapping).setSettings(settings).get();
110+
ensureGreen(INDEX_NAME);
111+
}
112+
113+
private record TestParams(
114+
int numDocs,
115+
int numDims,
116+
float[] queryVector,
117+
int k,
118+
int numCands,
119+
RescoreVectorBuilder rescoreVectorBuilder
120+
) {
121+
public static TestParams generate() {
122+
int numDims = randomIntBetween(32, 512) * 2; // Ensure even dimensions
123+
int numDocs = randomIntBetween(10, 100);
124+
int k = randomIntBetween(1, numDocs - 5);
125+
return new TestParams(
126+
numDocs,
127+
numDims,
128+
randomVector(numDims),
129+
k,
130+
(int) (k * randomFloatBetween(1.0f, 10.0f, true)),
131+
new RescoreVectorBuilder(randomFloatBetween(1.0f, 100f, true))
132+
);
133+
}
134+
}
135+
136+
public void testKnnSearchRescore() {
137+
BiFunction<TestParams, SearchRequestBuilder, SearchRequestBuilder> knnSearchGenerator = (testParams, requestBuilder) -> {
138+
KnnSearchBuilder knnSearch = new KnnSearchBuilder(
139+
VECTOR_FIELD,
140+
testParams.queryVector,
141+
testParams.k,
142+
testParams.numCands,
143+
testParams.rescoreVectorBuilder,
144+
null
145+
);
146+
return requestBuilder.setKnnSearch(List.of(knnSearch));
147+
};
148+
testKnnRescore(knnSearchGenerator);
149+
}
150+
151+
public void testKnnQueryRescore() {
152+
BiFunction<TestParams, SearchRequestBuilder, SearchRequestBuilder> knnQueryGenerator = (testParams, requestBuilder) -> {
153+
KnnVectorQueryBuilder knnQuery = new KnnVectorQueryBuilder(
154+
VECTOR_FIELD,
155+
testParams.queryVector,
156+
testParams.k,
157+
testParams.numCands,
158+
testParams.rescoreVectorBuilder,
159+
null
160+
);
161+
return requestBuilder.setQuery(knnQuery);
162+
};
163+
testKnnRescore(knnQueryGenerator);
164+
}
165+
166+
public void testKnnRetriever() {
167+
BiFunction<TestParams, SearchRequestBuilder, SearchRequestBuilder> knnQueryGenerator = (testParams, requestBuilder) -> {
168+
KnnRetrieverBuilder knnRetriever = new KnnRetrieverBuilder(
169+
VECTOR_FIELD,
170+
testParams.queryVector,
171+
null,
172+
testParams.k,
173+
testParams.numCands,
174+
testParams.rescoreVectorBuilder,
175+
null
176+
);
177+
return requestBuilder.setSource(new SearchSourceBuilder().retriever(knnRetriever));
178+
};
179+
testKnnRescore(knnQueryGenerator);
180+
}
181+
182+
private void testKnnRescore(BiFunction<TestParams, SearchRequestBuilder, SearchRequestBuilder> searchRequestGenerator) {
183+
TestParams testParams = TestParams.generate();
184+
185+
int numDocs = testParams.numDocs;
186+
IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];
187+
188+
for (int i = 0; i < numDocs; i++) {
189+
docs[i] = prepareIndex(INDEX_NAME).setId("" + i).setSource(VECTOR_FIELD, randomVector(testParams.numDims));
190+
}
191+
indexRandom(true, docs);
192+
193+
float[] queryVector = testParams.queryVector;
194+
float oversample = randomFloatBetween(1.0f, 100f, true);
195+
RescoreVectorBuilder rescoreVectorBuilder = new RescoreVectorBuilder(oversample);
196+
197+
SearchRequestBuilder requestBuilder = searchRequestGenerator.apply(
198+
testParams,
199+
prepareSearch(INDEX_NAME).setSize(numDocs).setTrackTotalHits(randomBoolean())
200+
);
201+
202+
assertNoFailuresAndResponse(requestBuilder, knnResponse -> { compareWithExactSearch(knnResponse, queryVector, numDocs); });
203+
}
204+
205+
private static void compareWithExactSearch(SearchResponse knnResponse, float[] queryVector, int docCount) {
206+
// Do an exact query and compare
207+
Script script = new Script(
208+
ScriptType.INLINE,
209+
CustomScriptPlugin.NAME,
210+
VECTOR_SCORE_SCRIPT,
211+
Map.of(QUERY_VECTOR_PARAM, queryVector)
212+
);
213+
ScriptScoreQueryBuilder scriptScoreQueryBuilder = QueryBuilders.scriptScoreQuery(new MatchAllQueryBuilder(), script);
214+
assertNoFailuresAndResponse(prepareSearch(INDEX_NAME).setQuery(scriptScoreQueryBuilder).setSize(docCount), exactResponse -> {
215+
assertHitCount(exactResponse, docCount);
216+
217+
int i = 0;
218+
SearchHit[] exactHits = exactResponse.getHits().getHits();
219+
for (SearchHit knnHit : knnResponse.getHits().getHits()) {
220+
while (i < exactHits.length && exactHits[i].getId().equals(knnHit.getId()) == false) {
221+
i++;
222+
}
223+
if (i >= exactHits.length) {
224+
fail("Knn doc not found in exact search");
225+
}
226+
assertThat("Real score is not the same as rescored score", knnHit.getScore(), equalTo(exactHits[i].getScore()));
227+
}
228+
});
229+
}
230+
231+
private static float[] randomVector(int numDimensions) {
232+
float[] vector = new float[numDimensions];
233+
for (int j = 0; j < numDimensions; j++) {
234+
vector[j] = randomFloatBetween(0, 1, true);
235+
}
236+
return vector;
237+
}
238+
}

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1225,7 +1225,7 @@ public final int hashCode() {
12251225
}
12261226
}
12271227

1228-
private enum VectorIndexType {
1228+
public enum VectorIndexType {
12291229
HNSW("hnsw", false) {
12301230
@Override
12311231
public IndexOptions parseIndexOptions(String fieldName, Map<String, ?> indexOptionsMap) {

server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,24 @@
1717
import org.apache.lucene.search.MatchNoDocsQuery;
1818
import org.apache.lucene.search.Query;
1919
import org.apache.lucene.search.QueryVisitor;
20+
import org.apache.lucene.search.ScoreDoc;
2021
import org.apache.lucene.search.ScoreMode;
2122
import org.apache.lucene.search.Scorer;
2223
import org.apache.lucene.search.ScorerSupplier;
2324
import org.apache.lucene.search.Weight;
2425

2526
import java.io.IOException;
2627
import java.util.Arrays;
28+
import java.util.Comparator;
2729
import java.util.Objects;
2830

2931
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
3032

3133
/**
3234
* A query that matches the provided docs with their scores.
3335
*
34-
* Note: this query was adapted from Lucene's DocAndScoreQuery from the class
36+
* Note: this query was originally adapted from Lucene's DocAndScoreQuery from the class
3537
* {@link org.apache.lucene.search.KnnFloatVectorQuery}, which is package-private.
36-
* There are no changes to the behavior, just some renames.
3738
*/
3839
public class KnnScoreDocQuery extends Query {
3940
private final int[] docs;
@@ -50,13 +51,18 @@ public class KnnScoreDocQuery extends Query {
5051
/**
5152
* Creates a query.
5253
*
53-
* @param docs the global doc IDs of documents that match, in ascending order
54-
* @param scores the scores of the matching documents
54+
* @param scoreDocs an array of ScoreDocs to use for the query
5555
* @param reader IndexReader
5656
*/
57-
KnnScoreDocQuery(int[] docs, float[] scores, IndexReader reader) {
58-
this.docs = docs;
59-
this.scores = scores;
57+
KnnScoreDocQuery(ScoreDoc[] scoreDocs, IndexReader reader) {
58+
// Ensure that the docs are sorted by docId, as they are later searched using binary search
59+
Arrays.sort(scoreDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc));
60+
this.docs = new int[scoreDocs.length];
61+
this.scores = new float[scoreDocs.length];
62+
for (int i = 0; i < scoreDocs.length; i++) {
63+
docs[i] = scoreDocs[i].doc;
64+
scores[i] = scoreDocs[i].score;
65+
}
6066
this.segmentStarts = findSegmentStarts(reader, docs);
6167
this.contextIdentity = reader.getContext().id();
6268
}

server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -141,15 +141,7 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep
141141

142142
@Override
143143
protected Query doToQuery(SearchExecutionContext context) throws IOException {
144-
int numDocs = scoreDocs.length;
145-
int[] docs = new int[numDocs];
146-
float[] scores = new float[numDocs];
147-
for (int i = 0; i < numDocs; i++) {
148-
docs[i] = scoreDocs[i].doc;
149-
scores[i] = scoreDocs[i].score;
150-
}
151-
152-
return new KnnScoreDocQuery(docs, scores, context.getIndexReader());
144+
return new KnnScoreDocQuery(scoreDocs, context.getIndexReader());
153145
}
154146

155147
@Override

server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,12 @@
1616
import org.apache.lucene.search.IndexSearcher;
1717
import org.apache.lucene.search.Query;
1818
import org.apache.lucene.search.QueryVisitor;
19-
import org.apache.lucene.search.ScoreDoc;
2019
import org.apache.lucene.search.TopDocs;
2120
import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource;
2221
import org.elasticsearch.search.profile.query.QueryProfiler;
2322

2423
import java.io.IOException;
2524
import java.util.Arrays;
26-
import java.util.Comparator;
2725
import java.util.Objects;
2826

2927
/**
@@ -60,16 +58,7 @@ public Query rewrite(IndexSearcher searcher) throws IOException {
6058
// Retrieve top k documents from the rescored query
6159
TopDocs topDocs = searcher.search(query, k);
6260
vectorOperations = topDocs.totalHits.value();
63-
ScoreDoc[] scoreDocs = topDocs.scoreDocs;
64-
Arrays.sort(scoreDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc));
65-
int[] docIds = new int[scoreDocs.length];
66-
float[] scores = new float[scoreDocs.length];
67-
for (int i = 0; i < scoreDocs.length; i++) {
68-
docIds[i] = scoreDocs[i].doc;
69-
scores[i] = scoreDocs[i].score;
70-
}
71-
72-
return new KnnScoreDocQuery(docIds, scores, searcher.getIndexReader());
61+
return new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader());
7362
}
7463

7564
public Query innerQuery() {

0 commit comments

Comments
 (0)