Skip to content

Commit 9108ad1

Browse files
committed
late I recore query with test and tidy
1 parent 9b7f9bb commit 9108ad1

File tree

5 files changed

+124
-146
lines changed

5 files changed

+124
-146
lines changed

lucene/core/src/java/org/apache/lucene/search/LateInteractionRescorer.java

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
* Rescores top N results from a first pass query using a {@link LateInteractionFloatValuesSource}
77
*
88
* <p>Typically, you run a low-cost first pass query to collect results from across the index, then
9-
* use this rescorer to rerank top N hits using multi-vectors, usually from a late interaction model.
10-
* Multi-vectors should be indexed in the {@link org.apache.lucene.document.LateInteractionField}
11-
* provided to rescorer.
9+
* use this rescorer to rerank top N hits using multi-vectors, usually from a late interaction
10+
* model. Multi-vectors should be indexed in the {@link
11+
* org.apache.lucene.document.LateInteractionField} provided to rescorer.
1212
*
1313
* @lucene.experimental
1414
*/
@@ -18,26 +18,27 @@ public LateInteractionRescorer(LateInteractionFloatValuesSource valuesSource) {
1818
super(valuesSource);
1919
}
2020

21-
/**
22-
* Creates a LateInteractionRescorer for provided query vector.
23-
*/
21+
/** Creates a LateInteractionRescorer for provided query vector. */
2422
public static LateInteractionRescorer create(String fieldName, float[][] queryVector) {
2523
return create(fieldName, queryVector, VectorSimilarityFunction.COSINE);
2624
}
2725

2826
/**
2927
* Creates a LateInteractionRescorer for provided query vector.
3028
*
31-
* <p>Top N results from a first pass query are rescored based on the similarity between {@code queryVector} and
32-
* the multi-vector indexed in {@code fieldName}. If document does not have a value indexed in {@code fieldName},
33-
* a 0f score is assigned.
29+
* <p>Top N results from a first pass query are rescored based on the similarity between {@code
30+
* queryVector} and the multi-vector indexed in {@code fieldName}. If document does not have a
31+
* value indexed in {@code fieldName}, a 0f score is assigned.
3432
*
35-
* @param fieldName the {@link org.apache.lucene.document.LateInteractionField} used for reranking.
33+
* @param fieldName the {@link org.apache.lucene.document.LateInteractionField} used for
34+
* reranking.
3635
* @param queryVector query multi-vector to use for similarity comparison
3736
* @param vectorSimilarityFunction function used for vector similarity comparisons
3837
*/
39-
public static LateInteractionRescorer create(String fieldName, float[][] queryVector, VectorSimilarityFunction vectorSimilarityFunction) {
40-
final LateInteractionFloatValuesSource valuesSource = new LateInteractionFloatValuesSource(fieldName, queryVector, vectorSimilarityFunction);
38+
public static LateInteractionRescorer create(
39+
String fieldName, float[][] queryVector, VectorSimilarityFunction vectorSimilarityFunction) {
40+
final LateInteractionFloatValuesSource valuesSource =
41+
new LateInteractionFloatValuesSource(fieldName, queryVector, vectorSimilarityFunction);
4142
return new LateInteractionRescorer(valuesSource);
4243
}
4344

@@ -49,16 +50,19 @@ protected float combine(float firstPassScore, boolean valuePresent, double sourc
4950
/**
5051
* Creates a LateInteractionRescorer for provided query vector.
5152
*
52-
* <p>Top N results from a first pass query are rescored based on the similarity between {@code queryVector} and
53-
* the multi-vector indexed in {@code fieldName}. Falls back to score from the first pass query if a document
54-
* does not have a value indexed in {@code fieldName}.
53+
* <p>Top N results from a first pass query are rescored based on the similarity between {@code
54+
* queryVector} and the multi-vector indexed in {@code fieldName}. Falls back to score from the
55+
* first pass query if a document does not have a value indexed in {@code fieldName}.
5556
*
56-
* @param fieldName the {@link org.apache.lucene.document.LateInteractionField} used for reranking.
57+
* @param fieldName the {@link org.apache.lucene.document.LateInteractionField} used for
58+
* reranking.
5759
* @param queryVector query multi-vector to use for similarity comparison
5860
* @param vectorSimilarityFunction function used for vector similarity comparisons.
5961
*/
60-
public static LateInteractionRescorer withFallbackToFirstPassScore(String fieldName, float[][] queryVector, VectorSimilarityFunction vectorSimilarityFunction) {
61-
final LateInteractionFloatValuesSource valuesSource = new LateInteractionFloatValuesSource(fieldName, queryVector, vectorSimilarityFunction);
62+
public static LateInteractionRescorer withFallbackToFirstPassScore(
63+
String fieldName, float[][] queryVector, VectorSimilarityFunction vectorSimilarityFunction) {
64+
final LateInteractionFloatValuesSource valuesSource =
65+
new LateInteractionFloatValuesSource(fieldName, queryVector, vectorSimilarityFunction);
6266
return new LateInteractionRescorer(valuesSource) {
6367
@Override
6468
protected float combine(float firstPassScore, boolean valuePresent, double sourceValue) {

lucene/core/src/java/org/apache/lucene/search/RescoreTopNQuery.java

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,15 +154,26 @@ public static Query createFullPrecisionRescorerQuery(
154154
* Creates a {@code RescoreTopNQuery} that computes top N results using multi-vector similarity
155155
* comparisons against a late interaction field.
156156
*
157+
* <p>Note: This query computes late interaction field similarity for the entire match-set of
158+
* wrapped query, and returns a new query with only top-N hits in the match-set. This is typically
159+
* useful in combining a query's results with other queries for hybrid search. To simply rerank
160+
* the top N hits without scoring entire match-set, see {@link LateInteractionRescorer}.
161+
*
157162
* @param in the inner Query to rescore
158163
* @param n number of results to keep
159-
* @param fieldName the {@link org.apache.lucene.document.LateInteractionField} for recomputing top N hits
164+
* @param fieldName the {@link org.apache.lucene.document.LateInteractionField} for recomputing
165+
* top N hits
160166
* @param queryVector query multi-vector to use for similarity comparisons
161167
* @param vectorSimilarityFunction function to use for vector similarity comparisons.
162168
*/
163169
public static Query createLateInteractionQuery(
164-
Query in, int n, String fieldName, float[][] queryVector, VectorSimilarityFunction vectorSimilarityFunction) {
165-
final LateInteractionFloatValuesSource valuesSource = new LateInteractionFloatValuesSource(fieldName, queryVector, vectorSimilarityFunction);
170+
Query in,
171+
int n,
172+
String fieldName,
173+
float[][] queryVector,
174+
VectorSimilarityFunction vectorSimilarityFunction) {
175+
final LateInteractionFloatValuesSource valuesSource =
176+
new LateInteractionFloatValuesSource(fieldName, queryVector, vectorSimilarityFunction);
166177
return new RescoreTopNQuery(in, valuesSource, n);
167178
}
168179
}

lucene/core/src/test/org/apache/lucene/search/TestRescoreTopNQuery.java

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,25 @@
1717
package org.apache.lucene.search;
1818

1919
import java.io.IOException;
20+
import java.util.ArrayList;
21+
import java.util.Arrays;
2022
import java.util.HashMap;
23+
import java.util.List;
2124
import java.util.Map;
2225
import java.util.Random;
26+
import java.util.Set;
27+
import java.util.stream.Collectors;
2328
import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat;
2429
import org.apache.lucene.document.Document;
2530
import org.apache.lucene.document.Field;
2631
import org.apache.lucene.document.IntField;
2732
import org.apache.lucene.document.KnnFloatVectorField;
33+
import org.apache.lucene.document.LateInteractionField;
2834
import org.apache.lucene.index.DirectoryReader;
2935
import org.apache.lucene.index.IndexReader;
3036
import org.apache.lucene.index.IndexWriter;
3137
import org.apache.lucene.index.IndexWriterConfig;
38+
import org.apache.lucene.index.StoredFields;
3239
import org.apache.lucene.index.Term;
3340
import org.apache.lucene.index.VectorSimilarityFunction;
3441
import org.apache.lucene.store.ByteBuffersDirectory;
@@ -156,6 +163,87 @@ public void testMissingDoubleValues() throws IOException {
156163
}
157164
}
158165

166+
public void testLateInteractionQuery() throws Exception {
167+
final String LATE_I_FIELD = "li_vector";
168+
final String KNN_FIELD = "knn_vector";
169+
List<float[][]> corpus = new ArrayList<>();
170+
final int numDocs = atLeast(1000);
171+
final int numSegments = random().nextInt(2, 10);
172+
final int dim = 128;
173+
final VectorSimilarityFunction vectorSimilarityFunction =
174+
VectorSimilarityFunction.values()[
175+
random().nextInt(VectorSimilarityFunction.values().length)];
176+
LateInteractionFloatValuesSource.ScoreFunction scoreFunction =
177+
LateInteractionFloatValuesSource.ScoreFunction.values()[
178+
random().nextInt(LateInteractionFloatValuesSource.ScoreFunction.values().length)];
179+
180+
try (Directory dir = newDirectory()) {
181+
int id = 0;
182+
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
183+
for (int j = 0; j < numSegments; j++) {
184+
for (int i = 0; i < numDocs; i++) {
185+
Document doc = new Document();
186+
if (random().nextInt(100) < 30) {
187+
// skip value for some docs to create sparse field
188+
doc.add(new IntField("has_li_vector", 0, Field.Store.YES));
189+
} else {
190+
float[][] value = createMultiVector(dim);
191+
corpus.add(value);
192+
doc.add(new IntField("id", id++, Field.Store.YES));
193+
doc.add(new LateInteractionField(LATE_I_FIELD, value));
194+
doc.add(new KnnFloatVectorField(KNN_FIELD, randomFloatVector(dim, random())));
195+
doc.add(new IntField("has_li_vector", 1, Field.Store.YES));
196+
}
197+
w.addDocument(doc);
198+
w.flush();
199+
}
200+
}
201+
// add a segment with no vectors
202+
for (int i = 0; i < 100; i++) {
203+
Document doc = new Document();
204+
doc.add(new IntField("has_li_vector", 0, Field.Store.YES));
205+
w.addDocument(doc);
206+
}
207+
w.flush();
208+
}
209+
210+
float[][] lateIQueryVector = createMultiVector(dim);
211+
float[] knnQueryVector = randomFloatVector(dim, random());
212+
KnnFloatVectorQuery knnQuery = new KnnFloatVectorQuery(KNN_FIELD, knnQueryVector, 50);
213+
214+
try (IndexReader reader = DirectoryReader.open(dir)) {
215+
final int topN = 10;
216+
IndexSearcher s = new IndexSearcher(reader);
217+
TopDocs knnHits = s.search(knnQuery, 5 * topN);
218+
Set<Integer> knnHitDocs =
219+
Arrays.stream(knnHits.scoreDocs).map(k -> k.doc).collect(Collectors.toSet());
220+
Query lateIQuery =
221+
RescoreTopNQuery.createLateInteractionQuery(
222+
knnQuery, topN, LATE_I_FIELD, lateIQueryVector, vectorSimilarityFunction);
223+
TopDocs lateIHits = s.search(lateIQuery, 3 * topN);
224+
// total match-set for RescoreTopNQuery is topN
225+
assertEquals(topN, lateIHits.scoreDocs.length);
226+
StoredFields storedFields = reader.storedFields();
227+
for (ScoreDoc hit : lateIHits.scoreDocs) {
228+
assertTrue(knnHitDocs.contains(hit.doc));
229+
int idValue = Integer.parseInt(storedFields.document(hit.doc).get("id"));
230+
float[][] docVector = corpus.get(idValue);
231+
float expected =
232+
scoreFunction.compare(lateIQueryVector, docVector, vectorSimilarityFunction);
233+
assertEquals(expected, hit.score, 1e-5);
234+
}
235+
}
236+
}
237+
}
238+
239+
private float[][] createMultiVector(int dimension) {
240+
float[][] value = new float[random().nextInt(3, 12)][];
241+
for (int i = 0; i < value.length; i++) {
242+
value[i] = randomFloatVector(dimension, random());
243+
}
244+
return value;
245+
}
246+
159247
private float[] randomFloatVector(int dimension, Random random) {
160248
float[] vector = new float[dimension];
161249
for (int i = 0; i < dimension; i++) {

lucene/queries/src/java/org/apache/lucene/queries/function/FunctionScoreQuery.java

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,12 @@
2020
import java.io.IOException;
2121
import java.util.Objects;
2222
import org.apache.lucene.index.LeafReaderContext;
23-
import org.apache.lucene.index.VectorSimilarityFunction;
2423
import org.apache.lucene.search.BooleanClause;
2524
import org.apache.lucene.search.DoubleValues;
2625
import org.apache.lucene.search.DoubleValuesSource;
2726
import org.apache.lucene.search.Explanation;
2827
import org.apache.lucene.search.FilterScorer;
2928
import org.apache.lucene.search.IndexSearcher;
30-
import org.apache.lucene.search.LateInteractionFloatValuesSource;
3129
import org.apache.lucene.search.Matches;
3230
import org.apache.lucene.search.Query;
3331
import org.apache.lucene.search.QueryVisitor;
@@ -73,28 +71,6 @@ public DoubleValuesSource getSource() {
7371
return source;
7472
}
7573

76-
/**
77-
* Returns a FunctionScoreQuery that re-scores hits from the wrapped query using late-interaction
78-
* scores between provided query and indexed document multi-vectors.
79-
*
80-
* <p>Document multi-vectors are indexed using {@link
81-
* org.apache.lucene.document.LateInteractionField}.
82-
*
83-
* @param in the query to re-score
84-
* @param fieldName field containing document multi-vectors for re-scoring
85-
* @param queryVector query multi-vector
86-
* @param vectorSimilarityFunction vector similarity function used for computing scores
87-
*/
88-
public static FunctionScoreQuery lateInteractionFloatRerankQuery(
89-
Query in,
90-
String fieldName,
91-
float[][] queryVector,
92-
VectorSimilarityFunction vectorSimilarityFunction) {
93-
LateInteractionFloatValuesSource scoreSource =
94-
new LateInteractionFloatValuesSource(fieldName, queryVector, vectorSimilarityFunction);
95-
return new FunctionScoreQuery(in, scoreSource);
96-
}
97-
9874
/**
9975
* Returns a FunctionScoreQuery where the scores of a wrapped query are multiplied by the value of
10076
* a DoubleValuesSource.

lucene/queries/src/test/org/apache/lucene/queries/function/TestFunctionScoreQuery.java

Lines changed: 0 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,9 @@
1818
package org.apache.lucene.queries.function;
1919

2020
import java.io.IOException;
21-
import java.util.ArrayList;
22-
import java.util.Arrays;
23-
import java.util.List;
24-
import java.util.Random;
25-
import java.util.Set;
2621
import java.util.concurrent.atomic.AtomicReference;
27-
import java.util.stream.Collectors;
2822
import org.apache.lucene.document.Document;
2923
import org.apache.lucene.document.Field;
30-
import org.apache.lucene.document.IntField;
31-
import org.apache.lucene.document.KnnFloatVectorField;
32-
import org.apache.lucene.document.LateInteractionField;
3324
import org.apache.lucene.document.NumericDocValuesField;
3425
import org.apache.lucene.document.TextField;
3526
import org.apache.lucene.expressions.Expression;
@@ -40,21 +31,16 @@
4031
import org.apache.lucene.index.IndexWriter;
4132
import org.apache.lucene.index.IndexWriterConfig;
4233
import org.apache.lucene.index.LeafReaderContext;
43-
import org.apache.lucene.index.StoredFields;
4434
import org.apache.lucene.index.Term;
45-
import org.apache.lucene.index.VectorSimilarityFunction;
4635
import org.apache.lucene.search.BooleanClause;
4736
import org.apache.lucene.search.BooleanQuery;
4837
import org.apache.lucene.search.BoostQuery;
4938
import org.apache.lucene.search.DoubleValuesSource;
5039
import org.apache.lucene.search.Explanation;
5140
import org.apache.lucene.search.IndexSearcher;
52-
import org.apache.lucene.search.KnnFloatVectorQuery;
53-
import org.apache.lucene.search.LateInteractionFloatValuesSource;
5441
import org.apache.lucene.search.MatchAllDocsQuery;
5542
import org.apache.lucene.search.PhraseQuery;
5643
import org.apache.lucene.search.Query;
57-
import org.apache.lucene.search.ScoreDoc;
5844
import org.apache.lucene.search.ScoreMode;
5945
import org.apache.lucene.search.TermQuery;
6046
import org.apache.lucene.search.TopDocs;
@@ -391,91 +377,4 @@ public void testQueryMatchesCount() throws Exception {
391377
}
392378
assertEquals(searchCount, weightCount);
393379
}
394-
395-
public void testLateInteractionQuery() throws Exception {
396-
final String LATE_I_FIELD = "li_vector";
397-
final String KNN_FIELD = "knn_vector";
398-
List<float[][]> corpus = new ArrayList<>();
399-
final int numDocs = atLeast(1000);
400-
final int numSegments = random().nextInt(2, 10);
401-
final int dim = 128;
402-
final VectorSimilarityFunction vectorSimilarityFunction =
403-
VectorSimilarityFunction.values()[
404-
random().nextInt(VectorSimilarityFunction.values().length)];
405-
LateInteractionFloatValuesSource.ScoreFunction scoreFunction =
406-
LateInteractionFloatValuesSource.ScoreFunction.values()[
407-
random().nextInt(LateInteractionFloatValuesSource.ScoreFunction.values().length)];
408-
409-
try (Directory dir = newDirectory()) {
410-
int id = 0;
411-
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
412-
for (int j = 0; j < numSegments; j++) {
413-
for (int i = 0; i < numDocs; i++) {
414-
Document doc = new Document();
415-
if (random().nextInt(100) < 30) {
416-
// skip value for some docs to create sparse field
417-
doc.add(new IntField("has_li_vector", 0, Field.Store.YES));
418-
} else {
419-
float[][] value = createMultiVector(dim);
420-
corpus.add(value);
421-
doc.add(new IntField("id", id++, Field.Store.YES));
422-
doc.add(new LateInteractionField(LATE_I_FIELD, value));
423-
doc.add(new KnnFloatVectorField(KNN_FIELD, randomVector(dim)));
424-
doc.add(new IntField("has_li_vector", 1, Field.Store.YES));
425-
}
426-
w.addDocument(doc);
427-
w.flush();
428-
}
429-
}
430-
// add a segment with no vectors
431-
for (int i = 0; i < 100; i++) {
432-
Document doc = new Document();
433-
doc.add(new IntField("has_li_vector", 0, Field.Store.YES));
434-
w.addDocument(doc);
435-
}
436-
w.flush();
437-
}
438-
439-
float[][] lateIQueryVector = createMultiVector(dim);
440-
float[] knnQueryVector = randomVector(dim);
441-
KnnFloatVectorQuery knnQuery = new KnnFloatVectorQuery(KNN_FIELD, knnQueryVector, 50);
442-
443-
try (IndexReader reader = DirectoryReader.open(dir)) {
444-
IndexSearcher s = new IndexSearcher(reader);
445-
TopDocs knnHits = s.search(knnQuery, 50);
446-
Set<Integer> knnHitDocs =
447-
Arrays.stream(knnHits.scoreDocs).map(k -> k.doc).collect(Collectors.toSet());
448-
FunctionScoreQuery lateIQuery =
449-
FunctionScoreQuery.lateInteractionFloatRerankQuery(
450-
knnQuery, LATE_I_FIELD, lateIQueryVector, vectorSimilarityFunction);
451-
TopDocs lateIHits = s.search(lateIQuery, 10);
452-
StoredFields storedFields = reader.storedFields();
453-
for (ScoreDoc hit : lateIHits.scoreDocs) {
454-
assertTrue(knnHitDocs.contains(hit.doc));
455-
int idValue = Integer.parseInt(storedFields.document(hit.doc).get("id"));
456-
float[][] docVector = corpus.get(idValue);
457-
float expected =
458-
scoreFunction.compare(lateIQueryVector, docVector, vectorSimilarityFunction);
459-
assertEquals(expected, hit.score, 1e-5);
460-
}
461-
}
462-
}
463-
}
464-
465-
private float[] randomVector(int dim) {
466-
float[] v = new float[dim];
467-
Random random = random();
468-
for (int i = 0; i < dim; i++) {
469-
v[i] = random.nextFloat();
470-
}
471-
return v;
472-
}
473-
474-
private float[][] createMultiVector(int dimension) {
475-
float[][] value = new float[random().nextInt(3, 12)][];
476-
for (int i = 0; i < value.length; i++) {
477-
value[i] = randomVector(dimension);
478-
}
479-
return value;
480-
}
481380
}

0 commit comments

Comments
 (0)