diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index ea219e798988..fbf1b5191298 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -111,6 +111,9 @@ New Features * GITHUB#14784: Make pack methods public for BigIntegerPoint and HalfFloatPoint. (Prudhvi Godithi) +* GITHUB#14009: Add a new Query that can rescore other Query based on a generic DoubleValueSource + and trim the results down to top N (Anh Dung Bui) + Improvements --------------------- * GITHUB#14458: Add an IndexDeletion policy that retains the last N commits. (Owais Kazi) diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index 5590643535ff..510e5cf72e57 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -16,12 +16,8 @@ */ package org.apache.lucene.search; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; - import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; -import java.util.Comparator; import java.util.HashMap; import java.util.Iterator; import java.util.List; @@ -142,7 +138,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { if (topK.scoreDocs.length == 0) { return new MatchNoDocsQuery(); } - return createRewrittenQuery(reader, topK, reentryCount); + return DocAndScoreQuery.createDocAndScoreQuery(reader, topK, reentryCount); } private TopDocs runSearchTasks( @@ -398,46 +394,6 @@ public KnnCollector newCollector( } } - protected Query createRewrittenQuery(IndexReader reader, TopDocs topK, int reentryCount) { - int len = topK.scoreDocs.length; - assert len > 0; - float maxScore = topK.scoreDocs[0].score; - Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc)); - int[] docs = new int[len]; - float[] scores = new float[len]; - for (int i = 0; i < len; i++) { - docs[i] = topK.scoreDocs[i].doc; - scores[i] = topK.scoreDocs[i].score; - } - int[] segmentStarts = findSegmentStarts(reader.leaves(), docs); - return new DocAndScoreQuery( - docs, - scores, - maxScore, - segmentStarts, - topK.totalHits.value(), - reader.getContext().id(), - reentryCount); - } - - static int[] findSegmentStarts(List leaves, int[] docs) { - int[] starts = new int[leaves.size() + 1]; - starts[starts.length - 1] = docs.length; - if (starts.length == 2) { - return starts; - } - int resultIndex = 0; - for (int i = 1; i < starts.length - 1; i++) { - int upper = leaves.get(i).docBase; - resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper); - if (resultIndex < 0) { - resultIndex = -1 - resultIndex; - } - starts[i] = resultIndex; - } - return starts; - } - @Override public void visit(QueryVisitor visitor) { if (visitor.acceptField(field)) { @@ -483,199 +439,6 @@ public Query getFilter() { return filter; } - /** Caches the results of a KnnVector search: a list of docs and their scores */ - static class DocAndScoreQuery extends Query { - - private final int[] docs; - private final float[] scores; - private final float maxScore; - private final int[] segmentStarts; - private final long visited; - private final Object contextIdentity; - private final int reentryCount; - - /** - * Constructor - * - * @param docs the global docids of documents that match, in ascending order - * @param scores the scores of the matching documents - * @param maxScore the max of those scores? why do we need to pass in? - * @param segmentStarts the indexes in docs and scores corresponding to the first matching - * document in each segment. If a segment has no matching documents, it should be assigned - * the index of the next segment that does. There should be a final entry that is always - * docs.length-1. - * @param visited the number of graph nodes that were visited, and for which vector distance - * scores were evaluated. - * @param contextIdentity an object identifying the reader context that was used to build this - * query - */ - DocAndScoreQuery( - int[] docs, - float[] scores, - float maxScore, - int[] segmentStarts, - long visited, - Object contextIdentity, - int reentryCount) { - this.docs = docs; - this.scores = scores; - this.maxScore = maxScore; - this.segmentStarts = segmentStarts; - this.visited = visited; - this.contextIdentity = contextIdentity; - this.reentryCount = reentryCount; - } - - /* - DocAndScoreQuery(DocAndScoreQuery other) { - this.docs = other.docs; - this.scores = other.scores; - this.maxScore = other.maxScore; - this.segmentStarts = other.segmentStarts; - this.visited = other.visited; - this.contextIdentity = other.contextIdentity; - this.reentryCount = other.reentryCount; - } - */ - - int reentryCount() { - return reentryCount; - } - - @Override - public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) - throws IOException { - if (searcher.getIndexReader().getContext().id() != contextIdentity) { - throw new IllegalStateException("This DocAndScore query was created by a different reader"); - } - return new Weight(this) { - @Override - public Explanation explain(LeafReaderContext context, int doc) { - int found = Arrays.binarySearch(docs, doc + context.docBase); - if (found < 0) { - return Explanation.noMatch("not in top " + docs.length + " docs"); - } - return Explanation.match(scores[found] * boost, "within top " + docs.length + " docs"); - } - - @Override - public int count(LeafReaderContext context) { - return segmentStarts[context.ord + 1] - segmentStarts[context.ord]; - } - - @Override - public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { - if (segmentStarts[context.ord] == segmentStarts[context.ord + 1]) { - return null; - } - final var scorer = - new Scorer() { - final int lower = segmentStarts[context.ord]; - final int upper = segmentStarts[context.ord + 1]; - int upTo = -1; - - @Override - public DocIdSetIterator iterator() { - return new DocIdSetIterator() { - @Override - public int docID() { - return docIdNoShadow(); - } - - @Override - public int nextDoc() { - if (upTo == -1) { - upTo = lower; - } else { - ++upTo; - } - return docIdNoShadow(); - } - - @Override - public int advance(int target) throws IOException { - return slowAdvance(target); - } - - @Override - public long cost() { - return upper - lower; - } - }; - } - - @Override - public float getMaxScore(int docId) { - return maxScore * boost; - } - - @Override - public float score() { - return scores[upTo] * boost; - } - - /** - * move the implementation of docID() into a differently-named method so we can call - * it from DocIDSetIterator.docID() even though this class is anonymous - * - * @return the current docid - */ - private int docIdNoShadow() { - if (upTo == -1) { - return -1; - } - if (upTo >= upper) { - return NO_MORE_DOCS; - } - return docs[upTo] - context.docBase; - } - - @Override - public int docID() { - return docIdNoShadow(); - } - }; - return new DefaultScorerSupplier(scorer); - } - - @Override - public boolean isCacheable(LeafReaderContext ctx) { - return true; - } - }; - } - - @Override - public String toString(String field) { - return "DocAndScoreQuery[" + docs[0] + ",...][" + scores[0] + ",...]," + maxScore; - } - - @Override - public void visit(QueryVisitor visitor) { - visitor.visitLeaf(this); - } - - public long visited() { - return visited; - } - - @Override - public boolean equals(Object obj) { - if (sameClassAs(obj) == false) { - return false; - } - return contextIdentity == ((DocAndScoreQuery) obj).contextIdentity - && Arrays.equals(docs, ((DocAndScoreQuery) obj).docs) - && Arrays.equals(scores, ((DocAndScoreQuery) obj).scores); - } - - @Override - public int hashCode() { - return Objects.hash( - classHash(), contextIdentity, Arrays.hashCode(docs), Arrays.hashCode(scores)); - } - } - public KnnSearchStrategy getSearchStrategy() { return searchStrategy; } diff --git a/lucene/core/src/java/org/apache/lucene/search/DocAndScoreQuery.java b/lucene/core/src/java/org/apache/lucene/search/DocAndScoreQuery.java new file mode 100644 index 000000000000..cf71ca8d8c20 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/DocAndScoreQuery.java @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; + +/** A query that wraps precomputed documents and scores */ +class DocAndScoreQuery extends Query { + + private final int[] docs; + private final float[] scores; + private final float maxScore; + private final int[] segmentStarts; + private final long visited; + private final Object contextIdentity; + private final int reentryCount; + + /** + * Constructor + * + * @param docs the global docids of documents that match, in ascending order + * @param scores the scores of the matching documents + * @param maxScore the max of those scores? why do we need to pass in? + * @param segmentStarts the indexes in docs and scores corresponding to the first matching + * document in each segment. If a segment has no matching documents, it should be assigned the + * index of the next segment that does. There should be a final entry that is always + * docs.length-1. + * @param visited the number of graph nodes that were visited, and for which vector distance + * scores were evaluated. + * @param contextIdentity an object identifying the reader context that was used to build this + * query + */ + DocAndScoreQuery( + int[] docs, + float[] scores, + float maxScore, + int[] segmentStarts, + long visited, + Object contextIdentity, + int reentryCount) { + this.docs = docs; + this.scores = scores; + this.maxScore = maxScore; + this.segmentStarts = segmentStarts; + this.visited = visited; + this.contextIdentity = contextIdentity; + this.reentryCount = reentryCount; + } + + static Query createDocAndScoreQuery(IndexReader reader, TopDocs topK, int reentryCount) { + int len = topK.scoreDocs.length; + assert len > 0; + float maxScore = topK.scoreDocs[0].score; + Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc)); + int[] docs = new int[len]; + float[] scores = new float[len]; + for (int i = 0; i < len; i++) { + docs[i] = topK.scoreDocs[i].doc; + scores[i] = topK.scoreDocs[i].score; + } + int[] segmentStarts = findSegmentStarts(reader.leaves(), docs); + return new DocAndScoreQuery( + docs, + scores, + maxScore, + segmentStarts, + topK.totalHits.value(), + reader.getContext().id(), + reentryCount); + } + + static int[] findSegmentStarts(List leaves, int[] docs) { + int[] starts = new int[leaves.size() + 1]; + starts[starts.length - 1] = docs.length; + if (starts.length == 2) { + return starts; + } + int resultIndex = 0; + for (int i = 1; i < starts.length - 1; i++) { + int upper = leaves.get(i).docBase; + resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper); + if (resultIndex < 0) { + resultIndex = -1 - resultIndex; + } + starts[i] = resultIndex; + } + return starts; + } + + int reentryCount() { + return reentryCount; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) + throws IOException { + if (searcher.getIndexReader().getContext().id() != contextIdentity) { + throw new IllegalStateException("This DocAndScore query was created by a different reader"); + } + return new Weight(this) { + @Override + public Explanation explain(LeafReaderContext context, int doc) { + int found = Arrays.binarySearch(docs, doc + context.docBase); + if (found < 0) { + return Explanation.noMatch("not in top " + docs.length + " docs"); + } + return Explanation.match(scores[found] * boost, "within top " + docs.length + " docs"); + } + + @Override + public int count(LeafReaderContext context) { + return segmentStarts[context.ord + 1] - segmentStarts[context.ord]; + } + + @Override + public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { + if (segmentStarts[context.ord] == segmentStarts[context.ord + 1]) { + return null; + } + final var scorer = + new Scorer() { + final int lower = segmentStarts[context.ord]; + final int upper = segmentStarts[context.ord + 1]; + int upTo = -1; + + @Override + public DocIdSetIterator iterator() { + return new DocIdSetIterator() { + @Override + public int docID() { + return docIdNoShadow(); + } + + @Override + public int nextDoc() { + if (upTo == -1) { + upTo = lower; + } else { + ++upTo; + } + return docIdNoShadow(); + } + + @Override + public int advance(int target) throws IOException { + return slowAdvance(target); + } + + @Override + public long cost() { + return upper - lower; + } + }; + } + + @Override + public float getMaxScore(int docId) { + return maxScore * boost; + } + + @Override + public float score() { + return scores[upTo] * boost; + } + + /** + * move the implementation of docID() into a differently-named method so we can call + * it from DocIDSetIterator.docID() even though this class is anonymous + * + * @return the current docid + */ + private int docIdNoShadow() { + if (upTo == -1) { + return -1; + } + if (upTo >= upper) { + return NO_MORE_DOCS; + } + return docs[upTo] - context.docBase; + } + + @Override + public int docID() { + return docIdNoShadow(); + } + }; + return new DefaultScorerSupplier(scorer); + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } + }; + } + + @Override + public String toString(String field) { + return "DocAndScoreQuery[" + docs[0] + ",...][" + scores[0] + ",...]," + maxScore; + } + + @Override + public void visit(QueryVisitor visitor) { + visitor.visitLeaf(this); + } + + public long visited() { + return visited; + } + + @Override + public boolean equals(Object obj) { + if (sameClassAs(obj) == false) { + return false; + } + return contextIdentity == ((DocAndScoreQuery) obj).contextIdentity + && Arrays.equals(docs, ((DocAndScoreQuery) obj).docs) + && Arrays.equals(scores, ((DocAndScoreQuery) obj).scores); + } + + @Override + public int hashCode() { + return Objects.hash( + classHash(), contextIdentity, Arrays.hashCode(docs), Arrays.hashCode(scores)); + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/RescoreTopNQuery.java b/lucene/core/src/java/org/apache/lucene/search/RescoreTopNQuery.java new file mode 100644 index 000000000000..3b9a8fa045c6 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/RescoreTopNQuery.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.Objects; +import org.apache.lucene.index.IndexReader; + +/** + * A Query that re-scores another Query with a {@link DoubleValuesSource} function and cut-off the + * results at top N. Unlike {@link Rescorer} which does rescoring at post-collection phase, this + * Query does the rescoring at rewrite() phase. The reason it operates in rewrite phase is to be + * compatible with KNN vector query, where the results are collected upfront, but it can work with + * any type of Query. Unlike or FunctionScoreQuery, this Query will work even with the + * no-scoring {@link ScoreMode}. + * + * @lucene.experimental + */ +public class RescoreTopNQuery extends Query { + + private final int n; + private final Query query; + private final DoubleValuesSource valuesSource; + + /** + * Execute the inner Query, re-score using a customizable DoubleValueSource and trim down the + * result to k + * + * @param query the query to execute as initial phase + * @param valuesSource the double value source to re-score + * @param n the number of documents to find + * @throws IllegalArgumentException if n is less than 1 + */ + public RescoreTopNQuery(Query query, DoubleValuesSource valuesSource, int n) { + if (n < 1) { + throw new IllegalArgumentException("n must be >= 1"); + } + this.query = query; + this.valuesSource = valuesSource; + this.n = n; + } + + @Override + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + DoubleValuesSource rewrittenValueSource = valuesSource.rewrite(indexSearcher); + IndexReader reader = indexSearcher.getIndexReader(); + Query rewritten = indexSearcher.rewrite(query); + Weight weight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1.0f); + HitQueue queue = new HitQueue(n, false); + int originalCount = 0; + for (var leaf : reader.leaves()) { + Scorer innerScorer = weight.scorer(leaf); + if (innerScorer == null) { + continue; + } + DoubleValues rescores = rewrittenValueSource.getValues(leaf, getDoubleValues(innerScorer)); + DocIdSetIterator iterator = innerScorer.iterator(); + while (iterator.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + int docId = iterator.docID(); + if (rescores.advanceExact(docId)) { + double v = rescores.doubleValue(); + queue.insertWithOverflow(new ScoreDoc(leaf.docBase + docId, (float) v)); + } else { + queue.insertWithOverflow(new ScoreDoc(leaf.docBase + docId, 0f)); + } + originalCount++; + } + } + int i = 0; + ScoreDoc[] scoreDocs = new ScoreDoc[queue.size()]; + for (ScoreDoc topDoc : queue) { + scoreDocs[i++] = topDoc; + } + TopDocs topDocs = + new TopDocs(new TotalHits(originalCount, TotalHits.Relation.EQUAL_TO), scoreDocs); + return DocAndScoreQuery.createDocAndScoreQuery(reader, topDocs, 0); + } + + private DoubleValues getDoubleValues(Scorer innerScorer) { + // if the value source doesn't need document score to compute value, return null + if (valuesSource.needsScores() == false) { + return null; + } + return DoubleValuesSource.fromScorer(innerScorer); + } + + @Override + public int hashCode() { + int result = valuesSource.hashCode(); + result = 31 * result + Objects.hash(query, n); + return result; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RescoreTopNQuery that = (RescoreTopNQuery) o; + return Objects.equals(query, that.query) + && Objects.equals(valuesSource, that.valuesSource) + && n == that.n; + } + + @Override + public void visit(QueryVisitor visitor) { + query.visit(visitor); + } + + @Override + public String toString(String field) { + return getClass().getSimpleName() + + ":" + + query.toString(field) + + ":" + + valuesSource.toString() + + "[" + + n + + "]"; + } + + /** + * Utility method to create a new RescoreTopNQuery which uses full-precision vectors for + * rescoring. + * + * @param in the inner Query to rescore + * @param targetVector the target vector to compute score + * @param field the vector field to compute score + * @param n the number of results to keep + * @return the RescoreTopNQuery + */ + public static Query createFullPrecisionRescorerQuery( + Query in, float[] targetVector, String field, int n) { + DoubleValuesSource valuaSource = + new FullPrecisionFloatVectorSimilarityValuesSource(targetVector, field); + return new RescoreTopNQuery(in, valuaSource, n); + } +} diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java index 5278beb5d97b..48ae27ce0b0e 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java @@ -232,10 +232,10 @@ public void testDocAndScoreQueryBasics() throws IOException { maxScore = Math.max(maxScore, scores[i]); } IndexReader indexReader = searcher.getIndexReader(); - int[] segments = AbstractKnnVectorQuery.findSegmentStarts(indexReader.leaves(), docs); + int[] segments = DocAndScoreQuery.findSegmentStarts(indexReader.leaves(), docs); - AbstractKnnVectorQuery.DocAndScoreQuery query = - new AbstractKnnVectorQuery.DocAndScoreQuery( + DocAndScoreQuery query = + new DocAndScoreQuery( docs, scores, maxScore, diff --git a/lucene/core/src/test/org/apache/lucene/search/TestRescoreTopNQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestRescoreTopNQuery.java new file mode 100644 index 000000000000..e1be2c81822a --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestRescoreTopNQuery.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Random; +import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.IntField; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.ByteBuffersDirectory; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.tests.util.TestUtil; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class TestRescoreTopNQuery extends LuceneTestCase { + + private static final String FIELD = "vector"; + private static final VectorSimilarityFunction VECTOR_SIMILARITY_FUNCTION = + VectorSimilarityFunction.COSINE; + private static final int NUM_VECTORS = 1000; + private static final int VECTOR_DIMENSION = 128; + + private Directory directory; + private IndexWriterConfig config; + + @Before + @Override + public void setUp() throws Exception { + super.setUp(); + directory = new ByteBuffersDirectory(); + + // Set up the IndexWriterConfig to use quantized vector storage + config = new IndexWriterConfig(); + config.setCodec( + TestUtil.alwaysKnnVectorsFormat(new Lucene99HnswScalarQuantizedVectorsFormat())); + } + + @Test + public void testInvalidN() { + expectThrows( + IllegalArgumentException.class, + () -> + new RescoreTopNQuery( + new TermQuery(new Term("test")), DoubleValuesSource.constant(0), 0)); + } + + @Test + public void testRescoreField() throws Exception { + Map vectors = new HashMap<>(); + + Random random = random(); + + int numVectors = atLeast(NUM_VECTORS); + int numSegments = random.nextInt(2, 10); + + // Step 1: Index random vectors in quantized format + try (IndexWriter writer = new IndexWriter(directory, config)) { + for (int j = 0; j < numSegments; j++) { + for (int i = 0; i < numVectors; i++) { + float[] vector = randomFloatVector(VECTOR_DIMENSION, random); + Document doc = new Document(); + int id = j * numVectors + i; + doc.add(new IntField("id", id, Field.Store.YES)); + doc.add(new KnnFloatVectorField(FIELD, vector, VECTOR_SIMILARITY_FUNCTION)); + writer.addDocument(doc); + vectors.put(id, vector); + + writer.flush(); + } + } + } + + // Step 2: Run TwoPhaseKnnVectorQuery with a random target vector + try (IndexReader reader = DirectoryReader.open(directory)) { + IndexSearcher searcher = new IndexSearcher(reader); + float[] targetVector = randomFloatVector(VECTOR_DIMENSION, random); + int k = 10; + double oversample = random.nextFloat(1.5f, 3.0f); + + KnnFloatVectorQuery knnQuery = + new KnnFloatVectorQuery(FIELD, targetVector, k + (int) (k * oversample)); + + Query query = + RescoreTopNQuery.createFullPrecisionRescorerQuery(knnQuery, targetVector, FIELD, k); + TopDocs topDocs = searcher.search(query, k); + + // Step 3: Verify that TopDocs scores match similarity with unquantized vectors + for (ScoreDoc scoreDoc : topDocs.scoreDocs) { + Document retrievedDoc = searcher.storedFields().document(scoreDoc.doc); + int id = retrievedDoc.getField("id").numericValue().intValue(); + float[] docVector = vectors.get(id); + assert docVector != null : "Vector for id " + id + " not found"; + float expectedScore = VECTOR_SIMILARITY_FUNCTION.compare(targetVector, docVector); + Assert.assertEquals( + "Score does not match expected similarity for doc ord: " + scoreDoc.doc + ", id: " + id, + expectedScore, + scoreDoc.score, + 1e-5); + } + } + } + + public void testMissingDoubleValues() throws IOException { + Random random = random(); + + try (IndexWriter writer = new IndexWriter(directory, config)) { + float[] vector = randomFloatVector(VECTOR_DIMENSION, random); + Document doc = new Document(); + doc.add(new KnnFloatVectorField(FIELD, vector, VECTOR_SIMILARITY_FUNCTION)); + writer.addDocument(doc); + } + + // Step 2: Run TwoPhaseKnnVectorQuery with a random target vector + try (IndexReader reader = DirectoryReader.open(directory)) { + IndexSearcher searcher = new IndexSearcher(reader); + float[] targetVector = randomFloatVector(VECTOR_DIMENSION, random); + int k = 1; + + KnnFloatVectorQuery knnQuery = new KnnFloatVectorQuery(FIELD, targetVector, k); + + Query query = + RescoreTopNQuery.createFullPrecisionRescorerQuery(knnQuery, targetVector, "field-1", k); + TopDocs topDocs = searcher.search(query, k); + + // Step 3: The rescoring field is invalid, so the score should be 0 + for (ScoreDoc scoreDoc : topDocs.scoreDocs) { + Assert.assertEquals("Score must be 0 for missing DoubleValues", 0, scoreDoc.score, 1e-5); + } + } + } + + private float[] randomFloatVector(int dimension, Random random) { + float[] vector = new float[dimension]; + for (int i = 0; i < dimension; i++) { + vector[i] = random.nextFloat(); + } + return vector; + } +}