Skip to content

Commit 4fcb6d6

Browse files
authored
Add Query for reranking KnnFloatVectorQuery with full-precision vectors (#14860)
1 parent b6fb10d commit 4fcb6d6

File tree

6 files changed

+547
-204
lines changed

6 files changed

+547
-204
lines changed

lucene/CHANGES.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ New Features
2727

2828
* GITHUB#14784: Make pack methods public for BigIntegerPoint and HalfFloatPoint. (Prudhvi Godithi)
2929

30+
* GITHUB#14009: Add a new Query that can rescore other Query based on a generic DoubleValueSource
31+
and trim the results down to top N (Anh Dung Bui)
32+
3033
Improvements
3134
---------------------
3235
* GITHUB#14458: Add an IndexDeletion policy that retains the last N commits. (Owais Kazi)

lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java

Lines changed: 1 addition & 200 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,8 @@
1616
*/
1717
package org.apache.lucene.search;
1818

19-
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
20-
2119
import java.io.IOException;
2220
import java.util.ArrayList;
23-
import java.util.Arrays;
24-
import java.util.Comparator;
2521
import java.util.List;
2622
import java.util.Objects;
2723
import java.util.concurrent.Callable;
@@ -106,7 +102,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
106102
if (topK.scoreDocs.length == 0) {
107103
return new MatchNoDocsQuery();
108104
}
109-
return createRewrittenQuery(reader, topK);
105+
return DocAndScoreQuery.createDocAndScoreQuery(reader, topK);
110106
}
111107

112108
private TopDocs searchLeaf(
@@ -275,41 +271,6 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
275271
return TopDocs.merge(k, perLeafResults);
276272
}
277273

278-
private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
279-
int len = topK.scoreDocs.length;
280-
281-
assert len > 0;
282-
float maxScore = topK.scoreDocs[0].score;
283-
284-
Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
285-
int[] docs = new int[len];
286-
float[] scores = new float[len];
287-
for (int i = 0; i < len; i++) {
288-
docs[i] = topK.scoreDocs[i].doc;
289-
scores[i] = topK.scoreDocs[i].score;
290-
}
291-
int[] segmentStarts = findSegmentStarts(reader.leaves(), docs);
292-
return new DocAndScoreQuery(docs, scores, maxScore, segmentStarts, reader.getContext().id());
293-
}
294-
295-
static int[] findSegmentStarts(List<LeafReaderContext> leaves, int[] docs) {
296-
int[] starts = new int[leaves.size() + 1];
297-
starts[starts.length - 1] = docs.length;
298-
if (starts.length == 2) {
299-
return starts;
300-
}
301-
int resultIndex = 0;
302-
for (int i = 1; i < starts.length - 1; i++) {
303-
int upper = leaves.get(i).docBase;
304-
resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
305-
if (resultIndex < 0) {
306-
resultIndex = -1 - resultIndex;
307-
}
308-
starts[i] = resultIndex;
309-
}
310-
return starts;
311-
}
312-
313274
@Override
314275
public void visit(QueryVisitor visitor) {
315276
if (visitor.acceptField(field)) {
@@ -355,166 +316,6 @@ public Query getFilter() {
355316
return filter;
356317
}
357318

358-
/** Caches the results of a KnnVector search: a list of docs and their scores */
359-
static class DocAndScoreQuery extends Query {
360-
361-
private final int[] docs;
362-
private final float[] scores;
363-
private final float maxScore;
364-
private final int[] segmentStarts;
365-
private final Object contextIdentity;
366-
367-
/**
368-
* Constructor
369-
*
370-
* @param docs the global docids of documents that match, in ascending order
371-
* @param scores the scores of the matching documents
372-
* @param segmentStarts the indexes in docs and scores corresponding to the first matching
373-
* document in each segment. If a segment has no matching documents, it should be assigned
374-
* the index of the next segment that does. There should be a final entry that is always
375-
* docs.length-1.
376-
* @param contextIdentity an object identifying the reader context that was used to build this
377-
* query
378-
*/
379-
DocAndScoreQuery(
380-
int[] docs, float[] scores, float maxScore, int[] segmentStarts, Object contextIdentity) {
381-
this.docs = docs;
382-
this.scores = scores;
383-
this.maxScore = maxScore;
384-
this.segmentStarts = segmentStarts;
385-
this.contextIdentity = contextIdentity;
386-
}
387-
388-
@Override
389-
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
390-
throws IOException {
391-
if (searcher.getIndexReader().getContext().id() != contextIdentity) {
392-
throw new IllegalStateException("This DocAndScore query was created by a different reader");
393-
}
394-
return new Weight(this) {
395-
@Override
396-
public Explanation explain(LeafReaderContext context, int doc) {
397-
int found = Arrays.binarySearch(docs, doc + context.docBase);
398-
if (found < 0) {
399-
return Explanation.noMatch("not in top " + docs.length + " docs");
400-
}
401-
return Explanation.match(scores[found] * boost, "within top " + docs.length + " docs");
402-
}
403-
404-
@Override
405-
public int count(LeafReaderContext context) {
406-
return segmentStarts[context.ord + 1] - segmentStarts[context.ord];
407-
}
408-
409-
@Override
410-
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
411-
if (segmentStarts[context.ord] == segmentStarts[context.ord + 1]) {
412-
return null;
413-
}
414-
final var scorer =
415-
new Scorer() {
416-
final int lower = segmentStarts[context.ord];
417-
final int upper = segmentStarts[context.ord + 1];
418-
int upTo = -1;
419-
420-
@Override
421-
public DocIdSetIterator iterator() {
422-
return new DocIdSetIterator() {
423-
@Override
424-
public int docID() {
425-
return docIdNoShadow();
426-
}
427-
428-
@Override
429-
public int nextDoc() {
430-
if (upTo == -1) {
431-
upTo = lower;
432-
} else {
433-
++upTo;
434-
}
435-
return docIdNoShadow();
436-
}
437-
438-
@Override
439-
public int advance(int target) throws IOException {
440-
return slowAdvance(target);
441-
}
442-
443-
@Override
444-
public long cost() {
445-
return upper - lower;
446-
}
447-
};
448-
}
449-
450-
@Override
451-
public float getMaxScore(int docId) {
452-
return maxScore * boost;
453-
}
454-
455-
@Override
456-
public float score() {
457-
return scores[upTo] * boost;
458-
}
459-
460-
/**
461-
* move the implementation of docID() into a differently-named method so we can call
462-
* it from DocIDSetIterator.docID() even though this class is anonymous
463-
*
464-
* @return the current docid
465-
*/
466-
private int docIdNoShadow() {
467-
if (upTo == -1) {
468-
return -1;
469-
}
470-
if (upTo >= upper) {
471-
return NO_MORE_DOCS;
472-
}
473-
return docs[upTo] - context.docBase;
474-
}
475-
476-
@Override
477-
public int docID() {
478-
return docIdNoShadow();
479-
}
480-
};
481-
return new DefaultScorerSupplier(scorer);
482-
}
483-
484-
@Override
485-
public boolean isCacheable(LeafReaderContext ctx) {
486-
return true;
487-
}
488-
};
489-
}
490-
491-
@Override
492-
public String toString(String field) {
493-
return "DocAndScoreQuery[" + docs[0] + ",...][" + scores[0] + ",...]," + maxScore;
494-
}
495-
496-
@Override
497-
public void visit(QueryVisitor visitor) {
498-
visitor.visitLeaf(this);
499-
}
500-
501-
@Override
502-
public boolean equals(Object obj) {
503-
if (sameClassAs(obj) == false) {
504-
return false;
505-
}
506-
return contextIdentity == ((DocAndScoreQuery) obj).contextIdentity
507-
&& Arrays.equals(docs, ((DocAndScoreQuery) obj).docs)
508-
&& Arrays.equals(scores, ((DocAndScoreQuery) obj).scores);
509-
}
510-
511-
@Override
512-
public int hashCode() {
513-
return Objects.hash(
514-
classHash(), contextIdentity, Arrays.hashCode(docs), Arrays.hashCode(scores));
515-
}
516-
}
517-
518319
public KnnSearchStrategy getSearchStrategy() {
519320
return searchStrategy;
520321
}

0 commit comments

Comments
 (0)