diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index 7180995d5f13..c179c654ad62 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -73,21 +73,32 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { final Weight filterWeight; if (filter != null) { BooleanQuery booleanQuery = - new BooleanQuery.Builder() - .add(filter, BooleanClause.Occur.FILTER) - .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER) - .build(); + new BooleanQuery.Builder() + .add(filter, BooleanClause.Occur.FILTER) + .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER) + .build(); Query rewritten = indexSearcher.rewrite(booleanQuery); filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f); } else { filterWeight = null; } + List leafReaderContexts = reader.leaves(); + + // Use deterministic search strategy for multi-segment cases without filters + if (leafReaderContexts.size() > 1 && filter == null) { + TopDocs topK = deterministicSearch(leafReaderContexts, filterWeight, indexSearcher); + if (topK.scoreDocs.length == 0) { + return new MatchNoDocsQuery(); + } + return createRewrittenQuery(reader, topK); + } + + // Fall back to standard search for single segment or filtered cases TimeLimitingKnnCollectorManager knnCollectorManager = - new TimeLimitingKnnCollectorManager( - getKnnCollectorManager(k, indexSearcher), indexSearcher.getTimeout()); + new TimeLimitingKnnCollectorManager( + getKnnCollectorManager(k, indexSearcher), indexSearcher.getTimeout()); TaskExecutor taskExecutor = indexSearcher.getTaskExecutor(); - List leafReaderContexts = reader.leaves(); List> tasks = new ArrayList<>(leafReaderContexts.size()); for (LeafReaderContext context : leafReaderContexts) { tasks.add(() -> searchLeaf(context, filterWeight, knnCollectorManager)); @@ -102,11 +113,40 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { return createRewrittenQuery(reader, topK); } + /** + * Implements deterministic KNN search strategy. Each segment searches independently for full k + * without shared state, ensuring deterministic results. + */ + private TopDocs deterministicSearch( + List leafReaderContexts, Weight filterWeight, IndexSearcher indexSearcher) + throws IOException { + TaskExecutor taskExecutor = indexSearcher.getTaskExecutor(); + List> tasks = new ArrayList<>(leafReaderContexts.size()); + + // Each segment searches for full k with its own collector (no shared state) + for (LeafReaderContext context : leafReaderContexts) { + tasks.add( + () -> { + // Each segment gets its own TopKnnCollectorManager (no global queue) + KnnCollectorManager ownCollectorManager = new TopKnnCollectorManager(k, indexSearcher); + TimeLimitingKnnCollectorManager timeLimitingManager = + new TimeLimitingKnnCollectorManager(ownCollectorManager, indexSearcher.getTimeout()); + return searchLeaf(context, filterWeight, timeLimitingManager); + }); + } + + // Execute all searches in parallel (no shared state = no locks needed) + TopDocs[] perLeafResults = taskExecutor.invokeAll(tasks).toArray(TopDocs[]::new); + + // Deterministic merge of all results + return TopDocs.merge(k, perLeafResults); + } + private TopDocs searchLeaf( - LeafReaderContext ctx, - Weight filterWeight, - TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager) - throws IOException { + LeafReaderContext ctx, + Weight filterWeight, + TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager) + throws IOException { TopDocs results = getLeafResults(ctx, filterWeight, timeLimitingKnnCollectorManager); if (ctx.docBase > 0) { for (ScoreDoc scoreDoc : results.scoreDocs) { @@ -117,10 +157,10 @@ private TopDocs searchLeaf( } private TopDocs getLeafResults( - LeafReaderContext ctx, - Weight filterWeight, - TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager) - throws IOException { + LeafReaderContext ctx, + Weight filterWeight, + TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager) + throws IOException { final LeafReader reader = ctx.reader(); final Bits liveDocs = reader.getLiveDocs(); @@ -147,8 +187,8 @@ private TopDocs getLeafResults( // We pass cost + 1 here to account for the edge case when we explore exactly cost vectors TopDocs results = approximateSearch(ctx, acceptDocs, cost + 1, timeLimitingKnnCollectorManager); if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO - // Return partial results only when timeout is met - || (queryTimeout != null && queryTimeout.shouldExit())) { + // Return partial results only when timeout is met + || (queryTimeout != null && queryTimeout.shouldExit())) { return results; } else { // We stopped the kNN search because it visited too many nodes, so fall back to exact search @@ -157,19 +197,19 @@ private TopDocs getLeafResults( } private BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc) - throws IOException { + throws IOException { if (liveDocs == null && iterator instanceof BitSetIterator) { // If we already have a BitSet and no deletions, reuse the BitSet return ((BitSetIterator) iterator).getBitSet(); } else { // Create a new BitSet from matching and live docs FilteredDocIdSetIterator filterIterator = - new FilteredDocIdSetIterator(iterator) { - @Override - protected boolean match(int doc) { - return liveDocs == null || liveDocs.get(doc); - } - }; + new FilteredDocIdSetIterator(iterator) { + @Override + protected boolean match(int doc) { + return liveDocs == null || liveDocs.get(doc); + } + }; return BitSet.of(filterIterator, maxDoc); } } @@ -179,19 +219,19 @@ protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher search } protected abstract TopDocs approximateSearch( - LeafReaderContext context, - Bits acceptDocs, - int visitedLimit, - KnnCollectorManager knnCollectorManager) - throws IOException; + LeafReaderContext context, + Bits acceptDocs, + int visitedLimit, + KnnCollectorManager knnCollectorManager) + throws IOException; abstract VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) - throws IOException; + throws IOException; // We allow this to be overridden so that tests can check what search strategy is used protected TopDocs exactSearch( - LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) - throws IOException { + LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) + throws IOException { FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field); if (fi == null || fi.getVectorDimension() == 0) { // The field does not exist or does not index vectors @@ -208,7 +248,7 @@ protected TopDocs exactSearch( ScoreDoc topDoc = queue.top(); DocIdSetIterator vectorIterator = vectorScorer.iterator(); DocIdSetIterator conjunction = - ConjunctionDISI.createConjunction(List.of(vectorIterator, acceptIterator), List.of()); + ConjunctionDISI.createConjunction(List.of(vectorIterator, acceptIterator), List.of()); int doc; while ((doc = conjunction.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { // Mark results as partial if timeout is met @@ -351,7 +391,7 @@ static class DocAndScoreQuery extends Query { * query */ DocAndScoreQuery( - int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) { + int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) { this.k = k; this.docs = docs; this.scores = scores; @@ -361,7 +401,7 @@ static class DocAndScoreQuery extends Query { @Override public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) - throws IOException { + throws IOException { if (searcher.getIndexReader().getContext().id() != contextIdentity) { throw new IllegalStateException("This DocAndScore query was created by a different reader"); } @@ -494,14 +534,14 @@ public boolean equals(Object obj) { return false; } return contextIdentity == ((DocAndScoreQuery) obj).contextIdentity - && Arrays.equals(docs, ((DocAndScoreQuery) obj).docs) - && Arrays.equals(scores, ((DocAndScoreQuery) obj).scores); + && Arrays.equals(docs, ((DocAndScoreQuery) obj).docs) + && Arrays.equals(scores, ((DocAndScoreQuery) obj).scores); } @Override public int hashCode() { return Objects.hash( - classHash(), contextIdentity, Arrays.hashCode(docs), Arrays.hashCode(scores)); + classHash(), contextIdentity, Arrays.hashCode(docs), Arrays.hashCode(scores)); } } -} +} \ No newline at end of file