Skip to content
5 changes: 5 additions & 0 deletions docs/changelog/133446.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 133446
summary: Do not share Weight between Drivers
area: ES|QL
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -162,36 +162,56 @@ public final void close() {
protected void additionalClose() { /* Override this method to add any additional cleanup logic if needed */ }

LuceneScorer getCurrentOrLoadNextScorer() {
while (currentScorer == null || currentScorer.isDone()) {
if (currentSlice == null || sliceIndex >= currentSlice.numLeaves()) {
sliceIndex = 0;
currentSlice = sliceQueue.nextSlice(currentSlice);
if (currentSlice == null) {
doneCollecting = true;
while (true) {
while (currentScorer == null || currentScorer.isDone()) {
var partialLeaf = nextPartialLeaf();
if (partialLeaf == null) {
assert doneCollecting;
return null;
}
processedSlices++;
processedShards.add(currentSlice.shardContext().shardIdentifier());
logger.trace("Starting {}", partialLeaf);
loadScorerForNewPartialLeaf(partialLeaf);
}
final PartialLeafReaderContext partialLeaf = currentSlice.getLeaf(sliceIndex++);
logger.trace("Starting {}", partialLeaf);
final LeafReaderContext leaf = partialLeaf.leafReaderContext();
if (currentScorer == null // First time
|| currentScorer.leafReaderContext() != leaf // Moved to a new leaf
|| currentScorer.weight != currentSlice.weight() // Moved to a new query
) {
final Weight weight = currentSlice.weight();
processedQueries.add(weight.getQuery());
currentScorer = new LuceneScorer(currentSlice.shardContext(), weight, currentSlice.tags(), leaf);
// Has the executing thread changed? If so, we need to reinitialize the scorer. The reinitialized bulkScorer
// can be null even if it was non-null previously, due to lazy initialization in Weight#bulkScorer.
// Hence, we need to check the previous condition again.
if (currentScorer.executingThread == Thread.currentThread()) {
return currentScorer;
} else {
currentScorer.reinitialize();
}
assert currentScorer.maxPosition <= partialLeaf.maxDoc() : currentScorer.maxPosition + ">" + partialLeaf.maxDoc();
currentScorer.maxPosition = partialLeaf.maxDoc();
currentScorer.position = Math.max(currentScorer.position, partialLeaf.minDoc());
}
if (Thread.currentThread() != currentScorer.executingThread) {
currentScorer.reinitialize();
}

private PartialLeafReaderContext nextPartialLeaf() {
if (currentSlice == null || sliceIndex >= currentSlice.numLeaves()) {
sliceIndex = 0;
currentSlice = sliceQueue.nextSlice(currentSlice);
if (currentSlice == null) {
doneCollecting = true;
return null;
}
processedSlices++;
processedShards.add(currentSlice.shardContext().shardIdentifier());
}
return currentScorer;
return currentSlice.getLeaf(sliceIndex++);
}

private void loadScorerForNewPartialLeaf(PartialLeafReaderContext partialLeaf) {
final LeafReaderContext leaf = partialLeaf.leafReaderContext();
// the current Weight can be reused with the current slice
if (currentScorer != null && currentSlice.isWeightCompatible(currentScorer.weight)) {
if (currentScorer.leafReaderContext != leaf) {
currentScorer = new LuceneScorer(currentSlice.shardContext(), currentScorer.weight, currentSlice.tags(), leaf);
}
} else {
final var weight = currentSlice.createWeight();
processedQueries.add(weight.getQuery());
currentScorer = new LuceneScorer(currentSlice.shardContext(), weight, currentSlice.tags(), leaf);
}
assert currentScorer.maxPosition <= partialLeaf.maxDoc() : currentScorer.maxPosition + ">" + partialLeaf.maxDoc();
currentScorer.maxPosition = partialLeaf.maxDoc();
currentScorer.position = Math.max(currentScorer.position, partialLeaf.minDoc());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@

package org.elasticsearch.compute.lucene;

import org.apache.lucene.search.FilterWeight;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.List;

/**
Expand All @@ -19,14 +24,39 @@ public record LuceneSlice(
boolean queryHead,
ShardContext shardContext,
List<PartialLeafReaderContext> leaves,
Weight weight,
Query query,
ScoreMode scoreMode,
List<Object> tags
) {

int numLeaves() {
return leaves.size();
}

PartialLeafReaderContext getLeaf(int index) {
return leaves.get(index);
}

Weight createWeight() {
var searcher = shardContext.searcher();
try {
Weight w = searcher.createWeight(query, scoreMode, 1);
return new OwningWeight(query, w);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}

private static class OwningWeight extends FilterWeight {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we put the Query into LuceneScorer instead of into the Weight? There's already a bunch of useful stuff. We could do:

+        private final Query query;
         private final Weight weight;

and check there.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I've updated the LuceneScorer to keep both the query and tags.

final Query originalQuery;

protected OwningWeight(Query originalQuery, Weight weight) {
super(weight);
this.originalQuery = originalQuery;
}
}

boolean isWeightCompatible(Weight weight) {
return weight instanceof OwningWeight ow && ow.originalQuery == query;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
Expand Down Expand Up @@ -221,12 +220,11 @@ public static LuceneSliceQueue create(
PartitioningStrategy partitioning = PartitioningStrategy.pick(dataPartitioning, autoStrategy, ctx, query);
partitioningStrategies.put(ctx.shardIdentifier(), partitioning);
List<List<PartialLeafReaderContext>> groups = partitioning.groups(ctx.searcher(), taskConcurrency);
Weight weight = weight(ctx, query, scoreMode);
boolean queryHead = true;
for (List<PartialLeafReaderContext> group : groups) {
if (group.isEmpty() == false) {
final int slicePosition = nextSliceId++;
slices.add(new LuceneSlice(slicePosition, queryHead, ctx, group, weight, queryAndExtra.tags));
slices.add(new LuceneSlice(slicePosition, queryHead, ctx, group, query, scoreMode, queryAndExtra.tags));
queryHead = false;
}
}
Expand Down Expand Up @@ -328,16 +326,6 @@ private static PartitioningStrategy forAuto(Function<Query, PartitioningStrategy
}
}

static Weight weight(ShardContext ctx, Query query, ScoreMode scoreMode) {
var searcher = ctx.searcher();
try {
Query actualQuery = scoreMode.needsScores() ? query : new ConstantScoreQuery(query);
return searcher.createWeight(actualQuery, scoreMode, 1);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}

static final class AdaptivePartitioner {
final int desiredDocsPerSlice;
final int maxDocsPerSlice;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ protected boolean lessThan(LeafIterator a, LeafIterator b) {
return a.timeSeriesHash.compareTo(b.timeSeriesHash) < 0;
}
};
Weight weight = luceneSlice.weight();
Weight weight = luceneSlice.createWeight();
processedQueries.add(weight.getQuery());
int maxSegmentOrd = 0;
for (var leafReaderContext : luceneSlice.leaves()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.lucene.index.TermVectors;
import org.apache.lucene.index.Terms;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.util.Bits;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.test.ESTestCase;
Expand All @@ -50,27 +51,28 @@ public void testBasics() {
LeafReaderContext leaf2 = new MockLeafReader(1000).getContext();
LeafReaderContext leaf3 = new MockLeafReader(1000).getContext();
LeafReaderContext leaf4 = new MockLeafReader(1000).getContext();
List<Object> query1 = List.of("1");
List<Object> query2 = List.of("q2");
List<Object> t1 = List.of("1");
List<Object> t2 = List.of("q2");
var scoreMode = ScoreMode.COMPLETE_NO_SCORES;
List<LuceneSlice> sliceList = List.of(
// query1: new segment
new LuceneSlice(0, true, null, List.of(new PartialLeafReaderContext(leaf1, 0, 10)), null, query1),
new LuceneSlice(1, false, null, List.of(new PartialLeafReaderContext(leaf2, 0, 10)), null, query1),
new LuceneSlice(2, false, null, List.of(new PartialLeafReaderContext(leaf2, 10, 20)), null, query1),
new LuceneSlice(0, true, null, List.of(new PartialLeafReaderContext(leaf1, 0, 10)), null, scoreMode, t1),
new LuceneSlice(1, false, null, List.of(new PartialLeafReaderContext(leaf2, 0, 10)), null, scoreMode, t1),
new LuceneSlice(2, false, null, List.of(new PartialLeafReaderContext(leaf2, 10, 20)), null, scoreMode, t1),
// query1: new segment
new LuceneSlice(3, false, null, List.of(new PartialLeafReaderContext(leaf3, 0, 20)), null, query1),
new LuceneSlice(4, false, null, List.of(new PartialLeafReaderContext(leaf3, 10, 20)), null, query1),
new LuceneSlice(5, false, null, List.of(new PartialLeafReaderContext(leaf3, 20, 30)), null, query1),
new LuceneSlice(3, false, null, List.of(new PartialLeafReaderContext(leaf3, 0, 20)), null, scoreMode, t1),
new LuceneSlice(4, false, null, List.of(new PartialLeafReaderContext(leaf3, 10, 20)), null, scoreMode, t1),
new LuceneSlice(5, false, null, List.of(new PartialLeafReaderContext(leaf3, 20, 30)), null, scoreMode, t1),
// query1: new segment
new LuceneSlice(6, false, null, List.of(new PartialLeafReaderContext(leaf4, 0, 10)), null, query1),
new LuceneSlice(7, false, null, List.of(new PartialLeafReaderContext(leaf4, 10, 20)), null, query1),
new LuceneSlice(6, false, null, List.of(new PartialLeafReaderContext(leaf4, 0, 10)), null, scoreMode, t1),
new LuceneSlice(7, false, null, List.of(new PartialLeafReaderContext(leaf4, 10, 20)), null, scoreMode, t1),
// query2: new segment
new LuceneSlice(8, true, null, List.of(new PartialLeafReaderContext(leaf2, 0, 10)), null, query2),
new LuceneSlice(9, false, null, List.of(new PartialLeafReaderContext(leaf2, 10, 20)), null, query2),
new LuceneSlice(8, true, null, List.of(new PartialLeafReaderContext(leaf2, 0, 10)), null, scoreMode, t2),
new LuceneSlice(9, false, null, List.of(new PartialLeafReaderContext(leaf2, 10, 20)), null, scoreMode, t2),
// query1: new segment
new LuceneSlice(10, false, null, List.of(new PartialLeafReaderContext(leaf3, 0, 20)), null, query2),
new LuceneSlice(11, false, null, List.of(new PartialLeafReaderContext(leaf3, 10, 20)), null, query2),
new LuceneSlice(12, false, null, List.of(new PartialLeafReaderContext(leaf3, 20, 30)), null, query2)
new LuceneSlice(10, false, null, List.of(new PartialLeafReaderContext(leaf3, 0, 20)), null, scoreMode, t2),
new LuceneSlice(11, false, null, List.of(new PartialLeafReaderContext(leaf3, 10, 20)), null, scoreMode, t2),
new LuceneSlice(12, false, null, List.of(new PartialLeafReaderContext(leaf3, 20, 30)), null, scoreMode, t2)
);
// single driver
{
Expand Down Expand Up @@ -140,6 +142,7 @@ public void testRandom() throws Exception {
mock(ShardContext.class),
List.of(new PartialLeafReaderContext(leafContext, minDoc, maxDoc)),
null,
ScoreMode.COMPLETE_NO_SCORES,
null
);
sliceList.add(slice);
Expand Down