diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index e90ca7c9edb0..f90d97414f12 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -128,6 +128,10 @@ API Changes * GITHUB#14978: Add a bulk scoring interface to RandomVectorScorer (Trevor McCulloch, Chris Hegarty) +* GITHUB#15011: LeafReader#searchNearestVectors now accepts an AcceptDocs + instance instead of a Bits instance to identify document IDs to filter. + (Shubham Chaudhary, Adrien Grand) + New Features --------------------- * GITHUB#15015: MultiIndexMergeScheduler: a production multi-tenant merge scheduler (Shawn Yarbrough) diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java index deb0607f5bdb..351c96bad60b 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java @@ -35,6 +35,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.internal.hppc.IntObjectHashMap; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.ChecksumIndexInput; @@ -242,7 +243,7 @@ public ByteVectorValues getByteVectorValues(String field) { } @Override - public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { final FieldEntry fieldEntry = getFieldEntry(field); if (fieldEntry.size() == 0) { @@ -260,7 +261,7 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits vectorValues, fieldEntry.similarityFunction, getGraphValues(fieldEntry), - getAcceptOrds(acceptDocs, fieldEntry), + getAcceptOrds(acceptDocs.bits(), fieldEntry), knnCollector.visitLimit(), random); knnCollector.incVisitedCount(results.visitedCount()); @@ -273,7 +274,7 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits } @Override - public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { throw new UnsupportedOperationException(); } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java index 472c2a959440..9bb320e71287 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java @@ -37,6 +37,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.internal.hppc.IntObjectHashMap; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.VectorScorer; @@ -238,7 +239,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { final FieldEntry fieldEntry = getFieldEntry(field); if (fieldEntry.size() == 0) { @@ -253,11 +254,11 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits scorer, new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), getGraph(fieldEntry), - getAcceptOrds(acceptDocs, fieldEntry)); + getAcceptOrds(acceptDocs.bits(), fieldEntry)); } @Override - public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { throw new UnsupportedOperationException(); } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java index df46cf2f2769..32dab1b6c7c7 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java @@ -36,11 +36,11 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.internal.hppc.IntObjectHashMap; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataInput; import org.apache.lucene.store.IndexInput; -import org.apache.lucene.util.Bits; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.lucene.util.hnsw.HnswGraphSearcher; @@ -236,7 +236,7 @@ public ByteVectorValues getByteVectorValues(String field) { } @Override - public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { final FieldEntry fieldEntry = getFieldEntry(field); if (fieldEntry.size() == 0) { @@ -251,11 +251,11 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits scorer, new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), getGraph(fieldEntry), - vectorValues.getAcceptOrds(acceptDocs)); + vectorValues.getAcceptOrds(acceptDocs.bits())); } @Override - public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { throw new UnsupportedOperationException(); } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java index 6bf791a309cc..5b937375f1af 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java @@ -37,11 +37,11 @@ import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.internal.hppc.IntObjectHashMap; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataInput; import org.apache.lucene.store.IndexInput; -import org.apache.lucene.util.Bits; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.lucene.util.hnsw.HnswGraphSearcher; @@ -270,7 +270,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32); if (fieldEntry.size() == 0 || knnCollector.k() == 0) { @@ -285,11 +285,11 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits scorer, new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), getGraph(fieldEntry), - vectorValues.getAcceptOrds(acceptDocs)); + vectorValues.getAcceptOrds(acceptDocs.bits())); } @Override - public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE); if (fieldEntry.size() == 0 || knnCollector.k() == 0) { @@ -304,7 +304,7 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits scorer, new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), getGraph(fieldEntry), - vectorValues.getAcceptOrds(acceptDocs)); + vectorValues.getAcceptOrds(acceptDocs.bits())); } private HnswGraph getGraph(FieldEntry entry) throws IOException { diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java index bbeda6930f30..7fa6ffbaedad 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java @@ -41,13 +41,13 @@ import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.internal.hppc.IntObjectHashMap; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataInput; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.RandomAccessInput; import org.apache.lucene.util.ArrayUtil; -import org.apache.lucene.util.Bits; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.lucene.util.hnsw.HnswGraphSearcher; @@ -290,7 +290,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32); if (fieldEntry.size() == 0 || knnCollector.k() == 0) { @@ -314,11 +314,11 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits scorer, new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), getGraph(fieldEntry), - vectorValues.getAcceptOrds(acceptDocs)); + vectorValues.getAcceptOrds(acceptDocs.bits())); } @Override - public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE); if (fieldEntry.size() == 0 || knnCollector.k() == 0) { @@ -342,7 +342,7 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits scorer, new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), getGraph(fieldEntry), - vectorValues.getAcceptOrds(acceptDocs)); + vectorValues.getAcceptOrds(acceptDocs.bits())); } /** Get knn graph values; used for testing */ diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java index be2896d1cda9..4927bd73042b 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java @@ -37,6 +37,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.internal.hppc.IntObjectHashMap; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.VectorScorer; @@ -44,7 +45,6 @@ import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; -import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRefBuilder; import org.apache.lucene.util.IOUtils; @@ -178,7 +178,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { FloatVectorValues values = getFloatVectorValues(field); if (target.length != values.dimension()) { @@ -192,7 +192,7 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits VectorSimilarityFunction vectorSimilarity = info.getVectorSimilarityFunction(); for (int ord = 0; ord < values.size(); ord++) { int doc = values.ordToDoc(ord); - if (acceptDocs != null && acceptDocs.get(doc) == false) { + if (acceptDocs.bits() != null && acceptDocs.bits().get(doc) == false) { continue; } @@ -208,7 +208,7 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits } @Override - public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { ByteVectorValues values = getByteVectorValues(field); if (target.length != values.dimension()) { @@ -223,7 +223,7 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits for (int ord = 0; ord < values.size(); ord++) { int doc = values.ordToDoc(ord); - if (acceptDocs != null && acceptDocs.get(doc) == false) { + if (acceptDocs.bits() != null && acceptDocs.bits().get(doc) == false) { continue; } diff --git a/lucene/codecs/src/test/org/apache/lucene/codecs/bitvectors/TestHnswBitVectorsFormat.java b/lucene/codecs/src/test/org/apache/lucene/codecs/bitvectors/TestHnswBitVectorsFormat.java index 388f08792565..49dd2dc68e00 100644 --- a/lucene/codecs/src/test/org/apache/lucene/codecs/bitvectors/TestHnswBitVectorsFormat.java +++ b/lucene/codecs/src/test/org/apache/lucene/codecs/bitvectors/TestHnswBitVectorsFormat.java @@ -33,6 +33,7 @@ import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.StoredFields; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopKnnCollector; import org.apache.lucene.store.Directory; @@ -84,7 +85,8 @@ public void testIndexAndSearchBitVectors() throws IOException { try (IndexReader reader = DirectoryReader.open(w)) { LeafReader r = getOnlyLeafReader(reader); TopKnnCollector collector = new TopKnnCollector(3, Integer.MAX_VALUE); - r.searchNearestVectors("v1", vectors[0], collector, null); + r.searchNearestVectors( + "v1", vectors[0], collector, AcceptDocs.fromLiveDocs(null, r.maxDoc())); TopDocs topDocs = collector.topDocs(); assertEquals(3, topDocs.scoreDocs.length); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java index 99439f387808..636a869e7ba3 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java @@ -25,8 +25,8 @@ import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.KnnCollector; -import org.apache.lucene.util.Bits; import org.apache.lucene.util.NamedSPILoader; /** @@ -140,13 +140,13 @@ public ByteVectorValues getByteVectorValues(String field) { @Override public void search( - String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) { + String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) { throw new UnsupportedOperationException(); } @Override public void search( - String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) { + String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) { throw new UnsupportedOperationException(); } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java index 2a34aa693136..99b5c81fdff6 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java @@ -25,6 +25,7 @@ import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; @@ -88,7 +89,8 @@ protected KnnVectorsReader() {} * if they are all allowed to match. */ public abstract void search( - String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException; + String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) + throws IOException; /** * Return the k nearest neighbor documents as determined by comparison of their vector values for @@ -116,7 +118,8 @@ public abstract void search( * if they are all allowed to match. */ public abstract void search( - String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException; + String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) + throws IOException; /** * Returns an instance optimized for merging. This instance may only be consumed in the thread diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsReader.java index 4816f8ad4ff1..5dfa5e1ec2e8 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsReader.java @@ -19,9 +19,9 @@ import java.io.IOException; import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Accountable; -import org.apache.lucene.util.Bits; import org.apache.lucene.util.hnsw.RandomVectorScorer; /** @@ -56,13 +56,13 @@ public FlatVectorsScorer getFlatVectorScorer() { } @Override - public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { // don't scan stored field data. If we didn't index it, produce no search results } @Override - public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { // don't scan stored field data. If we didn't index it, produce no search results } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsReader.java index e84c04b5e898..80e532df0f99 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102BinaryQuantizedVectorsReader.java @@ -38,6 +38,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.ChecksumIndexInput; @@ -223,20 +224,20 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { rawVectorsReader.search(field, target, knnCollector, acceptDocs); } @Override - public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { if (knnCollector.k() == 0) return; final RandomVectorScorer scorer = getRandomVectorScorer(field, target); if (scorer == null) return; OrdinalTranslatedKnnCollector collector = new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc); - Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs); + Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs.bits()); for (int i = 0; i < scorer.maxOrd(); i++) { if (acceptedOrds == null || acceptedOrds.get(i)) { collector.collect(i, scorer.score(i)); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java index 55a6b98c370f..cf2856023e91 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java @@ -38,6 +38,7 @@ import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.internal.hppc.IntObjectHashMap; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataAccessHint; @@ -48,7 +49,6 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.PreloadHint; import org.apache.lucene.store.RandomAccessInput; -import org.apache.lucene.util.BitSet; import org.apache.lucene.util.Bits; import org.apache.lucene.util.GroupVIntUtil; import org.apache.lucene.util.IOSupplier; @@ -299,7 +299,7 @@ private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) } @Override - public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32); search( @@ -310,7 +310,7 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits } @Override - public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE); search( @@ -323,7 +323,7 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits private void search( FieldEntry fieldEntry, KnnCollector knnCollector, - Bits acceptDocs, + AcceptDocs acceptDocs, IOSupplier scorerSupplier) throws IOException { if (fieldEntry.size() == 0 || knnCollector.k() == 0) { @@ -332,20 +332,18 @@ private void search( final RandomVectorScorer scorer = scorerSupplier.get(); final KnnCollector collector = new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc); - final Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs); HnswGraph graph = getGraph(fieldEntry); - boolean doHnsw = knnCollector.k() < scorer.maxOrd(); // Take into account if quantized? E.g. some scorer cost? - int filteredDocCount = 0; + // Use approximate cardinality as this is good enough, but ensure we don't exceed the graph + // size as that is illogical + int filteredDocCount = Math.min(acceptDocs.cost(), graph.size()); + Bits accepted = acceptDocs.bits(); + final Bits acceptedOrds = scorer.getAcceptOrds(accepted); + boolean doHnsw = knnCollector.k() < scorer.maxOrd(); // The approximate number of vectors that would be visited if we did not filter int unfilteredVisit = HnswGraphSearcher.expectedVisitedNodes(knnCollector.k(), graph.size()); - if (acceptDocs instanceof BitSet bitSet) { - // Use approximate cardinality as this is good enough, but ensure we don't exceed the graph - // size as that is illogical - filteredDocCount = Math.min(bitSet.approximateCardinality(), graph.size()); - if (unfilteredVisit >= filteredDocCount) { - doHnsw = false; - } + if (unfilteredVisit >= filteredDocCount) { + doHnsw = false; } if (doHnsw) { HnswGraphSearcher.search( diff --git a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java index 1df714857135..6bf0a8888589 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java @@ -39,8 +39,8 @@ import org.apache.lucene.index.Sorter; import org.apache.lucene.internal.hppc.IntObjectHashMap; import org.apache.lucene.internal.hppc.ObjectCursor; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.KnnCollector; -import org.apache.lucene.util.Bits; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.hnsw.HnswGraph; @@ -300,7 +300,8 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { final FieldInfo info = fieldInfos.fieldInfo(field); final KnnVectorsReader reader; @@ -311,7 +312,8 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits } @Override - public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { final FieldInfo info = fieldInfos.fieldInfo(field); final KnnVectorsReader reader; diff --git a/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java b/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java index 7e98e51bf69b..fc8ec5dee352 100644 --- a/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java +++ b/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java @@ -63,6 +63,7 @@ import org.apache.lucene.index.PointValues.IntersectVisitor; import org.apache.lucene.index.PointValues.Relation; import org.apache.lucene.internal.hppc.IntIntHashMap; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.DocAndFloatFeatureBuffer; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.FieldExistsQuery; @@ -3066,7 +3067,11 @@ private static void checkFloatVectorValues( if (vectorsReaderSupportsSearch(codecReader, fieldInfo.name)) { codecReader .getVectorReader() - .search(fieldInfo.name, values.vectorValue(count), collector, null); + .search( + fieldInfo.name, + values.vectorValue(count), + collector, + AcceptDocs.fromLiveDocs(null, codecReader.maxDoc())); TopDocs docs = collector.topDocs(); if (docs.scoreDocs.length == 0) { throw new CheckIndexException( @@ -3114,7 +3119,11 @@ private static void checkByteVectorValues( KnnCollector collector = new TopKnnCollector(10, Integer.MAX_VALUE); codecReader .getVectorReader() - .search(fieldInfo.name, values.vectorValue(count), collector, null); + .search( + fieldInfo.name, + values.vectorValue(count), + collector, + AcceptDocs.fromLiveDocs(null, codecReader.maxDoc())); TopDocs docs = collector.topDocs(); if (docs.scoreDocs.length == 0) { throw new CheckIndexException( diff --git a/lucene/core/src/java/org/apache/lucene/index/CodecReader.java b/lucene/core/src/java/org/apache/lucene/index/CodecReader.java index 20be7e1a45a8..a39b05ee9829 100644 --- a/lucene/core/src/java/org/apache/lucene/index/CodecReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/CodecReader.java @@ -25,8 +25,8 @@ import org.apache.lucene.codecs.PointsReader; import org.apache.lucene.codecs.StoredFieldsReader; import org.apache.lucene.codecs.TermVectorsReader; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.KnnCollector; -import org.apache.lucene.util.Bits; /** LeafReader implemented by codec APIs. */ public abstract class CodecReader extends LeafReader { @@ -260,7 +260,8 @@ public final ByteVectorValues getByteVectorValues(String field) throws IOExcepti @Override public final void searchNearestVectors( - String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) + throws IOException { ensureOpen(); FieldInfo fi = getFieldInfos().fieldInfo(field); if (fi == null @@ -274,7 +275,8 @@ public final void searchNearestVectors( @Override public final void searchNearestVectors( - String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) + throws IOException { ensureOpen(); FieldInfo fi = getFieldInfos().fieldInfo(field); if (fi == null diff --git a/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java b/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java index 3504c7429a5e..2c985b121f63 100644 --- a/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java @@ -18,6 +18,7 @@ package org.apache.lucene.index; import java.io.IOException; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Bits; @@ -59,13 +60,15 @@ public final ByteVectorValues getByteVectorValues(String field) throws IOExcepti @Override public void searchNearestVectors( - String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) + throws IOException { throw new UnsupportedOperationException(); } @Override public void searchNearestVectors( - String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) + throws IOException { throw new UnsupportedOperationException(); } diff --git a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java index 614a652cd35a..9b6843e778d2 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java @@ -20,7 +20,9 @@ import java.util.Objects; import org.apache.lucene.index.FilterLeafReader.FilterTerms; import org.apache.lucene.index.FilterLeafReader.FilterTermsEnum; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.FilterDocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.VectorScorer; import org.apache.lucene.util.Bits; @@ -331,68 +333,92 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { return new ExitableByteVectorValues(vectorValues); } - @Override - public void searchNearestVectors( - String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) - throws IOException { + private class ExitableAcceptDocs extends AcceptDocs { - // when acceptDocs is null due to no doc deleted, we will instantiate a new one that would - // match all docs to allow timeout checking. - final Bits updatedAcceptDocs = - acceptDocs == null ? new Bits.MatchAllBits(maxDoc()) : acceptDocs; + private final AcceptDocs in; + private Bits bits; - Bits timeoutCheckingAcceptDocs = - new Bits() { - private static final int MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK = 16; - private int calls; + ExitableAcceptDocs(AcceptDocs in) { + this.in = in; + } - @Override - public boolean get(int index) { - if (calls++ % MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK == 0) { - checkAndThrowForSearchVectors(); - } + @Override + public Bits bits() throws IOException { + if (bits == null) { + // when acceptDocs is null due to no doc deleted, we will instantiate a new one that would + // match all docs to allow timeout checking. + final Bits updatedAcceptDocs = + in.bits() == null ? new Bits.MatchAllBits(maxDoc()) : in.bits(); + bits = + new Bits() { + private static final int MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK = 16; + private int calls; + + @Override + public boolean get(int index) { + if (calls++ % MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK == 0) { + checkAndThrowForSearchVectors(); + } + + return updatedAcceptDocs.get(index); + } + + @Override + public int length() { + return updatedAcceptDocs.length(); + } + }; + } + return bits; + } - return updatedAcceptDocs.get(index); + @Override + public DocIdSetIterator iterator() throws IOException { + return new FilterDocIdSetIterator(in.iterator()) { + private int docToCheck = 0; + + @Override + public int advance(int target) throws IOException { + final int advance = super.advance(target); + if (advance >= docToCheck) { + checkAndThrow(in); + docToCheck = advance + DOCS_BETWEEN_TIMEOUT_CHECK; } + return advance; + } - @Override - public int length() { - return updatedAcceptDocs.length(); + @Override + public int nextDoc() throws IOException { + final int nextDoc = super.nextDoc(); + if (nextDoc >= docToCheck) { + checkAndThrow(in); + docToCheck = nextDoc + DOCS_BETWEEN_TIMEOUT_CHECK; } - }; + return nextDoc; + } + }; + } - in.searchNearestVectors(field, target, knnCollector, timeoutCheckingAcceptDocs); + @Override + public int cost() throws IOException { + return in.cost(); + } } @Override public void searchNearestVectors( - String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { - // when acceptDocs is null due to no doc deleted, we will instantiate a new one that would - // match all docs to allow timeout checking. - final Bits updatedAcceptDocs = - acceptDocs == null ? new Bits.MatchAllBits(maxDoc()) : acceptDocs; - - Bits timeoutCheckingAcceptDocs = - new Bits() { - private static final int MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK = 16; - private int calls; - - @Override - public boolean get(int index) { - if (calls++ % MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK == 0) { - checkAndThrowForSearchVectors(); - } - - return updatedAcceptDocs.get(index); - } - @Override - public int length() { - return updatedAcceptDocs.length(); - } - }; + AcceptDocs timeoutCheckingAcceptDocs = new ExitableAcceptDocs(acceptDocs); + in.searchNearestVectors(field, target, knnCollector, timeoutCheckingAcceptDocs); + } + @Override + public void searchNearestVectors( + String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) + throws IOException { + AcceptDocs timeoutCheckingAcceptDocs = new ExitableAcceptDocs(acceptDocs); in.searchNearestVectors(field, target, knnCollector, timeoutCheckingAcceptDocs); } diff --git a/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java b/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java index 8f1edf4962ee..174072407a8f 100644 --- a/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java @@ -18,6 +18,7 @@ import java.io.IOException; import java.util.Iterator; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.AttributeSource; import org.apache.lucene.util.Bits; @@ -364,13 +365,15 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { @Override public void searchNearestVectors( - String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) + throws IOException { in.searchNearestVectors(field, target, knnCollector, acceptDocs); } @Override public void searchNearestVectors( - String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) + throws IOException { in.searchNearestVectors(field, target, knnCollector, acceptDocs); } diff --git a/lucene/core/src/java/org/apache/lucene/index/LeafReader.java b/lucene/core/src/java/org/apache/lucene/index/LeafReader.java index 0f39d1ae1e8d..39ea90e182e1 100644 --- a/lucene/core/src/java/org/apache/lucene/index/LeafReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/LeafReader.java @@ -17,6 +17,7 @@ package org.apache.lucene.index; import java.io.IOException; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; @@ -244,14 +245,14 @@ public final PostingsEnum postings(Term term) throws IOException { * @param field the vector field to search * @param target the vector-valued query * @param k the number of docs to return - * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null} - * if they are all allowed to match. + * @param acceptDocs {@link AcceptDocs} that represents the allowed documents to match * @param visitedLimit the maximum number of nodes that the search is allowed to visit * @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores. * @lucene.experimental */ public final TopDocs searchNearestVectors( - String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException { + String field, float[] target, int k, AcceptDocs acceptDocs, int visitedLimit) + throws IOException { FieldInfo fi = getFieldInfos().fieldInfo(field); if (fi == null || fi.getVectorDimension() == 0) { return TopDocsCollector.EMPTY_TOPDOCS; @@ -288,14 +289,14 @@ public final TopDocs searchNearestVectors( * @param field the vector field to search * @param target the vector-valued query * @param k the number of docs to return - * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null} - * if they are all allowed to match. + * @param acceptDocs {@link AcceptDocs} that represents the allowed documents to match * @param visitedLimit the maximum number of nodes that the search is allowed to visit * @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores. * @lucene.experimental */ public final TopDocs searchNearestVectors( - String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException { + String field, byte[] target, int k, AcceptDocs acceptDocs, int visitedLimit) + throws IOException { FieldInfo fi = getFieldInfos().fieldInfo(field); if (fi == null || fi.getVectorDimension() == 0) { return TopDocsCollector.EMPTY_TOPDOCS; @@ -335,12 +336,12 @@ public final TopDocs searchNearestVectors( * @param field the vector field to search * @param target the vector-valued query * @param knnCollector collector with settings for gathering the vector results. - * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null} - * if they are all allowed to match. + * @param acceptDocs {@link AcceptDocs} that represents the allowed documents to match * @lucene.experimental */ public abstract void searchNearestVectors( - String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException; + String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) + throws IOException; /** * Return the k nearest neighbor documents as determined by comparison of their vector values for @@ -364,12 +365,12 @@ public abstract void searchNearestVectors( * @param field the vector field to search * @param target the vector-valued query * @param knnCollector collector with settings for gathering the vector results. - * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null} - * if they are all allowed to match. + * @param acceptDocs {@link AcceptDocs} that represents the allowed documents to match * @lucene.experimental */ public abstract void searchNearestVectors( - String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException; + String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) + throws IOException; /** * Get the {@link FieldInfos} describing all fields in this reader. diff --git a/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java b/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java index f96c5174f375..b714ef4e7c31 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java @@ -26,6 +26,7 @@ import java.util.Set; import java.util.SortedMap; import java.util.TreeMap; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.Sort; import org.apache.lucene.util.Bits; @@ -464,7 +465,7 @@ public ByteVectorValues getByteVectorValues(String fieldName) throws IOException @Override public void searchNearestVectors( - String fieldName, float[] target, KnnCollector knnCollector, Bits acceptDocs) + String fieldName, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { ensureOpen(); LeafReader reader = fieldToReader.get(fieldName); @@ -475,7 +476,7 @@ public void searchNearestVectors( @Override public void searchNearestVectors( - String fieldName, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + String fieldName, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { ensureOpen(); LeafReader reader = fieldToReader.get(fieldName); diff --git a/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java b/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java index d7ed3d2f06f4..23c4276a0df1 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java @@ -30,6 +30,7 @@ import org.apache.lucene.codecs.StoredFieldsReader; import org.apache.lucene.codecs.TermVectorsReader; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Bits; @@ -175,13 +176,15 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { reader.searchNearestVectors(field, target, knnCollector, acceptDocs); } @Override - public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { reader.searchNearestVectors(field, target, knnCollector, acceptDocs); } diff --git a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java index d15d5de1f51a..0d0c82a755a3 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java @@ -34,6 +34,7 @@ import org.apache.lucene.codecs.TermVectorsReader; import org.apache.lucene.index.MultiDocValues.MultiSortedDocValues; import org.apache.lucene.index.MultiDocValues.MultiSortedSetDocValues; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.Bits; @@ -1026,13 +1027,15 @@ private static int binarySearchStarts(int[] starts, int ord, int from, int to) { } @Override - public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { throw new UnsupportedOperationException(); } @Override - public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { throw new UnsupportedOperationException(); } diff --git a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java index 29b8fbae2429..90c09dae8db1 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java @@ -33,6 +33,7 @@ import org.apache.lucene.codecs.PointsReader; import org.apache.lucene.codecs.StoredFieldsReader; import org.apache.lucene.codecs.TermVectorsReader; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.Sort; @@ -581,12 +582,14 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) { + public void search( + String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) { throw new UnsupportedOperationException(); } @Override - public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) { + public void search( + String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) { throw new UnsupportedOperationException(); } diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index e4115766ab5b..2d94514364f8 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -35,10 +35,7 @@ import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.search.knn.KnnSearchStrategy; import org.apache.lucene.search.knn.TopKnnCollectorManager; -import org.apache.lucene.util.BitSet; -import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.Bits; -import org.apache.lucene.util.FixedBitSet; /** * Uses {@link KnnVectorsReader#search} to perform nearest neighbour search. @@ -193,16 +190,23 @@ private TopDocs getLeafResults( final Bits liveDocs = reader.getLiveDocs(); if (filterWeight == null) { - return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, timeLimitingKnnCollectorManager); + AcceptDocs acceptDocs = AcceptDocs.fromLiveDocs(liveDocs, reader.maxDoc()); + return approximateSearch(ctx, acceptDocs, Integer.MAX_VALUE, timeLimitingKnnCollectorManager); } - Scorer scorer = filterWeight.scorer(ctx); - if (scorer == null) { - return NO_RESULTS; - } - - BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, reader.maxDoc()); - final int cost = acceptDocs.cardinality(); + AcceptDocs acceptDocs = + AcceptDocs.fromIteratorSupplier( + () -> { + Scorer scorer = filterWeight.scorer(ctx); + if (scorer == null) { + return DocIdSetIterator.empty(); + } else { + return scorer.iterator(); + } + }, + liveDocs, + reader.maxDoc()); + final int cost = acceptDocs.cost(); QueryTimeout queryTimeout = timeLimitingKnnCollectorManager.getQueryTimeout(); float leafProportion = ctx.reader().maxDoc() / (float) ctx.parent.reader().maxDoc(); @@ -211,7 +215,7 @@ private TopDocs getLeafResults( if (cost <= perLeafTopK) { // If there are <= perLeafTopK possible matches, short-circuit and perform exact search, since // HNSW must always visit at least perLeafTopK documents - return exactSearch(ctx, new BitSetIterator(acceptDocs, cost), queryTimeout); + return exactSearch(ctx, acceptDocs.iterator(), queryTimeout); } // Perform the approximate kNN search @@ -227,35 +231,7 @@ private TopDocs getLeafResults( return results; } else { // We stopped the kNN search because it visited too many nodes, so fall back to exact search - return exactSearch(ctx, new BitSetIterator(acceptDocs, cost), queryTimeout); - } - } - - private BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc) - throws IOException { - if (liveDocs == null && iterator instanceof BitSetIterator bitSetIterator) { - // If we already have a BitSet and no deletions, reuse the BitSet - return bitSetIterator.getBitSet(); - } else { - int threshold = maxDoc >> 7; // same as BitSet#of - if (iterator.cost() >= threshold) { - // take advantage of Disi#intoBitset and Bits#applyMask - FixedBitSet bitSet = new FixedBitSet(maxDoc); - bitSet.or(iterator); - if (liveDocs != null) { - liveDocs.applyMask(bitSet, 0); - } - return bitSet; - } else { - FilteredDocIdSetIterator filterIterator = - new FilteredDocIdSetIterator(iterator) { - @Override - protected boolean match(int doc) { - return liveDocs == null || liveDocs.get(doc); - } - }; - return BitSet.of(filterIterator, maxDoc); // create a sparse bitset - } + return exactSearch(ctx, acceptDocs.iterator(), queryTimeout); } } @@ -297,7 +273,7 @@ private static int perLeafTopKCalculation(int k, float leafProportion) { protected abstract TopDocs approximateSearch( LeafReaderContext context, - Bits acceptDocs, + AcceptDocs acceptDocs, int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException; diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java index f80ebd34c655..6a8c0d19e41f 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java @@ -26,10 +26,7 @@ import org.apache.lucene.index.QueryTimeout; import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.search.knn.KnnSearchStrategy; -import org.apache.lucene.util.BitSet; -import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.Bits; -import org.apache.lucene.util.FixedBitSet; /** * Search for all (approximate) vectors above a similarity threshold. @@ -73,7 +70,7 @@ protected KnnCollectorManager getKnnCollectorManager() { protected abstract TopDocs approximateSearch( LeafReaderContext context, - Bits acceptDocs, + AcceptDocs acceptDocs, int visitLimit, KnnCollectorManager knnCollectorManager) throws IOException; @@ -128,45 +125,26 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti // Return exhaustive results TopDocs results = approximateSearch( - context, liveDocs, Integer.MAX_VALUE, timeLimitingKnnCollectorManager); + context, + AcceptDocs.fromLiveDocs(liveDocs, leafReader.maxDoc()), + Integer.MAX_VALUE, + timeLimitingKnnCollectorManager); return VectorSimilarityScorerSupplier.fromScoreDocs(boost, results.scoreDocs); } else { - Scorer scorer = filterWeight.scorer(context); - if (scorer == null) { - // If the filter does not match any documents - return null; - } - - BitSet acceptDocs; - if (liveDocs == null && scorer.iterator() instanceof BitSetIterator bitSetIterator) { - // If there are no deletions, and matching docs are already cached - acceptDocs = bitSetIterator.getBitSet(); - } else { - // Else collect all matching docs - DocIdSetIterator iterator = scorer.iterator(); - final int maxDoc = leafReader.maxDoc(); - int threshold = maxDoc >> 7; // same as BitSet#of - if (iterator.cost() >= threshold) { - // take advantage of Disi#intoBitset and Bits#applyMask - FixedBitSet bitSet = new FixedBitSet(maxDoc); - bitSet.or(iterator); - if (liveDocs != null) { - liveDocs.applyMask(bitSet, 0); - } - acceptDocs = bitSet; - } else { - FilteredDocIdSetIterator filterIterator = - new FilteredDocIdSetIterator(iterator) { - @Override - protected boolean match(int doc) { - return liveDocs == null || liveDocs.get(doc); + AcceptDocs acceptDocs = + AcceptDocs.fromIteratorSupplier( + () -> { + Scorer scorer = filterWeight.scorer(context); + if (scorer == null) { + return DocIdSetIterator.empty(); + } else { + return scorer.iterator(); } - }; - acceptDocs = BitSet.of(filterIterator, maxDoc); // create a sparse bitset - } - } + }, + liveDocs, + leafReader.maxDoc()); - int cardinality = acceptDocs.cardinality(); + int cardinality = acceptDocs.cost(); if (cardinality == 0) { // If there are no live matching docs return null; @@ -184,10 +162,7 @@ protected boolean match(int doc) { } else { // Return a lazy-loading iterator return VectorSimilarityScorerSupplier.fromAcceptDocs( - boost, - createVectorScorer(context), - new BitSetIterator(acceptDocs, cardinality), - resultSimilarity); + boost, createVectorScorer(context), acceptDocs.iterator(), resultSimilarity); } } } diff --git a/lucene/core/src/java/org/apache/lucene/search/AcceptDocs.java b/lucene/core/src/java/org/apache/lucene/search/AcceptDocs.java new file mode 100644 index 000000000000..175078434efe --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/AcceptDocs.java @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.Objects; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.BitSetIterator; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.FixedBitSet; +import org.apache.lucene.util.IOSupplier; + +/** + * Higher-level abstraction for document acceptance filtering. Can be consumed in either + * random-access (Bits) or sequential (DocIdSetIterator) pattern. + * + * @lucene.experimental + */ +public abstract class AcceptDocs { + + /** + * Random access to the accepted documents. + * + * @return Bits instance for random access, or null if all documents are accepted + * @throws IOException if an I/O error occurs + */ + public abstract Bits bits() throws IOException; + + /** + * Create a new iterator of accepted docs. There accepted docs already ignore deleted docs. + * + *

NOTE: If you also plan on calling {@link #bits()} or {@link #cost()}, it is + * recommended to call these methods before {@link #iterator()} for better performance. + * + * @return DocIdSetIterator for sequential access + * @throws IOException if an I/O error occurs + */ + public abstract DocIdSetIterator iterator() throws IOException; + + /** + * Return an approximation of the number of accepted documents. This is typically useful to decide + * whether to consume these accept docs using random access ({@link #bits()}) or sequential access + * ({@link #iterator()}). + * + *

NOTE: This must not be called after {@link #iterator()}. + * + * @return approximate cost + */ + public abstract int cost() throws IOException; + + /** + * Create AcceptDocs from a {@link Bits} instance representing live documents. A {@code null} + * instance is interpreted as matching all documents, like in {@link LeafReader#getLiveDocs()}. + * + * @param bits the Bits instance for random access + * @param maxDoc the number of documents in the reader + * @return AcceptDocs wrapping the Bits + */ + public static AcceptDocs fromLiveDocs(Bits bits, int maxDoc) { + return new BitsAcceptDocs(bits, maxDoc); + } + + /** + * Create AcceptDocs from an {@link IOSupplier} of {@link DocIdSetIterator}, optionally filtered + * by live documents. + * + * @param iteratorSupplier a DocIdSetIterator iterator + * @param liveDocs Bits representing live documents, or {@code null} if no deleted docs. + * @param maxDoc the number of documents in the reader + * @return AcceptDocs wrapping the iterator + */ + public static AcceptDocs fromIteratorSupplier( + IOSupplier iteratorSupplier, Bits liveDocs, int maxDoc) { + return new DocIdSetIteratorAcceptDocs(iteratorSupplier, liveDocs, maxDoc); + } + + private static BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc) + throws IOException { + if (liveDocs == null && iterator instanceof BitSetIterator bitSetIterator) { + // If we already have a BitSet and no deletions, reuse the BitSet + return bitSetIterator.getBitSet(); + } else { + int threshold = maxDoc >> 7; // same as BitSet#of + if (iterator.cost() >= threshold) { + // take advantage of Disi#intoBitset and Bits#applyMask + FixedBitSet bitSet = new FixedBitSet(maxDoc); + bitSet.or(iterator); + if (liveDocs != null) { + liveDocs.applyMask(bitSet, 0); + } + return bitSet; + } else { + return BitSet.of( + AcceptDocs.getFilteredDocIdSetIterator(iterator, liveDocs), + maxDoc); // create a sparse bitset + } + } + } + + /** + * Impl backed by Bits, expected to be somewhat dense, except when a {@link BitSet} is provided, + * in which case it's not necessarily dense. + */ + private static class BitsAcceptDocs extends AcceptDocs { + private final Bits bits; + private final int maxDoc; + + BitsAcceptDocs(Bits bits, int maxDoc) { + if (bits != null && bits.length() != maxDoc) { + throw new IllegalArgumentException( + "Bits length = " + bits.length() + " != maxDoc = " + maxDoc); + } + this.bits = bits; + if (bits instanceof BitSet bitSet) { + this.maxDoc = Objects.requireNonNull(bitSet).cardinality(); + } else { + this.maxDoc = maxDoc; + } + } + + @Override + public Bits bits() { + return bits; + } + + @Override + public DocIdSetIterator iterator() { + if (bits instanceof BitSet bitSet) { + return new BitSetIterator(bitSet, maxDoc); + } + return AcceptDocs.getFilteredDocIdSetIterator(DocIdSetIterator.all(maxDoc), bits); + } + + @Override + public int cost() { + // We have no better estimate. This should be ok in practice since background merges should + // keep the number of deletes under control (< 20% by default). + return maxDoc; + } + } + + /** + * Impl backed by a {@link DocIdSetIterator}, which lazily creates a {@link BitSet} if {@link + * #cost()} or {@link #bits()} are called. + */ + private static class DocIdSetIteratorAcceptDocs extends AcceptDocs { + + private final IOSupplier iteratorSupplier; + private final Bits liveDocs; + private final int maxDoc; + private BitSet acceptBitSet; + private int cardinality; + + DocIdSetIteratorAcceptDocs( + IOSupplier iteratorSupplier, Bits liveDocs, int maxDoc) { + this.iteratorSupplier = Objects.requireNonNull(iteratorSupplier); + this.liveDocs = liveDocs; + this.maxDoc = maxDoc; + } + + private void createBitSetAcceptDocsIfNecessary() throws IOException { + if (acceptBitSet == null) { + acceptBitSet = Objects.requireNonNull(createBitSet(iterator(), liveDocs, maxDoc)); + cardinality = acceptBitSet.cardinality(); + } + } + + @Override + public Bits bits() throws IOException { + createBitSetAcceptDocsIfNecessary(); + return acceptBitSet; + } + + @Override + public int cost() throws IOException { + createBitSetAcceptDocsIfNecessary(); + return acceptBitSet.cardinality(); + } + + @Override + public DocIdSetIterator iterator() throws IOException { + if (acceptBitSet != null) { + return new BitSetIterator(acceptBitSet, cardinality); + } + DocIdSetIterator iterator = Objects.requireNonNull(iteratorSupplier.get()); + return AcceptDocs.getFilteredDocIdSetIterator(iterator, liveDocs); + } + } + + private static DocIdSetIterator getFilteredDocIdSetIterator( + DocIdSetIterator iterator, Bits liveDocs) { + if (liveDocs != null) { + iterator = + new FilteredDocIdSetIterator(iterator) { + @Override + protected boolean match(int doc) { + return liveDocs.get(doc); + } + }; + } + return iterator; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityQuery.java b/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityQuery.java index 90d7b1ee16ef..0c66d3345588 100644 --- a/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityQuery.java @@ -24,7 +24,6 @@ import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.knn.KnnCollectorManager; -import org.apache.lucene.util.Bits; /** * Search for all (approximate) byte vectors above a similarity threshold. @@ -109,7 +108,7 @@ VectorScorer createVectorScorer(LeafReaderContext context) throws IOException { @SuppressWarnings("resource") protected TopDocs approximateSearch( LeafReaderContext context, - Bits acceptDocs, + AcceptDocs acceptDocs, int visitLimit, KnnCollectorManager knnCollectorManager) throws IOException { diff --git a/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityQuery.java b/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityQuery.java index 0a440faefce6..6f04b76cb4e5 100644 --- a/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityQuery.java @@ -24,7 +24,6 @@ import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.knn.KnnCollectorManager; -import org.apache.lucene.util.Bits; import org.apache.lucene.util.VectorUtil; /** @@ -111,7 +110,7 @@ VectorScorer createVectorScorer(LeafReaderContext context) throws IOException { @SuppressWarnings("resource") protected TopDocs approximateSearch( LeafReaderContext context, - Bits acceptDocs, + AcceptDocs acceptDocs, int visitLimit, KnnCollectorManager knnCollectorManager) throws IOException { diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java index 231bcaed1910..cea0e1bec658 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java @@ -30,10 +30,9 @@ import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.search.knn.KnnSearchStrategy; import org.apache.lucene.util.ArrayUtil; -import org.apache.lucene.util.Bits; /** - * Uses {@link KnnVectorsReader#search(String, byte[], KnnCollector, Bits)} to perform nearest + * Uses {@link KnnVectorsReader#search(String, byte[], KnnCollector, AcceptDocs)} to perform nearest * neighbour search. * *

This query also allows for performing a kNN search subject to a filter. In this case, it first @@ -100,7 +99,7 @@ public KnnByteVectorQuery( @Override protected TopDocs approximateSearch( LeafReaderContext context, - Bits acceptDocs, + AcceptDocs acceptDocs, int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException { diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java index 5156ff2816fa..9c0f22625dc8 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java @@ -30,12 +30,11 @@ import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.search.knn.KnnSearchStrategy; import org.apache.lucene.util.ArrayUtil; -import org.apache.lucene.util.Bits; import org.apache.lucene.util.VectorUtil; /** - * Uses {@link KnnVectorsReader#search(String, float[], KnnCollector, Bits)} to perform nearest - * neighbour search. + * Uses {@link KnnVectorsReader#search(String, float[], KnnCollector, AcceptDocs)} to perform + * nearest neighbour search. * *

This query also allows for performing a kNN search subject to a filter. In this case, it first * executes the filter for each leaf, then chooses a strategy dynamically: @@ -101,7 +100,7 @@ public KnnFloatVectorQuery( @Override protected TopDocs approximateSearch( LeafReaderContext context, - Bits acceptDocs, + AcceptDocs acceptDocs, int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException { diff --git a/lucene/core/src/java/org/apache/lucene/search/PatienceKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/PatienceKnnVectorQuery.java index 170a26a714ec..f911161da0f0 100644 --- a/lucene/core/src/java/org/apache/lucene/search/PatienceKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/PatienceKnnVectorQuery.java @@ -23,7 +23,6 @@ import org.apache.lucene.index.QueryTimeout; import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.search.knn.KnnSearchStrategy; -import org.apache.lucene.util.Bits; /** * This is a version of knn vector query that exits early when HNSW queue saturates over a {@code @@ -197,7 +196,7 @@ protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher search @Override protected TopDocs approximateSearch( LeafReaderContext context, - Bits acceptDocs, + AcceptDocs acceptDocs, int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException { diff --git a/lucene/core/src/java/org/apache/lucene/search/SeededKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/SeededKnnVectorQuery.java index 2b43b6c0b638..8ee73c37a598 100644 --- a/lucene/core/src/java/org/apache/lucene/search/SeededKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/SeededKnnVectorQuery.java @@ -27,7 +27,6 @@ import org.apache.lucene.index.QueryTimeout; import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.search.knn.KnnSearchStrategy; -import org.apache.lucene.util.Bits; /** * This is a version of knn vector query that provides a query seed to initiate the vector search. @@ -149,7 +148,7 @@ Weight createSeedWeight(IndexSearcher indexSearcher) throws IOException { @Override protected TopDocs approximateSearch( LeafReaderContext context, - Bits acceptDocs, + AcceptDocs acceptDocs, int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException { diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102HnswBinaryQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102HnswBinaryQuantizedVectorsFormat.java index cadfd9439edd..31799f096690 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102HnswBinaryQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene102/TestLucene102HnswBinaryQuantizedVectorsFormat.java @@ -42,6 +42,7 @@ import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; @@ -101,7 +102,13 @@ public void testSingleVectorCase() throws Exception { } float[] randomVector = randomVector(vector.length); float trueScore = similarityFunction.compare(vector, randomVector); - TopDocs td = r.searchNearestVectors("f", randomVector, 1, null, Integer.MAX_VALUE); + TopDocs td = + r.searchNearestVectors( + "f", + randomVector, + 1, + AcceptDocs.fromLiveDocs(null, r.maxDoc()), + Integer.MAX_VALUE); assertEquals(1, td.totalHits.value()); assertTrue(td.scoreDocs[0].score >= 0); // When it's the only vector in a segment, the score should be very close to the true diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java index d2aa9b8d0542..2c6c54cece73 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java @@ -40,6 +40,7 @@ import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.NoMergePolicy; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.ScoreDoc; @@ -197,7 +198,11 @@ public void testQuantizationScoringEdgeCase() throws Exception { try (IndexReader reader = DirectoryReader.open(w)) { LeafReader r = getOnlyLeafReader(reader); TopKnnCollector topKnnCollector = new TopKnnCollector(5, Integer.MAX_VALUE); - r.searchNearestVectors("f", new float[] {0.6f, 0.8f}, topKnnCollector, null); + r.searchNearestVectors( + "f", + new float[] {0.6f, 0.8f}, + topKnnCollector, + AcceptDocs.fromLiveDocs(null, r.maxDoc())); TopDocs topDocs = topKnnCollector.topDocs(); assertEquals(3, topDocs.totalHits.value()); for (ScoreDoc scoreDoc : topDocs.scoreDocs) { diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java index 3ffeef501e87..778111fc1f0f 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java @@ -37,6 +37,7 @@ import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.StoredFields; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.store.Directory; import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; @@ -294,7 +295,9 @@ private void testSingleVectorPerSegment(VectorSimilarityFunction sim) throws IOE LeafReader leafReader = getOnlyLeafReader(reader); StoredFields storedFields = reader.storedFields(); float[] queryVector = new float[] {0.6f, 0.8f}; - var hits = leafReader.searchNearestVectors("field", queryVector, 3, null, 100); + var hits = + leafReader.searchNearestVectors( + "field", queryVector, 3, AcceptDocs.fromLiveDocs(null, leafReader.maxDoc()), 100); assertEquals(hits.scoreDocs.length, 3); assertEquals("B", storedFields.document(hits.scoreDocs[0].doc).get("id")); assertEquals("A", storedFields.document(hits.scoreDocs[1].doc).get("id")); diff --git a/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java index 45814144d10a..5d7ccdb3055d 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java @@ -44,6 +44,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; @@ -94,12 +95,16 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { "missing_field", new float[] {1, 2, 3}, 10, - reader.getLiveDocs(), + AcceptDocs.fromLiveDocs(reader.getLiveDocs(), reader.maxDoc()), Integer.MAX_VALUE); assertEquals(0, hits.scoreDocs.length); hits = reader.searchNearestVectors( - "id", new float[] {1, 2, 3}, 10, reader.getLiveDocs(), Integer.MAX_VALUE); + "id", + new float[] {1, 2, 3}, + 10, + AcceptDocs.fromLiveDocs(reader.getLiveDocs(), reader.maxDoc()), + Integer.MAX_VALUE); assertEquals(0, hits.scoreDocs.length); } } @@ -146,12 +151,20 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { LeafReader reader = ireader.leaves().get(0).reader(); TopDocs hits1 = reader.searchNearestVectors( - "field1", new float[] {1, 2, 3}, 10, reader.getLiveDocs(), Integer.MAX_VALUE); + "field1", + new float[] {1, 2, 3}, + 10, + AcceptDocs.fromLiveDocs(reader.getLiveDocs(), reader.maxDoc()), + Integer.MAX_VALUE); assertEquals(1, hits1.scoreDocs.length); TopDocs hits2 = reader.searchNearestVectors( - "field2", new float[] {1, 2, 3}, 10, reader.getLiveDocs(), Integer.MAX_VALUE); + "field2", + new float[] {1, 2, 3}, + 10, + AcceptDocs.fromLiveDocs(reader.getLiveDocs(), reader.maxDoc()), + Integer.MAX_VALUE); assertEquals(1, hits2.scoreDocs.length); } } diff --git a/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java b/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java index 046179eb82e2..0b57a63e641f 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java @@ -33,6 +33,7 @@ import org.apache.lucene.document.SortedSetDocValuesField; import org.apache.lucene.document.StringField; import org.apache.lucene.index.ExitableDirectoryReader.ExitingReaderException; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.PrefixQuery; @@ -466,7 +467,7 @@ public void testFloatVectorValues() throws IOException { "vector", TestVectorUtil.randomVector(dimension), 5, - leaf.getLiveDocs(), + AcceptDocs.fromLiveDocs(leaf.getLiveDocs(), leaf.maxDoc()), Integer.MAX_VALUE)); } else { KnnVectorValues values = leaf.getFloatVectorValues("vector"); @@ -476,7 +477,7 @@ public void testFloatVectorValues() throws IOException { "vector", TestVectorUtil.randomVector(dimension), 5, - leaf.getLiveDocs(), + AcceptDocs.fromLiveDocs(leaf.getLiveDocs(), leaf.maxDoc()), Integer.MAX_VALUE); } @@ -541,7 +542,7 @@ public void testByteVectorValues() throws IOException { "vector", TestVectorUtil.randomVectorBytes(dimension), 5, - leaf.getLiveDocs(), + AcceptDocs.fromLiveDocs(leaf.getLiveDocs(), leaf.maxDoc()), Integer.MAX_VALUE)); } else { @@ -552,7 +553,7 @@ public void testByteVectorValues() throws IOException { "vector", TestVectorUtil.randomVectorBytes(dimension), 5, - leaf.getLiveDocs(), + AcceptDocs.fromLiveDocs(leaf.getLiveDocs(), leaf.maxDoc()), Integer.MAX_VALUE); } diff --git a/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java b/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java index 2badd9d6334d..8d41236e2e2b 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java @@ -23,6 +23,7 @@ import java.util.Collections; import java.util.List; import org.apache.lucene.document.Document; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.Directory; @@ -118,11 +119,11 @@ public ByteVectorValues getByteVectorValues(String field) { @Override public void searchNearestVectors( - String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) {} + String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) {} @Override public void searchNearestVectors( - String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) {} + String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) {} @Override protected void doClose() {} diff --git a/lucene/core/src/test/org/apache/lucene/search/TestAcceptDocs.java b/lucene/core/src/test/org/apache/lucene/search/TestAcceptDocs.java new file mode 100644 index 000000000000..b8a8e7ba51bf --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestAcceptDocs.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.Arrays; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.FixedBitSet; + +public class TestAcceptDocs extends LuceneTestCase { + + public void testValidation() { + // iterator supplier must be non-null + expectThrows(NullPointerException.class, () -> AcceptDocs.fromIteratorSupplier(null, null, 1)); + + // iterator supplier may not produce null iterators + expectThrows( + NullPointerException.class, + () -> AcceptDocs.fromIteratorSupplier(() -> null, null, 1).iterator()); + + // Bits length != maxDoc + expectThrows( + IllegalArgumentException.class, () -> AcceptDocs.fromLiveDocs(new Bits.MatchNoBits(3), 4)); + } + + public void testIteratorIgnoresDeletedDocs() throws IOException { + int maxDoc = 5; + int deletedDoc = 3; + FixedBitSet liveDocs = new FixedBitSet(maxDoc); + liveDocs.set(0, liveDocs.length()); + liveDocs.clear(deletedDoc); + + Bits liveDocsBits = liveDocs.asReadOnlyBits(); + + AcceptDocs bitsAcceptDocs = AcceptDocs.fromLiveDocs(liveDocsBits, maxDoc); + AcceptDocs iteratorAcceptDocs = + AcceptDocs.fromIteratorSupplier(() -> DocIdSetIterator.all(maxDoc), liveDocsBits, maxDoc); + + for (AcceptDocs acceptDocs : Arrays.asList(bitsAcceptDocs, iteratorAcceptDocs)) { + Bits acceptBits = acceptDocs.bits(); + assertEquals(maxDoc, acceptBits.length()); + for (int i = 0; i < maxDoc; ++i) { + assertEquals(i != deletedDoc, acceptBits.get(i)); + } + + DocIdSetIterator iterator = acceptDocs.iterator(); + for (int i = 0; i < maxDoc; ++i) { + if (i != deletedDoc) { + assertEquals(i, iterator.nextDoc()); + } + } + assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc()); + } + } + + public void testIteratorIsNew() throws IOException { + int maxDoc = 5; + AcceptDocs bitsAcceptDocs = AcceptDocs.fromLiveDocs(null, maxDoc); + AcceptDocs iteratorAcceptDocs = + AcceptDocs.fromIteratorSupplier(() -> DocIdSetIterator.all(maxDoc), null, maxDoc); + + for (AcceptDocs acceptDocs : Arrays.asList(bitsAcceptDocs, iteratorAcceptDocs)) { + DocIdSetIterator iterator = acceptDocs.iterator(); + assertEquals(-1, iterator.docID()); + iterator.nextDoc(); + iterator = acceptDocs.iterator(); + assertEquals(-1, iterator.docID()); + + // Triggers lazy loading of matches into a bit set when created from an iterator + acceptDocs.bits(); + + iterator = acceptDocs.iterator(); + assertEquals(-1, iterator.docID()); + iterator.nextDoc(); + iterator = acceptDocs.iterator(); + assertEquals(-1, iterator.docID()); + } + } +} diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java index 8d15c9329e1d..f89321ef5b51 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java @@ -28,7 +28,6 @@ import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.search.knn.KnnSearchStrategy; import org.apache.lucene.store.Directory; -import org.apache.lucene.util.Bits; import org.apache.lucene.util.TestVectorUtil; public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase { @@ -149,7 +148,7 @@ public CappedResultsThrowingKnnVectorQuery( @Override protected TopDocs approximateSearch( LeafReaderContext context, - Bits acceptDocs, + AcceptDocs acceptDocs, int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException { diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java index 48ae27ce0b0e..ad1e73067f6f 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java @@ -42,7 +42,6 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.util.LuceneTestCase; -import org.apache.lucene.util.Bits; import org.apache.lucene.util.TestVectorUtil; import org.apache.lucene.util.VectorUtil; @@ -307,7 +306,7 @@ public CappedResultsThrowingKnnVectorQuery( @Override protected TopDocs approximateSearch( LeafReaderContext context, - Bits acceptDocs, + AcceptDocs acceptDocs, int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException { diff --git a/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java index 169f09db2f64..80c95b37972b 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java @@ -35,7 +35,6 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.util.TestUtil; -import org.apache.lucene.util.Bits; import org.apache.lucene.util.TestVectorUtil; public class TestSeededKnnByteVectorQuery extends BaseKnnVectorQueryTestCase { @@ -256,7 +255,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { @Override protected TopDocs approximateSearch( LeafReaderContext context, - Bits acceptDocs, + AcceptDocs acceptDocs, int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException { diff --git a/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java b/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java index cb8c71a089f0..906001c775fb 100644 --- a/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java +++ b/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java @@ -42,6 +42,7 @@ import org.apache.lucene.index.Terms; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Bits; import org.apache.lucene.util.Version; @@ -181,11 +182,11 @@ public ByteVectorValues getByteVectorValues(String fieldName) { @Override public void searchNearestVectors( - String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) {} + String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) {} @Override public void searchNearestVectors( - String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) {} + String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) {} @Override public void checkIntegrity() throws IOException {} diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java index f3e1c518bdd7..b77cf14bb739 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java @@ -24,6 +24,7 @@ import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.QueryTimeout; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.HitQueue; import org.apache.lucene.search.IndexSearcher; @@ -38,7 +39,6 @@ import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.search.knn.KnnSearchStrategy; import org.apache.lucene.util.BitSet; -import org.apache.lucene.util.Bits; /** * kNN byte vector query that joins matching children vector documents with their parent doc id. The @@ -162,7 +162,7 @@ protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher search @Override protected TopDocs approximateSearch( LeafReaderContext context, - Bits acceptDocs, + AcceptDocs acceptDocs, int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException { diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java index d98fc62fdbe9..22f5333102e9 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java @@ -24,6 +24,7 @@ import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.QueryTimeout; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.HitQueue; import org.apache.lucene.search.IndexSearcher; @@ -38,7 +39,6 @@ import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.search.knn.KnnSearchStrategy; import org.apache.lucene.util.BitSet; -import org.apache.lucene.util.Bits; /** * kNN float vector query that joins matching children vector documents with their parent doc id. @@ -161,7 +161,7 @@ protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher search @Override protected TopDocs approximateSearch( LeafReaderContext context, - Bits acceptDocs, + AcceptDocs acceptDocs, int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException { diff --git a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java index b6b1db2637a5..145a6c8ce6f0 100644 --- a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java +++ b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java @@ -69,6 +69,7 @@ import org.apache.lucene.index.TermVectors; import org.apache.lucene.index.Terms; import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.Collector; import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.DocIdSetIterator; @@ -1757,11 +1758,11 @@ public ByteVectorValues getByteVectorValues(String fieldName) { @Override public void searchNearestVectors( - String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) {} + String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) {} @Override public void searchNearestVectors( - String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) {} + String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) {} @Override public void checkIntegrity() throws IOException { diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsReader.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsReader.java index c596bbc61b19..02275e5b5963 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsReader.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsReader.java @@ -38,12 +38,12 @@ import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataAccessHint; import org.apache.lucene.store.FileTypeHint; import org.apache.lucene.store.IndexInput; -import org.apache.lucene.util.Bits; import org.apache.lucene.util.IOUtils; /** @@ -165,7 +165,8 @@ public ByteVectorValues getByteVectorValues(String field) { } @Override - public void search(String field, float[] vector, KnnCollector knnCollector, Bits acceptDocs) { + public void search( + String field, float[] vector, KnnCollector knnCollector, AcceptDocs acceptDocs) { FaissLibrary.Index index = indexMap.get(field); if (index != null) { index.search(vector, knnCollector, acceptDocs); @@ -173,7 +174,8 @@ public void search(String field, float[] vector, KnnCollector knnCollector, Bits } @Override - public void search(String field, byte[] vector, KnnCollector knnCollector, Bits acceptDocs) { + public void search( + String field, byte[] vector, KnnCollector knnCollector, AcceptDocs acceptDocs) { // TODO: Support using SQ8 quantization, see: // - https://github.com/opensearch-project/k-NN/pull/2425 throw new UnsupportedOperationException("Byte vectors not supported"); diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibrary.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibrary.java index e7837692222d..ad0715bcb333 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibrary.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibrary.java @@ -19,10 +19,10 @@ import java.io.Closeable; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; -import org.apache.lucene.util.Bits; import org.apache.lucene.util.hnsw.IntToIntFunction; /** @@ -42,7 +42,7 @@ interface FaissLibrary { String VERSION = "1.11.0"; interface Index extends Closeable { - void search(float[] query, KnnCollector knnCollector, Bits acceptDocs); + void search(float[] query, KnnCollector knnCollector, AcceptDocs acceptDocs); void write(IndexOutput output); } diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibraryNativeImpl.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibraryNativeImpl.java index d72e65eca860..08ec8264131c 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibraryNativeImpl.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibraryNativeImpl.java @@ -40,6 +40,7 @@ import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; @@ -313,10 +314,10 @@ public void close() { } @Override - public void search(float[] query, KnnCollector knnCollector, Bits acceptDocs) { + public void search(float[] query, KnnCollector knnCollector, AcceptDocs acceptDocs) { try (Arena temp = Arena.ofConfined()) { FixedBitSet fixedBitSet = - switch (acceptDocs) { + switch (acceptDocs.bits()) { case null -> null; case FixedBitSet bitSet -> bitSet; // TODO: Add optimized case for SparseFixedBitSet @@ -384,6 +385,8 @@ public void search(float[] query, KnnCollector knnCollector, Bits acceptDocs) { float distance = distancesPointer.getAtIndex(JAVA_FLOAT, i); knnCollector.collect(id, scaler.scale(distance)); } + } catch (IOException e) { + throw new RuntimeException(e); } } diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java index c588713869c7..1be5fb3cdb58 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java @@ -34,9 +34,10 @@ import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.tests.search.AssertingAcceptDocs; import org.apache.lucene.tests.util.TestUtil; -import org.apache.lucene.util.Bits; import org.apache.lucene.util.hnsw.HnswGraph; /** Wraps the default KnnVectorsFormat and provides additional assertions. */ @@ -157,22 +158,26 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { FieldInfo fi = fis.fieldInfo(field); assert fi != null && fi.getVectorDimension() > 0 && fi.getVectorEncoding() == VectorEncoding.FLOAT32; + acceptDocs = AssertingAcceptDocs.wrap(acceptDocs); delegate.search(field, target, knnCollector, acceptDocs); } @Override - public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { FieldInfo fi = fis.fieldInfo(field); assert fi != null && fi.getVectorDimension() > 0 && fi.getVectorEncoding() == VectorEncoding.BYTE; + acceptDocs = AssertingAcceptDocs.wrap(acceptDocs); delegate.search(field, target, knnCollector, acceptDocs); } @@ -188,13 +193,13 @@ public KnnVectorsReader getMergeInstance() throws IOException { @Override public void search( - String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) { + String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) { assert false : "This instance should only be used for merging"; } @Override public void search( - String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) { + String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) { assert false : "This instance should only be used for merging"; } diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/AssertingLeafReader.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/AssertingLeafReader.java index 256e13c514e8..6715c157a4d5 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/AssertingLeafReader.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/AssertingLeafReader.java @@ -50,8 +50,11 @@ import org.apache.lucene.index.TermsEnum; import org.apache.lucene.internal.tests.IndexPackageAccess; import org.apache.lucene.internal.tests.TestSecrets; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.DocAndFloatFeatureBuffer; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.tests.search.AssertingAcceptDocs; import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.FixedBitSet; @@ -1843,6 +1846,22 @@ public Bits getLiveDocs() { return liveDocs; } + @Override + public void searchNearestVectors( + String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) + throws IOException { + acceptDocs = AssertingAcceptDocs.wrap(acceptDocs); + super.searchNearestVectors(field, target, knnCollector, acceptDocs); + } + + @Override + public void searchNearestVectors( + String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) + throws IOException { + acceptDocs = AssertingAcceptDocs.wrap(acceptDocs); + super.searchNearestVectors(field, target, knnCollector, acceptDocs); + } + // we don't change behavior of the reader: just validate the API. @Override diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java index 17dcc1ef705b..eaff7a0ca9d5 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java @@ -79,6 +79,7 @@ import org.apache.lucene.index.Term; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.KnnFloatVectorQuery; @@ -832,7 +833,11 @@ public void testDeleteAllVectorDocs() throws Exception { // assert that knn search doesn't fail on a field with all deleted docs TopDocs results = leafReader.searchNearestVectors( - "v", randomNormalizedVector(4), 1, leafReader.getLiveDocs(), Integer.MAX_VALUE); + "v", + randomNormalizedVector(4), + 1, + AcceptDocs.fromLiveDocs(leafReader.getLiveDocs(), leafReader.maxDoc()), + Integer.MAX_VALUE); assertEquals(0, results.scoreDocs.length); } } @@ -1491,7 +1496,11 @@ public void testSearchWithVisitedLimit() throws Exception { TopDocs results = ctx.reader() .searchNearestVectors( - fieldName, randomNormalizedVector(dimension), k, liveDocs, visitedLimit); + fieldName, + randomNormalizedVector(dimension), + k, + AcceptDocs.fromLiveDocs(liveDocs, ctx.reader().maxDoc()), + visitedLimit); assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, results.totalHits.relation()); int size = Lucene99HnswVectorsReader.EXHAUSTIVE_BULK_SCORE_ORDS; assertTrue( @@ -1505,7 +1514,11 @@ public void testSearchWithVisitedLimit() throws Exception { results = ctx.reader() .searchNearestVectors( - fieldName, randomNormalizedVector(dimension), k, liveDocs, visitedLimit); + fieldName, + randomNormalizedVector(dimension), + k, + AcceptDocs.fromLiveDocs(liveDocs, ctx.reader().maxDoc()), + visitedLimit); assertEquals(TotalHits.Relation.EQUAL_TO, results.totalHits.relation()); assertTrue(results.totalHits.value() <= visitedLimit); assertOffHeapByteSize(ctx.reader(), fieldName); @@ -1587,7 +1600,11 @@ public void testRandomWithUpdatesAndGraph() throws Exception { TopDocs results = ctx.reader() .searchNearestVectors( - fieldName, randomNormalizedVector(dimension), k, liveDocs, Integer.MAX_VALUE); + fieldName, + randomNormalizedVector(dimension), + k, + AcceptDocs.fromLiveDocs(liveDocs, ctx.reader().maxDoc()), + Integer.MAX_VALUE); assertEquals(Math.min(k, size), results.scoreDocs.length); for (int i = 0; i < k - 1; i++) { assertTrue(results.scoreDocs[i].score >= results.scoreDocs[i + 1].score); @@ -2213,7 +2230,7 @@ private static Query buildExactKnnQuery(String fieldName, float[] queryVector, i @Override protected TopDocs approximateSearch( LeafReaderContext context, - Bits acceptDocs, + AcceptDocs acceptDocs, int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException { diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/MergeReaderWrapper.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/MergeReaderWrapper.java index 3fee110f7836..385ad78d3f19 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/MergeReaderWrapper.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/MergeReaderWrapper.java @@ -41,6 +41,7 @@ import org.apache.lucene.index.StoredFields; import org.apache.lucene.index.TermVectors; import org.apache.lucene.index.Terms; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Bits; @@ -241,13 +242,15 @@ public ByteVectorValues getByteVectorValues(String fieldName) throws IOException @Override public void searchNearestVectors( - String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) + throws IOException { in.searchNearestVectors(field, target, knnCollector, acceptDocs); } @Override public void searchNearestVectors( - String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) + throws IOException { in.searchNearestVectors(field, target, knnCollector, acceptDocs); } diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingAcceptDocs.java b/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingAcceptDocs.java new file mode 100644 index 000000000000..2a32b50b163a --- /dev/null +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingAcceptDocs.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.tests.search; + +import java.io.IOException; +import org.apache.lucene.search.AcceptDocs; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.Bits; + +/** Wraps {@link AcceptDocs} with assertions. */ +public final class AssertingAcceptDocs extends AcceptDocs { + + /** Wrap the given {@link AcceptDocs} with assertions. */ + public static AcceptDocs wrap(AcceptDocs acceptDocs) { + if (acceptDocs instanceof AssertingAcceptDocs assertingAcceptDocs) { + return assertingAcceptDocs; + } else { + return new AssertingAcceptDocs(acceptDocs); + } + } + + private final AcceptDocs acceptDocs; + private final Thread creationThread = Thread.currentThread(); + + private AssertingAcceptDocs(AcceptDocs acceptDocs) { + this.acceptDocs = acceptDocs; + } + + @Override + public Bits bits() throws IOException { + assert Thread.currentThread() == creationThread + : "Usage of AcceptDocs should be confined to a single thread"; + return acceptDocs.bits(); + } + + @Override + public DocIdSetIterator iterator() throws IOException { + assert Thread.currentThread() == creationThread + : "Usage of AcceptDocs should be confined to a single thread"; + DocIdSetIterator iterator = acceptDocs.iterator(); + assert iterator.docID() == -1 : "Iterator must be unpositioned"; + return iterator; + } + + @Override + public int cost() throws IOException { + assert Thread.currentThread() == creationThread + : "Usage of AcceptDocs should be confined to a single thread"; + return acceptDocs.cost(); + } +} diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/search/QueryUtils.java b/lucene/test-framework/src/java/org/apache/lucene/tests/search/QueryUtils.java index efd13121d930..c0102570c73d 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/search/QueryUtils.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/search/QueryUtils.java @@ -42,6 +42,7 @@ import org.apache.lucene.index.StoredFields; import org.apache.lucene.index.TermVectors; import org.apache.lucene.index.Terms; +import org.apache.lucene.search.AcceptDocs; import org.apache.lucene.search.BulkScorer; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.DocIdStream; @@ -245,11 +246,11 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { @Override public void searchNearestVectors( - String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) {} + String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) {} @Override public void searchNearestVectors( - String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) {} + String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) {} @Override public FieldInfos getFieldInfos() {