diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index e9246a8b5756..7a51ad3b65b4 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -16,6 +16,8 @@ */ package org.apache.lucene.search; +import static org.apache.lucene.search.AnnQueryUtils.createBitSet; +import static org.apache.lucene.search.AnnQueryUtils.createFilterWeight; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; @@ -53,8 +55,13 @@ abstract class AbstractKnnVectorQuery extends Query { private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS; + /** the KNN vector field to search */ protected final String field; + + /** the number of documents to find */ protected final int k; + + /** the filter to be executed. when the filter is applied is up to the underlying knn index */ protected final Query filter; public AbstractKnnVectorQuery(String field, int k, Query filter) { @@ -68,20 +75,12 @@ public AbstractKnnVectorQuery(String field, int k, Query filter) { @Override public Query rewrite(IndexSearcher indexSearcher) throws IOException { + // we need to perform search inside rewrite() because we need to get top-k + // matches across all segments + IndexReader reader = indexSearcher.getIndexReader(); - final Weight filterWeight; - if (filter != null) { - BooleanQuery booleanQuery = - new BooleanQuery.Builder() - .add(filter, BooleanClause.Occur.FILTER) - .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER) - .build(); - Query rewritten = indexSearcher.rewrite(booleanQuery); - filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f); - } else { - filterWeight = null; - } + final Weight filterWeight = createFilterWeight(indexSearcher, filter, field); TimeLimitingKnnCollectorManager knnCollectorManager = new TimeLimitingKnnCollectorManager( @@ -116,6 +115,7 @@ private TopDocs searchLeaf( return results; } + // Perform kNN search for the provided LeafReaderContext applying filterWeight as necessary private TopDocs getLeafResults( LeafReaderContext ctx, Weight filterWeight, @@ -156,24 +156,6 @@ private TopDocs getLeafResults( } } - private BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc) - throws IOException { - if (liveDocs == null && iterator instanceof BitSetIterator bitSetIterator) { - // If we already have a BitSet and no deletions, reuse the BitSet - return bitSetIterator.getBitSet(); - } else { - // Create a new BitSet from matching and live docs - FilteredDocIdSetIterator filterIterator = - new FilteredDocIdSetIterator(iterator) { - @Override - protected boolean match(int doc) { - return liveDocs == null || liveDocs.get(doc); - } - }; - return BitSet.of(filterIterator, maxDoc); - } - } - protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) { return new TopKnnCollectorManager(k, searcher); } @@ -188,6 +170,8 @@ protected abstract TopDocs approximateSearch( abstract VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException; + // Perform a brute-force search by computing the vector score for each accepted doc and try to + // take the top k docs. // We allow this to be overridden so that tests can check what search strategy is used protected TopDocs exactSearch( LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) @@ -255,6 +239,8 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { return TopDocs.merge(k, perLeafResults); } + // At this point we already collected top k matching docs, thus we only wrap the cached docs with + // their scores here. private Query createRewrittenQuery(IndexReader reader, TopDocs topK) { int len = topK.scoreDocs.length; @@ -272,6 +258,8 @@ private Query createRewrittenQuery(IndexReader reader, TopDocs topK) { return new DocAndScoreQuery(docs, scores, maxScore, segmentStarts, reader.getContext().id()); } + // For each segment, find the first index in docs belong to that segment. + // This method essentially partitions docs by segments static int[] findSegmentStarts(List leaves, int[] docs) { int[] starts = new int[leaves.size() + 1]; starts[starts.length - 1] = docs.length; diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java index dde9ce76ac3d..71379268a47a 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java @@ -16,6 +16,9 @@ */ package org.apache.lucene.search; +import static org.apache.lucene.search.AnnQueryUtils.createBitSet; +import static org.apache.lucene.search.AnnQueryUtils.createFilterWeight; + import java.io.IOException; import java.util.Arrays; import java.util.Comparator; @@ -78,10 +81,7 @@ protected abstract TopDocs approximateSearch( public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { return new Weight(this) { - final Weight filterWeight = - filter == null - ? null - : searcher.createWeight(searcher.rewrite(filter), ScoreMode.COMPLETE_NO_SCORES, 1); + final Weight filterWeight = createFilterWeight(searcher, filter, field); final QueryTimeout queryTimeout = searcher.getTimeout(); final TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager = @@ -133,21 +133,7 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti return null; } - BitSet acceptDocs; - if (liveDocs == null && scorer.iterator() instanceof BitSetIterator bitSetIterator) { - // If there are no deletions, and matching docs are already cached - acceptDocs = bitSetIterator.getBitSet(); - } else { - // Else collect all matching docs - FilteredDocIdSetIterator filtered = - new FilteredDocIdSetIterator(scorer.iterator()) { - @Override - protected boolean match(int doc) { - return liveDocs == null || liveDocs.get(doc); - } - }; - acceptDocs = BitSet.of(filtered, leafReader.maxDoc()); - } + BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, leafReader.maxDoc()); int cardinality = acceptDocs.cardinality(); if (cardinality == 0) { diff --git a/lucene/core/src/java/org/apache/lucene/search/AnnQueryUtils.java b/lucene/core/src/java/org/apache/lucene/search/AnnQueryUtils.java new file mode 100644 index 000000000000..6ab4bb1d127c --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/AnnQueryUtils.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.BitSetIterator; +import org.apache.lucene.util.Bits; + +/** Common utilities for ANN queries. */ +final class AnnQueryUtils { + + /** private constructor */ + private AnnQueryUtils() {} + + /** + * Create a bit set for a set of matching docs which are also not deleted. + * + *

If there is no deleted doc, it will use the matching docs bit set. Otherwise, it will return + * the bit set from matching docs which are also not deleted. + * + * @param iterator the matching doc iterator + * @param liveDocs the segment live (non-deleted) doc + * @param maxDoc the maximum number of docs to return + * @return a bit set over the matching docs + */ + static BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc) + throws IOException { + if (liveDocs == null && iterator instanceof BitSetIterator bitSetIterator) { + // If we already have a BitSet and no deletions, reuse the BitSet + return bitSetIterator.getBitSet(); + } else { + // Create a new BitSet from matching and live docs + FilteredDocIdSetIterator filterIterator = + new FilteredDocIdSetIterator(iterator) { + @Override + protected boolean match(int doc) { + return liveDocs == null || liveDocs.get(doc); + } + }; + return BitSet.of(filterIterator, maxDoc); + } + } + + /** + * Create a Weight for the filtered query. The filter will also be enhanced to only match + * documents with value in the vector field. + * + * @param indexSearcher the index searcher to rewrite and create weight + * @param filter the filter query + * @param field the KNN vector field to check + * @return Weight for the filter query + */ + static Weight createFilterWeight(IndexSearcher indexSearcher, Query filter, String field) + throws IOException { + if (filter == null) { + return null; + } + BooleanQuery booleanQuery = + new BooleanQuery.Builder() + .add(filter, BooleanClause.Occur.FILTER) + .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER) + .build(); + Query rewritten = indexSearcher.rewrite(booleanQuery); + return indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f); + } +}