Skip to content

Commit 81384f2

Browse files
committed
Parameterize recore knn vector query tests
1 parent 257b75d commit 81384f2

File tree

1 file changed

+181
-40
lines changed

1 file changed

+181
-40
lines changed

server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java

Lines changed: 181 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,18 @@
99

1010
package org.elasticsearch.search.vectors;
1111

12+
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
13+
1214
import org.apache.lucene.document.Document;
15+
import org.apache.lucene.document.KnnByteVectorField;
1316
import org.apache.lucene.document.KnnFloatVectorField;
17+
import org.apache.lucene.index.ByteVectorValues;
1418
import org.apache.lucene.index.DirectoryReader;
1519
import org.apache.lucene.index.FloatVectorValues;
1620
import org.apache.lucene.index.IndexReader;
1721
import org.apache.lucene.index.IndexWriter;
1822
import org.apache.lucene.index.KnnVectorValues;
23+
import org.apache.lucene.index.LeafReader;
1924
import org.apache.lucene.index.LeafReaderContext;
2025
import org.apache.lucene.index.VectorSimilarityFunction;
2126
import org.apache.lucene.search.IndexSearcher;
@@ -24,9 +29,12 @@
2429
import org.apache.lucene.store.Directory;
2530
import org.elasticsearch.test.ESTestCase;
2631

32+
import java.io.IOException;
2733
import java.util.ArrayList;
2834
import java.util.Arrays;
2935
import java.util.Collection;
36+
import java.util.HashSet;
37+
import java.util.List;
3038
import java.util.Map;
3139
import java.util.PriorityQueue;
3240
import java.util.stream.Collectors;
@@ -37,65 +45,56 @@
3745
public class RescoreKnnVectorQueryTests extends ESTestCase {
3846

3947
public static final String FIELD_NAME = "float_vector";
48+
private final int numDocs;
49+
private final VectorProvider vectorProvider;
50+
private final Integer k;
4051

41-
public void testRescoresTopK() throws Exception {
42-
int numDocs = randomIntBetween(10, 100);
43-
testRescoreDocs(numDocs, randomIntBetween(5, numDocs - 1));
44-
}
45-
46-
public void testRescoresNoKParameter() throws Exception {
47-
testRescoreDocs(randomIntBetween(10, 100), null);
52+
public RescoreKnnVectorQueryTests(VectorProvider vectorProvider, boolean useK) {
53+
this.vectorProvider = vectorProvider;
54+
this.numDocs = randomIntBetween(10, 100);;
55+
this.k = useK ? randomIntBetween(1, numDocs - 1) : null;
4856
}
4957

50-
private void testRescoreDocs(int numDocs, Integer k) throws Exception {
58+
public void testRescoreDocs() throws Exception {
5159
int numDims = randomIntBetween(5, 100);
5260

61+
Integer adjustedK = k;
5362
if (k == null) {
54-
k = numDocs;
63+
adjustedK = numDocs;
5564
}
5665

5766
try (Directory d = newDirectory()) {
58-
try (IndexWriter w = new IndexWriter(d, newIndexWriterConfig())) {
59-
for (int i = 0; i < numDocs; i++) {
60-
Document document = new Document();
61-
float[] vector = randomVector(numDims);
62-
KnnFloatVectorField vectorField = new KnnFloatVectorField(FIELD_NAME, vector);
63-
document.add(vectorField);
64-
w.addDocument(document);
65-
}
66-
w.commit();
67-
w.forceMerge(1);
68-
}
67+
addRandomDocuments(numDocs, d, numDims, vectorProvider);
6968

7069
try (IndexReader reader = DirectoryReader.open(d)) {
71-
float[] queryVector = randomVector(numDims);
7270

73-
RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery(
74-
FIELD_NAME,
75-
queryVector,
76-
VectorSimilarityFunction.COSINE,
77-
k,
78-
new MatchAllDocsQuery()
79-
);
71+
// Use a RescoreKnnVectorQuery with a match all query, to ensure we get scoring of 1 from the inner query
72+
// and thus we're rescoring the top k docs.
73+
VectorData queryVector = vectorProvider.randomVector(numDims);
74+
RescoreKnnVectorQuery rescoreKnnVectorQuery = vectorProvider.createRescoreQuery(queryVector, adjustedK);
8075

8176
IndexSearcher searcher = newSearcher(reader, true, false);
8277
TopDocs docs = searcher.search(rescoreKnnVectorQuery, numDocs);
8378
Map<Integer, Float> rescoredDocs = Arrays.stream(docs.scoreDocs)
8479
.collect(Collectors.toMap(scoreDoc -> scoreDoc.doc, scoreDoc -> scoreDoc.score));
8580

86-
assertThat(rescoredDocs.size(), equalTo(k));
81+
assertThat(rescoredDocs.size(), equalTo(adjustedK));
82+
83+
Collection<Float> rescoredScores = new HashSet<>(rescoredDocs.values());
8784

88-
Collection<Float> rescoredScores = new ArrayList<>(rescoredDocs.values());
85+
// Collect all docs sequentially, and score them using the similarity function to get the top K scores
8986
PriorityQueue<Float> topK = new PriorityQueue<>((o1, o2) -> Float.compare(o2, o1));
9087

9188
for (LeafReaderContext leafReaderContext : reader.leaves()) {
92-
FloatVectorValues floatVectorValues = leafReaderContext.reader().getFloatVectorValues(FIELD_NAME);
93-
KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator();
89+
KnnVectorValues vectorValues = vectorProvider.vectorValues(leafReaderContext.reader());
90+
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
9491
while (iterator.nextDoc() != NO_MORE_DOCS) {
95-
float[] vector = floatVectorValues.vectorValue(iterator.index());
96-
float score = VectorSimilarityFunction.COSINE.compare(queryVector, vector);
92+
VectorData vectorData = vectorProvider.dataVectorForDoc(vectorValues, iterator.docID());
93+
float score = vectorProvider.score(queryVector, vectorData);
9794
topK.add(score);
9895
int docId = iterator.docID();
96+
// If the doc has been retrieved from the RescoreKnnVectorQuery, check the score is the same and remove it
97+
// to ensure we found them all
9998
if (rescoredDocs.containsKey(docId)) {
10099
assertThat(rescoredDocs.get(docId), equalTo(score));
101100
rescoredDocs.remove(docId);
@@ -106,7 +105,7 @@ private void testRescoreDocs(int numDocs, Integer k) throws Exception {
106105
assertThat(rescoredDocs.size(), equalTo(0));
107106

108107
// Check top scoring docs are contained in rescored docs
109-
for (int i = 0; i < k; i++) {
108+
for (int i = 0; i < adjustedK; i++) {
110109
Float topScore = topK.poll();
111110
if (rescoredScores.contains(topScore) == false) {
112111
fail("Top score " + topScore + " not contained in rescored doc scores " + rescoredScores);
@@ -116,12 +115,154 @@ private void testRescoreDocs(int numDocs, Integer k) throws Exception {
116115
}
117116
}
118117

119-
private static float[] randomVector(int numDims) {
120-
float[] vector = new float[numDims];
121-
for (int j = 0; j < numDims; j++) {
122-
vector[j] = randomFloatBetween(0, 1, true);
118+
private interface VectorProvider {
119+
VectorData randomVector(int numDimensions);
120+
121+
RescoreKnnVectorQuery createRescoreQuery(VectorData queryVector, Integer k);
122+
123+
KnnVectorValues vectorValues(LeafReader leafReader) throws IOException;
124+
125+
void addVectorField(Document document, VectorData vector);
126+
127+
VectorData dataVectorForDoc(KnnVectorValues vectorValues, int docId) throws IOException;
128+
129+
float score(VectorData queryVector, VectorData dataVector);
130+
}
131+
132+
private static class FloatVectorProvider implements VectorProvider {
133+
@Override
134+
public VectorData randomVector(int numDimensions) {
135+
float[] vector = new float[numDimensions];
136+
for (int j = 0; j < numDimensions; j++) {
137+
vector[j] = randomFloatBetween(0, 1, true);
138+
}
139+
return VectorData.fromFloats(vector);
140+
}
141+
142+
@Override
143+
public RescoreKnnVectorQuery createRescoreQuery(VectorData queryVector, Integer k) {
144+
return new RescoreKnnVectorQuery(
145+
FIELD_NAME,
146+
queryVector.floatVector(),
147+
VectorSimilarityFunction.COSINE,
148+
k,
149+
new MatchAllDocsQuery()
150+
);
151+
}
152+
153+
@Override
154+
public KnnVectorValues vectorValues(LeafReader leafReader) throws IOException {
155+
return leafReader.getFloatVectorValues(FIELD_NAME);
156+
}
157+
158+
@Override
159+
public void addVectorField(Document document, VectorData vector) {
160+
KnnFloatVectorField vectorField = new KnnFloatVectorField(FIELD_NAME, vector.floatVector());
161+
document.add(vectorField);
162+
}
163+
164+
@Override
165+
public VectorData dataVectorForDoc(KnnVectorValues vectorValues, int docId) throws IOException {
166+
return VectorData.fromFloats(((FloatVectorValues)vectorValues).vectorValue(docId));
167+
}
168+
169+
@Override
170+
public float score(VectorData queryVector, VectorData dataVector) {
171+
return VectorSimilarityFunction.COSINE.compare(queryVector.floatVector(), dataVector.floatVector());
123172
}
124-
return vector;
125173
}
126174

175+
private static class ByteVectorProvider implements VectorProvider {
176+
@Override
177+
public VectorData randomVector(int numDimensions) {
178+
byte[] vector = new byte[numDimensions];
179+
for (int j = 0; j < numDimensions; j++) {
180+
vector[j] = randomByte();
181+
}
182+
return VectorData.fromBytes(vector);
183+
}
184+
185+
@Override
186+
public RescoreKnnVectorQuery createRescoreQuery(VectorData queryVector, Integer k) {
187+
return new RescoreKnnVectorQuery(
188+
FIELD_NAME,
189+
queryVector.byteVector(),
190+
VectorSimilarityFunction.COSINE,
191+
k,
192+
new MatchAllDocsQuery()
193+
);
194+
}
195+
196+
@Override
197+
public KnnVectorValues vectorValues(LeafReader leafReader) throws IOException {
198+
return leafReader.getByteVectorValues(FIELD_NAME);
199+
}
200+
201+
@Override
202+
public void addVectorField(Document document, VectorData vector) {
203+
KnnByteVectorField vectorField = new KnnByteVectorField(FIELD_NAME, vector.byteVector());
204+
document.add(vectorField);
205+
}
206+
207+
@Override
208+
public VectorData dataVectorForDoc(KnnVectorValues vectorValues, int docId) throws IOException {
209+
return VectorData.fromBytes(((ByteVectorValues)vectorValues).vectorValue(docId));
210+
}
211+
212+
@Override
213+
public float score(VectorData queryVector, VectorData dataVector) {
214+
return VectorSimilarityFunction.COSINE.compare(queryVector.byteVector(), dataVector.byteVector());
215+
}
216+
}
217+
218+
private static void addRandomDocuments(int numDocs, Directory d, int numDims, VectorProvider vectorProvider) throws IOException {
219+
try (IndexWriter w = new IndexWriter(d, newIndexWriterConfig())) {
220+
for (int i = 0; i < numDocs; i++) {
221+
Document document = new Document();
222+
VectorData vector = vectorProvider.randomVector(numDims);
223+
vectorProvider.addVectorField(document, vector);
224+
w.addDocument(document);
225+
}
226+
w.commit();
227+
w.forceMerge(1);
228+
}
229+
}
230+
231+
@ParametersFactory
232+
public static Iterable<Object[]> parameters() {
233+
234+
List<Object[]> params = new ArrayList<>();
235+
params.add(new Object[] {new FloatVectorProvider(), true});
236+
params.add(new Object[] {new FloatVectorProvider(), false});
237+
params.add(new Object[] {new ByteVectorProvider(), true});
238+
params.add(new Object[] {new ByteVectorProvider(), false});
239+
240+
return params;
241+
}
242+
243+
// public void testProfiling() throws Exception {
244+
// int numDocs = randomIntBetween(10, 100);
245+
// int numDims = randomIntBetween(5, 100);
246+
//
247+
// try (Directory d = newDirectory()) {
248+
// addRandomDocuments(numDocs, d, numDims, vectorProvider);
249+
//
250+
// try (IndexReader reader = DirectoryReader.open(d)) {
251+
// float[] queryVector = randomVector(numDims);
252+
//
253+
// RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery(
254+
// FIELD_NAME,
255+
// queryVector,
256+
// VectorSimilarityFunction.COSINE,
257+
// randomIntBetween(5, numDocs - 1),
258+
// new MatchAllDocsQuery()
259+
// );
260+
//
261+
// IndexSearcher searcher = newSearcher(reader, true, false);
262+
// QueryProfiler queryProfiler = new QueryProfiler();
263+
// rescoreKnnVectorQuery.profile(queryProfiler);
264+
// }
265+
// }
266+
// }
267+
127268
}

0 commit comments

Comments
 (0)