Skip to content

Commit 3404496

Browse files
authored
Add Query for reranking KnnFloatVectorQuery with full-precision vectors (#14009)
1 parent 8e2c6aa commit 3404496

File tree

6 files changed

+572
-241
lines changed

6 files changed

+572
-241
lines changed

lucene/CHANGES.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,9 @@ New Features
115115

116116
* GITHUB#14784: Make pack methods public for BigIntegerPoint and HalfFloatPoint. (Prudhvi Godithi)
117117

118+
* GITHUB#14009: Add a new Query that can rescore other Query based on a generic DoubleValueSource
119+
and trim the results down to top N (Anh Dung Bui)
120+
118121
Improvements
119122
---------------------
120123
* GITHUB#14458: Add an IndexDeletion policy that retains the last N commits. (Owais Kazi)

lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java

Lines changed: 1 addition & 238 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,8 @@
1616
*/
1717
package org.apache.lucene.search;
1818

19-
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
20-
2119
import java.io.IOException;
2220
import java.util.ArrayList;
23-
import java.util.Arrays;
24-
import java.util.Comparator;
2521
import java.util.HashMap;
2622
import java.util.Iterator;
2723
import java.util.List;
@@ -142,7 +138,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
142138
if (topK.scoreDocs.length == 0) {
143139
return new MatchNoDocsQuery();
144140
}
145-
return createRewrittenQuery(reader, topK, reentryCount);
141+
return DocAndScoreQuery.createDocAndScoreQuery(reader, topK, reentryCount);
146142
}
147143

148144
private TopDocs runSearchTasks(
@@ -398,46 +394,6 @@ public KnnCollector newCollector(
398394
}
399395
}
400396

401-
protected Query createRewrittenQuery(IndexReader reader, TopDocs topK, int reentryCount) {
402-
int len = topK.scoreDocs.length;
403-
assert len > 0;
404-
float maxScore = topK.scoreDocs[0].score;
405-
Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
406-
int[] docs = new int[len];
407-
float[] scores = new float[len];
408-
for (int i = 0; i < len; i++) {
409-
docs[i] = topK.scoreDocs[i].doc;
410-
scores[i] = topK.scoreDocs[i].score;
411-
}
412-
int[] segmentStarts = findSegmentStarts(reader.leaves(), docs);
413-
return new DocAndScoreQuery(
414-
docs,
415-
scores,
416-
maxScore,
417-
segmentStarts,
418-
topK.totalHits.value(),
419-
reader.getContext().id(),
420-
reentryCount);
421-
}
422-
423-
static int[] findSegmentStarts(List<LeafReaderContext> leaves, int[] docs) {
424-
int[] starts = new int[leaves.size() + 1];
425-
starts[starts.length - 1] = docs.length;
426-
if (starts.length == 2) {
427-
return starts;
428-
}
429-
int resultIndex = 0;
430-
for (int i = 1; i < starts.length - 1; i++) {
431-
int upper = leaves.get(i).docBase;
432-
resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
433-
if (resultIndex < 0) {
434-
resultIndex = -1 - resultIndex;
435-
}
436-
starts[i] = resultIndex;
437-
}
438-
return starts;
439-
}
440-
441397
@Override
442398
public void visit(QueryVisitor visitor) {
443399
if (visitor.acceptField(field)) {
@@ -483,199 +439,6 @@ public Query getFilter() {
483439
return filter;
484440
}
485441

486-
/** Caches the results of a KnnVector search: a list of docs and their scores */
487-
static class DocAndScoreQuery extends Query {
488-
489-
private final int[] docs;
490-
private final float[] scores;
491-
private final float maxScore;
492-
private final int[] segmentStarts;
493-
private final long visited;
494-
private final Object contextIdentity;
495-
private final int reentryCount;
496-
497-
/**
498-
* Constructor
499-
*
500-
* @param docs the global docids of documents that match, in ascending order
501-
* @param scores the scores of the matching documents
502-
* @param maxScore the max of those scores? why do we need to pass in?
503-
* @param segmentStarts the indexes in docs and scores corresponding to the first matching
504-
* document in each segment. If a segment has no matching documents, it should be assigned
505-
* the index of the next segment that does. There should be a final entry that is always
506-
* docs.length-1.
507-
* @param visited the number of graph nodes that were visited, and for which vector distance
508-
* scores were evaluated.
509-
* @param contextIdentity an object identifying the reader context that was used to build this
510-
* query
511-
*/
512-
DocAndScoreQuery(
513-
int[] docs,
514-
float[] scores,
515-
float maxScore,
516-
int[] segmentStarts,
517-
long visited,
518-
Object contextIdentity,
519-
int reentryCount) {
520-
this.docs = docs;
521-
this.scores = scores;
522-
this.maxScore = maxScore;
523-
this.segmentStarts = segmentStarts;
524-
this.visited = visited;
525-
this.contextIdentity = contextIdentity;
526-
this.reentryCount = reentryCount;
527-
}
528-
529-
/*
530-
DocAndScoreQuery(DocAndScoreQuery other) {
531-
this.docs = other.docs;
532-
this.scores = other.scores;
533-
this.maxScore = other.maxScore;
534-
this.segmentStarts = other.segmentStarts;
535-
this.visited = other.visited;
536-
this.contextIdentity = other.contextIdentity;
537-
this.reentryCount = other.reentryCount;
538-
}
539-
*/
540-
541-
int reentryCount() {
542-
return reentryCount;
543-
}
544-
545-
@Override
546-
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
547-
throws IOException {
548-
if (searcher.getIndexReader().getContext().id() != contextIdentity) {
549-
throw new IllegalStateException("This DocAndScore query was created by a different reader");
550-
}
551-
return new Weight(this) {
552-
@Override
553-
public Explanation explain(LeafReaderContext context, int doc) {
554-
int found = Arrays.binarySearch(docs, doc + context.docBase);
555-
if (found < 0) {
556-
return Explanation.noMatch("not in top " + docs.length + " docs");
557-
}
558-
return Explanation.match(scores[found] * boost, "within top " + docs.length + " docs");
559-
}
560-
561-
@Override
562-
public int count(LeafReaderContext context) {
563-
return segmentStarts[context.ord + 1] - segmentStarts[context.ord];
564-
}
565-
566-
@Override
567-
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
568-
if (segmentStarts[context.ord] == segmentStarts[context.ord + 1]) {
569-
return null;
570-
}
571-
final var scorer =
572-
new Scorer() {
573-
final int lower = segmentStarts[context.ord];
574-
final int upper = segmentStarts[context.ord + 1];
575-
int upTo = -1;
576-
577-
@Override
578-
public DocIdSetIterator iterator() {
579-
return new DocIdSetIterator() {
580-
@Override
581-
public int docID() {
582-
return docIdNoShadow();
583-
}
584-
585-
@Override
586-
public int nextDoc() {
587-
if (upTo == -1) {
588-
upTo = lower;
589-
} else {
590-
++upTo;
591-
}
592-
return docIdNoShadow();
593-
}
594-
595-
@Override
596-
public int advance(int target) throws IOException {
597-
return slowAdvance(target);
598-
}
599-
600-
@Override
601-
public long cost() {
602-
return upper - lower;
603-
}
604-
};
605-
}
606-
607-
@Override
608-
public float getMaxScore(int docId) {
609-
return maxScore * boost;
610-
}
611-
612-
@Override
613-
public float score() {
614-
return scores[upTo] * boost;
615-
}
616-
617-
/**
618-
* move the implementation of docID() into a differently-named method so we can call
619-
* it from DocIDSetIterator.docID() even though this class is anonymous
620-
*
621-
* @return the current docid
622-
*/
623-
private int docIdNoShadow() {
624-
if (upTo == -1) {
625-
return -1;
626-
}
627-
if (upTo >= upper) {
628-
return NO_MORE_DOCS;
629-
}
630-
return docs[upTo] - context.docBase;
631-
}
632-
633-
@Override
634-
public int docID() {
635-
return docIdNoShadow();
636-
}
637-
};
638-
return new DefaultScorerSupplier(scorer);
639-
}
640-
641-
@Override
642-
public boolean isCacheable(LeafReaderContext ctx) {
643-
return true;
644-
}
645-
};
646-
}
647-
648-
@Override
649-
public String toString(String field) {
650-
return "DocAndScoreQuery[" + docs[0] + ",...][" + scores[0] + ",...]," + maxScore;
651-
}
652-
653-
@Override
654-
public void visit(QueryVisitor visitor) {
655-
visitor.visitLeaf(this);
656-
}
657-
658-
public long visited() {
659-
return visited;
660-
}
661-
662-
@Override
663-
public boolean equals(Object obj) {
664-
if (sameClassAs(obj) == false) {
665-
return false;
666-
}
667-
return contextIdentity == ((DocAndScoreQuery) obj).contextIdentity
668-
&& Arrays.equals(docs, ((DocAndScoreQuery) obj).docs)
669-
&& Arrays.equals(scores, ((DocAndScoreQuery) obj).scores);
670-
}
671-
672-
@Override
673-
public int hashCode() {
674-
return Objects.hash(
675-
classHash(), contextIdentity, Arrays.hashCode(docs), Arrays.hashCode(scores));
676-
}
677-
}
678-
679442
public KnnSearchStrategy getSearchStrategy() {
680443
return searchStrategy;
681444
}

0 commit comments

Comments
 (0)