Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,21 +73,32 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
final Weight filterWeight;
if (filter != null) {
BooleanQuery booleanQuery =
new BooleanQuery.Builder()
.add(filter, BooleanClause.Occur.FILTER)
.add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER)
.build();
new BooleanQuery.Builder()
.add(filter, BooleanClause.Occur.FILTER)
.add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER)
.build();
Query rewritten = indexSearcher.rewrite(booleanQuery);
filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
} else {
filterWeight = null;
}

List<LeafReaderContext> leafReaderContexts = reader.leaves();

// Use deterministic search strategy for multi-segment cases without filters
if (leafReaderContexts.size() > 1 && filter == null) {
TopDocs topK = deterministicSearch(leafReaderContexts, filterWeight, indexSearcher);
if (topK.scoreDocs.length == 0) {
return new MatchNoDocsQuery();
}
return createRewrittenQuery(reader, topK);
}

// Fall back to standard search for single segment or filtered cases
TimeLimitingKnnCollectorManager knnCollectorManager =
new TimeLimitingKnnCollectorManager(
getKnnCollectorManager(k, indexSearcher), indexSearcher.getTimeout());
new TimeLimitingKnnCollectorManager(
getKnnCollectorManager(k, indexSearcher), indexSearcher.getTimeout());
TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
List<LeafReaderContext> leafReaderContexts = reader.leaves();
List<Callable<TopDocs>> tasks = new ArrayList<>(leafReaderContexts.size());
for (LeafReaderContext context : leafReaderContexts) {
tasks.add(() -> searchLeaf(context, filterWeight, knnCollectorManager));
Expand All @@ -102,11 +113,40 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
return createRewrittenQuery(reader, topK);
}

/**
* Implements deterministic KNN search strategy. Each segment searches independently for full k
* without shared state, ensuring deterministic results.
*/
private TopDocs deterministicSearch(
List<LeafReaderContext> leafReaderContexts, Weight filterWeight, IndexSearcher indexSearcher)
throws IOException {
TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
List<Callable<TopDocs>> tasks = new ArrayList<>(leafReaderContexts.size());

// Each segment searches for full k with its own collector (no shared state)
for (LeafReaderContext context : leafReaderContexts) {
tasks.add(
() -> {
// Each segment gets its own TopKnnCollectorManager (no global queue)
KnnCollectorManager ownCollectorManager = new TopKnnCollectorManager(k, indexSearcher);
TimeLimitingKnnCollectorManager timeLimitingManager =
new TimeLimitingKnnCollectorManager(ownCollectorManager, indexSearcher.getTimeout());
return searchLeaf(context, filterWeight, timeLimitingManager);
});
}

// Execute all searches in parallel (no shared state = no locks needed)
TopDocs[] perLeafResults = taskExecutor.invokeAll(tasks).toArray(TopDocs[]::new);

// Deterministic merge of all results
return TopDocs.merge(k, perLeafResults);
}

private TopDocs searchLeaf(
LeafReaderContext ctx,
Weight filterWeight,
TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager)
throws IOException {
LeafReaderContext ctx,
Weight filterWeight,
TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager)
throws IOException {
TopDocs results = getLeafResults(ctx, filterWeight, timeLimitingKnnCollectorManager);
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
Expand All @@ -117,10 +157,10 @@ private TopDocs searchLeaf(
}

private TopDocs getLeafResults(
LeafReaderContext ctx,
Weight filterWeight,
TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager)
throws IOException {
LeafReaderContext ctx,
Weight filterWeight,
TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager)
throws IOException {
final LeafReader reader = ctx.reader();
final Bits liveDocs = reader.getLiveDocs();

Expand All @@ -147,8 +187,8 @@ private TopDocs getLeafResults(
// We pass cost + 1 here to account for the edge case when we explore exactly cost vectors
TopDocs results = approximateSearch(ctx, acceptDocs, cost + 1, timeLimitingKnnCollectorManager);
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO
// Return partial results only when timeout is met
|| (queryTimeout != null && queryTimeout.shouldExit())) {
// Return partial results only when timeout is met
|| (queryTimeout != null && queryTimeout.shouldExit())) {
return results;
} else {
// We stopped the kNN search because it visited too many nodes, so fall back to exact search
Expand All @@ -157,19 +197,19 @@ private TopDocs getLeafResults(
}

private BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc)
throws IOException {
throws IOException {
if (liveDocs == null && iterator instanceof BitSetIterator) {
// If we already have a BitSet and no deletions, reuse the BitSet
return ((BitSetIterator) iterator).getBitSet();
} else {
// Create a new BitSet from matching and live docs
FilteredDocIdSetIterator filterIterator =
new FilteredDocIdSetIterator(iterator) {
@Override
protected boolean match(int doc) {
return liveDocs == null || liveDocs.get(doc);
}
};
new FilteredDocIdSetIterator(iterator) {
@Override
protected boolean match(int doc) {
return liveDocs == null || liveDocs.get(doc);
}
};
return BitSet.of(filterIterator, maxDoc);
}
}
Expand All @@ -179,19 +219,19 @@ protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher search
}

protected abstract TopDocs approximateSearch(
LeafReaderContext context,
Bits acceptDocs,
int visitedLimit,
KnnCollectorManager knnCollectorManager)
throws IOException;
LeafReaderContext context,
Bits acceptDocs,
int visitedLimit,
KnnCollectorManager knnCollectorManager)
throws IOException;

abstract VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi)
throws IOException;
throws IOException;

// We allow this to be overridden so that tests can check what search strategy is used
protected TopDocs exactSearch(
LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout)
throws IOException {
LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout)
throws IOException {
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// The field does not exist or does not index vectors
Expand All @@ -208,7 +248,7 @@ protected TopDocs exactSearch(
ScoreDoc topDoc = queue.top();
DocIdSetIterator vectorIterator = vectorScorer.iterator();
DocIdSetIterator conjunction =
ConjunctionDISI.createConjunction(List.of(vectorIterator, acceptIterator), List.of());
ConjunctionDISI.createConjunction(List.of(vectorIterator, acceptIterator), List.of());
int doc;
while ((doc = conjunction.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
// Mark results as partial if timeout is met
Expand Down Expand Up @@ -351,7 +391,7 @@ static class DocAndScoreQuery extends Query {
* query
*/
DocAndScoreQuery(
int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) {
int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) {
this.k = k;
this.docs = docs;
this.scores = scores;
Expand All @@ -361,7 +401,7 @@ static class DocAndScoreQuery extends Query {

@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
throws IOException {
throws IOException {
if (searcher.getIndexReader().getContext().id() != contextIdentity) {
throw new IllegalStateException("This DocAndScore query was created by a different reader");
}
Expand Down Expand Up @@ -494,14 +534,14 @@ public boolean equals(Object obj) {
return false;
}
return contextIdentity == ((DocAndScoreQuery) obj).contextIdentity
&& Arrays.equals(docs, ((DocAndScoreQuery) obj).docs)
&& Arrays.equals(scores, ((DocAndScoreQuery) obj).scores);
&& Arrays.equals(docs, ((DocAndScoreQuery) obj).docs)
&& Arrays.equals(scores, ((DocAndScoreQuery) obj).scores);
}

@Override
public int hashCode() {
return Objects.hash(
classHash(), contextIdentity, Arrays.hashCode(docs), Arrays.hashCode(scores));
classHash(), contextIdentity, Arrays.hashCode(docs), Arrays.hashCode(scores));
}
}
}
}