Skip to content

Commit 2aa011d

Browse files
authored
[8.x] Add rescore knn vector test coverage (#122801) (#123239)
* Add rescore knn vector test coverage (#122801) (cherry picked from commit f5e2a92) # Conflicts: # server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java # server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java * Fix merge for 8.x
1 parent e774703 commit 2aa011d

File tree

6 files changed

+316
-75
lines changed

6 files changed

+316
-75
lines changed
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.query;
11+
12+
import org.apache.lucene.index.VectorSimilarityFunction;
13+
import org.elasticsearch.action.index.IndexRequestBuilder;
14+
import org.elasticsearch.action.search.SearchRequestBuilder;
15+
import org.elasticsearch.action.search.SearchResponse;
16+
import org.elasticsearch.cluster.metadata.IndexMetadata;
17+
import org.elasticsearch.common.settings.Settings;
18+
import org.elasticsearch.index.IndexVersion;
19+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
20+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.VectorIndexType;
21+
import org.elasticsearch.index.mapper.vectors.DenseVectorScriptDocValues;
22+
import org.elasticsearch.index.query.MatchAllQueryBuilder;
23+
import org.elasticsearch.index.query.QueryBuilders;
24+
import org.elasticsearch.index.query.functionscore.ScriptScoreQueryBuilder;
25+
import org.elasticsearch.plugins.Plugin;
26+
import org.elasticsearch.script.MockScriptPlugin;
27+
import org.elasticsearch.script.Script;
28+
import org.elasticsearch.script.ScriptType;
29+
import org.elasticsearch.search.SearchHit;
30+
import org.elasticsearch.search.builder.SearchSourceBuilder;
31+
import org.elasticsearch.search.retriever.KnnRetrieverBuilder;
32+
import org.elasticsearch.search.vectors.KnnSearchBuilder;
33+
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
34+
import org.elasticsearch.search.vectors.RescoreVectorBuilder;
35+
import org.elasticsearch.test.ESIntegTestCase;
36+
import org.elasticsearch.xcontent.XContentBuilder;
37+
import org.elasticsearch.xcontent.XContentFactory;
38+
import org.junit.Before;
39+
40+
import java.io.IOException;
41+
import java.util.ArrayList;
42+
import java.util.Arrays;
43+
import java.util.Collection;
44+
import java.util.Collections;
45+
import java.util.List;
46+
import java.util.Locale;
47+
import java.util.Map;
48+
import java.util.function.BiFunction;
49+
import java.util.function.Function;
50+
import java.util.stream.Collectors;
51+
52+
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount;
53+
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailuresAndResponse;
54+
import static org.hamcrest.Matchers.equalTo;
55+
56+
public class RescoreKnnVectorQueryIT extends ESIntegTestCase {
57+
58+
public static final String INDEX_NAME = "test";
59+
public static final String VECTOR_FIELD = "vector";
60+
public static final String VECTOR_SCORE_SCRIPT = "vector_scoring";
61+
public static final String QUERY_VECTOR_PARAM = "query_vector";
62+
63+
@Override
64+
protected Collection<Class<? extends Plugin>> nodePlugins() {
65+
return Collections.singleton(CustomScriptPlugin.class);
66+
}
67+
68+
public static class CustomScriptPlugin extends MockScriptPlugin {
69+
private static final VectorSimilarityFunction SIMILARITY_FUNCTION = DenseVectorFieldMapper.VectorSimilarity.L2_NORM
70+
.vectorSimilarityFunction(IndexVersion.current(), DenseVectorFieldMapper.ElementType.FLOAT);
71+
72+
@Override
73+
protected Map<String, Function<Map<String, Object>, Object>> pluginScripts() {
74+
return Map.of(VECTOR_SCORE_SCRIPT, vars -> {
75+
Map<?, ?> doc = (Map<?, ?>) vars.get("doc");
76+
return SIMILARITY_FUNCTION.compare(
77+
((DenseVectorScriptDocValues) doc.get(VECTOR_FIELD)).getVectorValue(),
78+
(float[]) vars.get(QUERY_VECTOR_PARAM)
79+
);
80+
});
81+
}
82+
}
83+
84+
@Before
85+
public void setup() throws IOException {
86+
String type = randomFrom(
87+
Arrays.stream(VectorIndexType.values())
88+
.filter(VectorIndexType::isQuantized)
89+
.map(t -> t.name().toLowerCase(Locale.ROOT))
90+
.collect(Collectors.toCollection(ArrayList::new))
91+
);
92+
XContentBuilder mapping = XContentFactory.jsonBuilder()
93+
.startObject()
94+
.startObject("properties")
95+
.startObject(VECTOR_FIELD)
96+
.field("type", "dense_vector")
97+
.field("similarity", "l2_norm")
98+
.startObject("index_options")
99+
.field("type", type)
100+
.endObject()
101+
.endObject()
102+
.endObject()
103+
.endObject();
104+
105+
Settings settings = Settings.builder()
106+
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
107+
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5))
108+
.build();
109+
prepareCreate(INDEX_NAME).setMapping(mapping).setSettings(settings).get();
110+
ensureGreen(INDEX_NAME);
111+
}
112+
113+
private record TestParams(
114+
int numDocs,
115+
int numDims,
116+
float[] queryVector,
117+
int k,
118+
int numCands,
119+
RescoreVectorBuilder rescoreVectorBuilder
120+
) {
121+
public static TestParams generate() {
122+
int numDims = randomIntBetween(32, 512) * 2; // Ensure even dimensions
123+
int numDocs = randomIntBetween(10, 100);
124+
int k = randomIntBetween(1, numDocs - 5);
125+
return new TestParams(
126+
numDocs,
127+
numDims,
128+
randomVector(numDims),
129+
k,
130+
(int) (k * randomFloatBetween(1.0f, 10.0f, true)),
131+
new RescoreVectorBuilder(randomFloatBetween(1.0f, 100f, true))
132+
);
133+
}
134+
}
135+
136+
public void testKnnSearchRescore() {
137+
BiFunction<TestParams, SearchRequestBuilder, SearchRequestBuilder> knnSearchGenerator = (testParams, requestBuilder) -> {
138+
KnnSearchBuilder knnSearch = new KnnSearchBuilder(
139+
VECTOR_FIELD,
140+
testParams.queryVector,
141+
testParams.k,
142+
testParams.numCands,
143+
testParams.rescoreVectorBuilder,
144+
null
145+
);
146+
return requestBuilder.setKnnSearch(List.of(knnSearch));
147+
};
148+
testKnnRescore(knnSearchGenerator);
149+
}
150+
151+
public void testKnnQueryRescore() {
152+
BiFunction<TestParams, SearchRequestBuilder, SearchRequestBuilder> knnQueryGenerator = (testParams, requestBuilder) -> {
153+
KnnVectorQueryBuilder knnQuery = new KnnVectorQueryBuilder(
154+
VECTOR_FIELD,
155+
testParams.queryVector,
156+
testParams.k,
157+
testParams.numCands,
158+
testParams.rescoreVectorBuilder,
159+
null
160+
);
161+
return requestBuilder.setQuery(knnQuery);
162+
};
163+
testKnnRescore(knnQueryGenerator);
164+
}
165+
166+
public void testKnnRetriever() {
167+
BiFunction<TestParams, SearchRequestBuilder, SearchRequestBuilder> knnQueryGenerator = (testParams, requestBuilder) -> {
168+
KnnRetrieverBuilder knnRetriever = new KnnRetrieverBuilder(
169+
VECTOR_FIELD,
170+
testParams.queryVector,
171+
null,
172+
testParams.k,
173+
testParams.numCands,
174+
testParams.rescoreVectorBuilder,
175+
null
176+
);
177+
return requestBuilder.setSource(new SearchSourceBuilder().retriever(knnRetriever));
178+
};
179+
testKnnRescore(knnQueryGenerator);
180+
}
181+
182+
private void testKnnRescore(BiFunction<TestParams, SearchRequestBuilder, SearchRequestBuilder> searchRequestGenerator) {
183+
TestParams testParams = TestParams.generate();
184+
185+
int numDocs = testParams.numDocs;
186+
IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];
187+
188+
for (int i = 0; i < numDocs; i++) {
189+
docs[i] = prepareIndex(INDEX_NAME).setId("" + i).setSource(VECTOR_FIELD, randomVector(testParams.numDims));
190+
}
191+
indexRandom(true, docs);
192+
193+
float[] queryVector = testParams.queryVector;
194+
float oversample = randomFloatBetween(1.0f, 100f, true);
195+
RescoreVectorBuilder rescoreVectorBuilder = new RescoreVectorBuilder(oversample);
196+
197+
SearchRequestBuilder requestBuilder = searchRequestGenerator.apply(
198+
testParams,
199+
prepareSearch(INDEX_NAME).setSize(numDocs).setTrackTotalHits(randomBoolean())
200+
);
201+
202+
assertNoFailuresAndResponse(requestBuilder, knnResponse -> { compareWithExactSearch(knnResponse, queryVector, numDocs); });
203+
}
204+
205+
private static void compareWithExactSearch(SearchResponse knnResponse, float[] queryVector, int docCount) {
206+
// Do an exact query and compare
207+
Script script = new Script(
208+
ScriptType.INLINE,
209+
CustomScriptPlugin.NAME,
210+
VECTOR_SCORE_SCRIPT,
211+
Map.of(QUERY_VECTOR_PARAM, queryVector)
212+
);
213+
ScriptScoreQueryBuilder scriptScoreQueryBuilder = QueryBuilders.scriptScoreQuery(new MatchAllQueryBuilder(), script);
214+
assertNoFailuresAndResponse(prepareSearch(INDEX_NAME).setQuery(scriptScoreQueryBuilder).setSize(docCount), exactResponse -> {
215+
assertHitCount(exactResponse, docCount);
216+
217+
int i = 0;
218+
SearchHit[] exactHits = exactResponse.getHits().getHits();
219+
for (SearchHit knnHit : knnResponse.getHits().getHits()) {
220+
while (i < exactHits.length && exactHits[i].getId().equals(knnHit.getId()) == false) {
221+
i++;
222+
}
223+
if (i >= exactHits.length) {
224+
fail("Knn doc not found in exact search");
225+
}
226+
assertThat("Real score is not the same as rescored score", knnHit.getScore(), equalTo(exactHits[i].getScore()));
227+
}
228+
});
229+
}
230+
231+
private static float[] randomVector(int numDimensions) {
232+
float[] vector = new float[numDimensions];
233+
for (int j = 0; j < numDimensions; j++) {
234+
vector[j] = randomFloatBetween(0, 1, true);
235+
}
236+
return vector;
237+
}
238+
}

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1228,7 +1228,7 @@ public final int hashCode() {
12281228
}
12291229
}
12301230

1231-
private enum VectorIndexType {
1231+
public enum VectorIndexType {
12321232
HNSW("hnsw", false) {
12331233
@Override
12341234
public IndexOptions parseIndexOptions(String fieldName, Map<String, ?> indexOptionsMap) {

server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,23 @@
1717
import org.apache.lucene.search.MatchNoDocsQuery;
1818
import org.apache.lucene.search.Query;
1919
import org.apache.lucene.search.QueryVisitor;
20+
import org.apache.lucene.search.ScoreDoc;
2021
import org.apache.lucene.search.ScoreMode;
2122
import org.apache.lucene.search.Scorer;
2223
import org.apache.lucene.search.Weight;
2324

2425
import java.io.IOException;
2526
import java.util.Arrays;
27+
import java.util.Comparator;
2628
import java.util.Objects;
2729

2830
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
2931

3032
/**
3133
* A query that matches the provided docs with their scores.
3234
*
33-
* Note: this query was adapted from Lucene's DocAndScoreQuery from the class
35+
* Note: this query was originally adapted from Lucene's DocAndScoreQuery from the class
3436
* {@link org.apache.lucene.search.KnnFloatVectorQuery}, which is package-private.
35-
* There are no changes to the behavior, just some renames.
3637
*/
3738
public class KnnScoreDocQuery extends Query {
3839
private final int[] docs;
@@ -49,13 +50,18 @@ public class KnnScoreDocQuery extends Query {
4950
/**
5051
* Creates a query.
5152
*
52-
* @param docs the global doc IDs of documents that match, in ascending order
53-
* @param scores the scores of the matching documents
53+
* @param scoreDocs an array of ScoreDocs to use for the query
5454
* @param reader IndexReader
5555
*/
56-
KnnScoreDocQuery(int[] docs, float[] scores, IndexReader reader) {
57-
this.docs = docs;
58-
this.scores = scores;
56+
KnnScoreDocQuery(ScoreDoc[] scoreDocs, IndexReader reader) {
57+
// Ensure that the docs are sorted by docId, as they are later searched using binary search
58+
Arrays.sort(scoreDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc));
59+
this.docs = new int[scoreDocs.length];
60+
this.scores = new float[scoreDocs.length];
61+
for (int i = 0; i < scoreDocs.length; i++) {
62+
docs[i] = scoreDocs[i].doc;
63+
scores[i] = scoreDocs[i].score;
64+
}
5965
this.segmentStarts = findSegmentStarts(reader, docs);
6066
this.contextIdentity = reader.getContext().id();
6167
}

server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -141,15 +141,7 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep
141141

142142
@Override
143143
protected Query doToQuery(SearchExecutionContext context) throws IOException {
144-
int numDocs = scoreDocs.length;
145-
int[] docs = new int[numDocs];
146-
float[] scores = new float[numDocs];
147-
for (int i = 0; i < numDocs; i++) {
148-
docs[i] = scoreDocs[i].doc;
149-
scores[i] = scoreDocs[i].score;
150-
}
151-
152-
return new KnnScoreDocQuery(docs, scores, context.getIndexReader());
144+
return new KnnScoreDocQuery(scoreDocs, context.getIndexReader());
153145
}
154146

155147
@Override

server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,12 @@
1616
import org.apache.lucene.search.IndexSearcher;
1717
import org.apache.lucene.search.Query;
1818
import org.apache.lucene.search.QueryVisitor;
19-
import org.apache.lucene.search.ScoreDoc;
2019
import org.apache.lucene.search.TopDocs;
2120
import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource;
2221
import org.elasticsearch.search.profile.query.QueryProfiler;
2322

2423
import java.io.IOException;
2524
import java.util.Arrays;
26-
import java.util.Comparator;
2725
import java.util.Objects;
2826

2927
/**
@@ -60,16 +58,7 @@ public Query rewrite(IndexSearcher searcher) throws IOException {
6058
// Retrieve top k documents from the rescored query
6159
TopDocs topDocs = searcher.search(query, k);
6260
vectorOperations = topDocs.totalHits.value;
63-
ScoreDoc[] scoreDocs = topDocs.scoreDocs;
64-
Arrays.sort(scoreDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc));
65-
int[] docIds = new int[scoreDocs.length];
66-
float[] scores = new float[scoreDocs.length];
67-
for (int i = 0; i < scoreDocs.length; i++) {
68-
docIds[i] = scoreDocs[i].doc;
69-
scores[i] = scoreDocs[i].score;
70-
}
71-
72-
return new KnnScoreDocQuery(docIds, scores, searcher.getIndexReader());
61+
return new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader());
7362
}
7463

7564
public Query innerQuery() {

0 commit comments

Comments
 (0)