Skip to content

Commit 9946e8d

Browse files
committed
Add tests for RescoreKnnVectorQuery
1 parent 39e1676 commit 9946e8d

File tree

1 file changed

+126
-0
lines changed

1 file changed

+126
-0
lines changed
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.vectors;
11+
12+
13+
import org.apache.lucene.document.Document;
14+
import org.apache.lucene.document.KnnFloatVectorField;
15+
import org.apache.lucene.index.DirectoryReader;
16+
import org.apache.lucene.index.FloatVectorValues;
17+
import org.apache.lucene.index.IndexReader;
18+
import org.apache.lucene.index.IndexWriter;
19+
import org.apache.lucene.index.KnnVectorValues;
20+
import org.apache.lucene.index.LeafReaderContext;
21+
import org.apache.lucene.index.VectorSimilarityFunction;
22+
import org.apache.lucene.search.IndexSearcher;
23+
import org.apache.lucene.search.MatchAllDocsQuery;
24+
import org.apache.lucene.search.TopDocs;
25+
import org.apache.lucene.store.Directory;
26+
import org.elasticsearch.test.ESTestCase;
27+
28+
import java.util.ArrayList;
29+
import java.util.Arrays;
30+
import java.util.Collection;
31+
import java.util.Map;
32+
import java.util.PriorityQueue;
33+
import java.util.stream.Collectors;
34+
35+
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
36+
import static org.hamcrest.Matchers.equalTo;
37+
38+
public class RescoreKnnVectorQueryTests extends ESTestCase {
39+
40+
public static final String FIELD_NAME = "float_vector";
41+
42+
public void testRescoresTopK() throws Exception {
43+
int numDocs = randomIntBetween(10, 100);
44+
testRescoreDocs(numDocs, randomIntBetween(5, numDocs - 1));
45+
}
46+
47+
public void testRescoresNoKParameter() throws Exception {
48+
testRescoreDocs(randomIntBetween(10, 100), null);
49+
}
50+
51+
private void testRescoreDocs(int numDocs, Integer k) throws Exception {
52+
int numDims = randomIntBetween(5, 100);
53+
54+
if (k == null) {
55+
k = numDocs;
56+
}
57+
58+
try (Directory d = newDirectory()) {
59+
try (IndexWriter w = new IndexWriter(d, newIndexWriterConfig())) {
60+
for (int i = 0; i < numDocs; i++) {
61+
Document document = new Document();
62+
float[] vector = randomVector(numDims);
63+
KnnFloatVectorField vectorField = new KnnFloatVectorField(
64+
FIELD_NAME, vector);
65+
document.add(vectorField);
66+
w.addDocument(document);
67+
}
68+
w.commit();
69+
w.forceMerge(1);
70+
}
71+
72+
try (IndexReader reader = DirectoryReader.open(d)) {
73+
float[] queryVector = randomVector(numDims);
74+
75+
RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery(
76+
FIELD_NAME, queryVector, VectorSimilarityFunction.COSINE, k, new MatchAllDocsQuery());
77+
78+
IndexSearcher searcher = newSearcher(reader, true, false);
79+
TopDocs docs = searcher.search(rescoreKnnVectorQuery, numDocs);
80+
Map<Integer, Float> rescoredDocs = Arrays.stream(docs.scoreDocs).collect(Collectors.toMap(
81+
scoreDoc -> scoreDoc.doc,
82+
scoreDoc -> scoreDoc.score)
83+
);
84+
85+
assertThat(rescoredDocs.size(), equalTo(k));
86+
87+
Collection<Float> rescoredScores = new ArrayList<>(rescoredDocs.values());
88+
PriorityQueue<Float> topK = new PriorityQueue<>((o1, o2) -> Float.compare(o2, o1));
89+
90+
for (LeafReaderContext leafReaderContext : reader.leaves()) {
91+
FloatVectorValues floatVectorValues = leafReaderContext.reader().getFloatVectorValues(FIELD_NAME);
92+
KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator();
93+
while (iterator.nextDoc() != NO_MORE_DOCS) {
94+
float[] vector = floatVectorValues.vectorValue(iterator.index());
95+
float score = VectorSimilarityFunction.COSINE.compare(queryVector, vector);
96+
topK.add(score);
97+
int docId = iterator.docID();
98+
if (rescoredDocs.containsKey(docId)) {
99+
assertThat(rescoredDocs.get(docId), equalTo(score));
100+
rescoredDocs.remove(docId);
101+
}
102+
}
103+
}
104+
105+
assertThat(rescoredDocs.size(), equalTo(0));
106+
107+
// Check top scoring docs are contained in rescored docs
108+
for (int i = 0; i < k; i++) {
109+
Float topScore = topK.poll();
110+
if (rescoredScores.contains(topScore) == false ) {
111+
fail("Top score " + topScore + " not contained in rescored doc scores " + rescoredScores);
112+
}
113+
}
114+
}
115+
}
116+
}
117+
118+
private static float[] randomVector(int numDims) {
119+
float[] vector = new float[numDims];
120+
for (int j = 0; j < numDims; j++) {
121+
vector[j] = randomFloatBetween(0, 1, true);
122+
}
123+
return vector;
124+
}
125+
126+
}

0 commit comments

Comments
 (0)