Skip to content

Commit df32233

Browse files
committed
test for lateI rescorer
1 parent 9108ad1 commit df32233

File tree

2 files changed

+197
-0
lines changed

2 files changed

+197
-0
lines changed

lucene/core/src/java/org/apache/lucene/search/LateInteractionRescorer.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
118
package org.apache.lucene.search;
219

320
import org.apache.lucene.index.VectorSimilarityFunction;
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.lucene.search;
19+
20+
import static org.apache.lucene.search.LateInteractionFloatValuesSource.ScoreFunction;
21+
22+
import java.io.IOException;
23+
import java.util.ArrayList;
24+
import java.util.Arrays;
25+
import java.util.List;
26+
import java.util.Random;
27+
import java.util.Set;
28+
import java.util.stream.Collectors;
29+
import org.apache.lucene.document.Document;
30+
import org.apache.lucene.document.Field;
31+
import org.apache.lucene.document.IntField;
32+
import org.apache.lucene.document.KnnFloatVectorField;
33+
import org.apache.lucene.document.LateInteractionField;
34+
import org.apache.lucene.index.DirectoryReader;
35+
import org.apache.lucene.index.IndexReader;
36+
import org.apache.lucene.index.IndexWriter;
37+
import org.apache.lucene.index.StoredFields;
38+
import org.apache.lucene.index.VectorSimilarityFunction;
39+
import org.apache.lucene.store.Directory;
40+
import org.apache.lucene.tests.util.LuceneTestCase;
41+
42+
public class TestLateInteractionRescorer extends LuceneTestCase {
43+
44+
private final String LATE_I_FIELD = "li_vector";
45+
private final String KNN_FIELD = "knn_vector";
46+
private final int DIMENSION = 128;
47+
48+
public void testBasic() throws Exception {
49+
List<float[][]> corpus = new ArrayList<>();
50+
final VectorSimilarityFunction vectorSimilarityFunction =
51+
VectorSimilarityFunction.values()[
52+
random().nextInt(VectorSimilarityFunction.values().length)];
53+
ScoreFunction scoreFunction = ScoreFunction.SUM_MAX_SIM;
54+
55+
try (Directory dir = newDirectory()) {
56+
indexMultiVectors(dir, corpus);
57+
float[][] lateIQueryVector = createMultiVector(DIMENSION);
58+
float[] knnQueryVector = randomFloatVector(DIMENSION, random());
59+
KnnFloatVectorQuery knnQuery = new KnnFloatVectorQuery(KNN_FIELD, knnQueryVector, 50);
60+
61+
try (IndexReader reader = DirectoryReader.open(dir)) {
62+
final int topN = 10;
63+
IndexSearcher s = new IndexSearcher(reader);
64+
TopDocs knnHits = s.search(knnQuery, 5 * topN);
65+
LateInteractionRescorer rescorer =
66+
LateInteractionRescorer.create(
67+
LATE_I_FIELD, lateIQueryVector, vectorSimilarityFunction);
68+
TopDocs rerankedHits = rescorer.rescore(s, knnHits, topN);
69+
Set<Integer> knnHitDocs =
70+
Arrays.stream(knnHits.scoreDocs).map(k -> k.doc).collect(Collectors.toSet());
71+
assertEquals(topN, rerankedHits.scoreDocs.length);
72+
StoredFields storedFields = reader.storedFields();
73+
for (int i = 0; i < rerankedHits.scoreDocs.length; i++) {
74+
assertTrue(knnHitDocs.contains(rerankedHits.scoreDocs[i].doc));
75+
int idValue =
76+
Integer.parseInt(storedFields.document(rerankedHits.scoreDocs[i].doc).get("id"));
77+
float[][] docVector = corpus.get(idValue);
78+
float expected =
79+
scoreFunction.compare(lateIQueryVector, docVector, vectorSimilarityFunction);
80+
assertEquals(expected, rerankedHits.scoreDocs[i].score, 1e-5);
81+
if (i > 0) {
82+
assertTrue(rerankedHits.scoreDocs[i].score <= rerankedHits.scoreDocs[i - 1].score);
83+
}
84+
}
85+
}
86+
}
87+
}
88+
89+
public void testMissingLateIValues() throws Exception {
90+
List<float[][]> corpus = new ArrayList<>();
91+
final VectorSimilarityFunction vectorSimilarityFunction =
92+
VectorSimilarityFunction.values()[
93+
random().nextInt(VectorSimilarityFunction.values().length)];
94+
95+
try (Directory dir = newDirectory()) {
96+
indexMultiVectors(dir, corpus);
97+
float[][] lateIQueryVector = createMultiVector(DIMENSION);
98+
float[] knnQueryVector = randomFloatVector(DIMENSION, random());
99+
KnnFloatVectorQuery knnQuery = new KnnFloatVectorQuery(KNN_FIELD, knnQueryVector, 50);
100+
101+
try (IndexReader reader = DirectoryReader.open(dir)) {
102+
final int topN = 10;
103+
IndexSearcher s = new IndexSearcher(reader);
104+
TopDocs knnHits = s.search(knnQuery, 5 * topN);
105+
LateInteractionRescorer rescorer =
106+
LateInteractionRescorer.create(
107+
"bad-test-field", lateIQueryVector, vectorSimilarityFunction);
108+
TopDocs rerankedHits = rescorer.rescore(s, knnHits, topN);
109+
Set<Integer> knnHitDocs =
110+
Arrays.stream(knnHits.scoreDocs).map(k -> k.doc).collect(Collectors.toSet());
111+
assertEquals(topN, rerankedHits.scoreDocs.length);
112+
for (int i = 0; i < rerankedHits.scoreDocs.length; i++) {
113+
assertTrue(knnHitDocs.contains(rerankedHits.scoreDocs[i].doc));
114+
assertEquals(0f, rerankedHits.scoreDocs[i].score, 1e-5);
115+
}
116+
117+
LateInteractionRescorer rescorerWithFallback =
118+
LateInteractionRescorer.withFallbackToFirstPassScore(
119+
"bad-test-field", lateIQueryVector, vectorSimilarityFunction);
120+
knnHits = s.search(knnQuery, 5 * topN);
121+
rerankedHits = rescorerWithFallback.rescore(s, knnHits, topN);
122+
knnHitDocs = Arrays.stream(knnHits.scoreDocs).map(k -> k.doc).collect(Collectors.toSet());
123+
assertEquals(topN, rerankedHits.scoreDocs.length);
124+
for (int i = 0; i < rerankedHits.scoreDocs.length; i++) {
125+
assertTrue(knnHitDocs.contains(rerankedHits.scoreDocs[i].doc));
126+
assertEquals(knnHits.scoreDocs[i].score, rerankedHits.scoreDocs[i].score, 1e-5);
127+
}
128+
}
129+
}
130+
}
131+
132+
private void indexMultiVectors(Directory dir, List<float[][]> corpus) throws IOException {
133+
final int numDocs = atLeast(1000);
134+
final int numSegments = random().nextInt(2, 10);
135+
int id = 0;
136+
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
137+
for (int j = 0; j < numSegments; j++) {
138+
for (int i = 0; i < numDocs; i++) {
139+
Document doc = new Document();
140+
if (random().nextInt(100) < 30) {
141+
// skip value for some docs to create sparse field
142+
doc.add(new IntField("has_li_vector", 0, Field.Store.YES));
143+
} else {
144+
float[][] value = createMultiVector(DIMENSION);
145+
corpus.add(value);
146+
doc.add(new IntField("id", id++, Field.Store.YES));
147+
doc.add(new LateInteractionField(LATE_I_FIELD, value));
148+
doc.add(new KnnFloatVectorField(KNN_FIELD, randomFloatVector(DIMENSION, random())));
149+
doc.add(new IntField("has_li_vector", 1, Field.Store.YES));
150+
}
151+
w.addDocument(doc);
152+
w.flush();
153+
}
154+
}
155+
// add a segment with no vectors
156+
for (int i = 0; i < 100; i++) {
157+
Document doc = new Document();
158+
doc.add(new IntField("has_li_vector", 0, Field.Store.YES));
159+
w.addDocument(doc);
160+
}
161+
w.flush();
162+
}
163+
}
164+
165+
private float[][] createMultiVector(int dimension) {
166+
float[][] value = new float[random().nextInt(3, 12)][];
167+
for (int i = 0; i < value.length; i++) {
168+
value[i] = randomFloatVector(dimension, random());
169+
}
170+
return value;
171+
}
172+
173+
private float[] randomFloatVector(int dimension, Random random) {
174+
float[] vector = new float[dimension];
175+
for (int i = 0; i < dimension; i++) {
176+
vector[i] = random.nextFloat();
177+
}
178+
return vector;
179+
}
180+
}

0 commit comments

Comments
 (0)