Skip to content
5 changes: 5 additions & 0 deletions docs/changelog/133446.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 133446
summary: Do not share Weight between Drivers
area: ES|QL
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -165,40 +165,61 @@ public final void close() {
protected void additionalClose() { /* Override this method to add any additional cleanup logic if needed */ }

LuceneScorer getCurrentOrLoadNextScorer() {
while (currentScorer == null || currentScorer.isDone()) {
if (currentSlice == null || sliceIndex >= currentSlice.numLeaves()) {
sliceIndex = 0;
currentSlice = sliceQueue.nextSlice(currentSlice);
if (currentSlice == null) {
doneCollecting = true;
while (true) {
while (currentScorer == null || currentScorer.isDone()) {
var partialLeaf = nextPartialLeaf();
if (partialLeaf == null) {
assert doneCollecting;
return null;
}
processedSlices++;
processedShards.add(currentSlice.shardContext().shardIdentifier());
int shardId = currentSlice.shardContext().index();
if (currentScorerShardRefCounted == null || currentScorerShardRefCounted.index() != shardId) {
currentScorerShardRefCounted = new ShardRefCounted.Single(shardId, shardContextCounters.get(shardId));
}
logger.trace("Starting {}", partialLeaf);
loadScorerForNewPartialLeaf(partialLeaf);
}
final PartialLeafReaderContext partialLeaf = currentSlice.getLeaf(sliceIndex++);
logger.trace("Starting {}", partialLeaf);
final LeafReaderContext leaf = partialLeaf.leafReaderContext();
if (currentScorer == null // First time
|| currentScorer.leafReaderContext() != leaf // Moved to a new leaf
|| currentScorer.weight != currentSlice.weight() // Moved to a new query
) {
final Weight weight = currentSlice.weight();
processedQueries.add(weight.getQuery());
currentScorer = new LuceneScorer(currentSlice.shardContext(), weight, currentSlice.tags(), leaf);
// Has the executing thread changed? If so, we need to reinitialize the scorer. The reinitialized bulkScorer
// can be null even if it was non-null previously, due to lazy initialization in Weight#bulkScorer.
// Hence, we need to check the previous condition again.
if (currentScorer.executingThread == Thread.currentThread()) {
return currentScorer;
} else {
currentScorer.reinitialize();
}
assert currentScorer.maxPosition <= partialLeaf.maxDoc() : currentScorer.maxPosition + ">" + partialLeaf.maxDoc();
currentScorer.maxPosition = partialLeaf.maxDoc();
currentScorer.position = Math.max(currentScorer.position, partialLeaf.minDoc());
}
if (Thread.currentThread() != currentScorer.executingThread) {
currentScorer.reinitialize();
}

private PartialLeafReaderContext nextPartialLeaf() {
if (currentSlice == null || sliceIndex >= currentSlice.numLeaves()) {
sliceIndex = 0;
currentSlice = sliceQueue.nextSlice(currentSlice);
if (currentSlice == null) {
doneCollecting = true;
return null;
}
processedSlices++;
int shardId = currentSlice.shardContext().index();
if (currentScorerShardRefCounted == null || currentScorerShardRefCounted.index() != shardId) {
currentScorerShardRefCounted = new ShardRefCounted.Single(shardId, shardContextCounters.get(shardId));
}
processedShards.add(currentSlice.shardContext().shardIdentifier());
}
return currentScorer;
return currentSlice.getLeaf(sliceIndex++);
}

private void loadScorerForNewPartialLeaf(PartialLeafReaderContext partialLeaf) {
final LeafReaderContext leaf = partialLeaf.leafReaderContext();
if (currentScorer != null
&& currentScorer.query() == currentSlice.query()
&& currentScorer.shardContext == currentSlice.shardContext()) {
if (currentScorer.leafReaderContext != leaf) {
currentScorer = new LuceneScorer(currentSlice.shardContext(), currentScorer.weight, currentSlice.queryAndTags(), leaf);
}
} else {
final var weight = currentSlice.createWeight();
currentScorer = new LuceneScorer(currentSlice.shardContext(), weight, currentSlice.queryAndTags(), leaf);
processedQueries.add(currentScorer.query());
}
assert currentScorer.maxPosition <= partialLeaf.maxDoc() : currentScorer.maxPosition + ">" + partialLeaf.maxDoc();
currentScorer.maxPosition = partialLeaf.maxDoc();
currentScorer.position = Math.max(currentScorer.position, partialLeaf.minDoc());
}

/**
Expand All @@ -214,18 +235,23 @@ ShardRefCounted currentScorerShardRefCounted() {
static final class LuceneScorer {
private final ShardContext shardContext;
private final Weight weight;
private final LuceneSliceQueue.QueryAndTags queryAndTags;
private final LeafReaderContext leafReaderContext;
private final List<Object> tags;

private BulkScorer bulkScorer;
private int position;
private int maxPosition;
private Thread executingThread;

LuceneScorer(ShardContext shardContext, Weight weight, List<Object> tags, LeafReaderContext leafReaderContext) {
LuceneScorer(
ShardContext shardContext,
Weight weight,
LuceneSliceQueue.QueryAndTags queryAndTags,
LeafReaderContext leafReaderContext
) {
this.shardContext = shardContext;
this.weight = weight;
this.tags = tags;
this.queryAndTags = queryAndTags;
this.leafReaderContext = leafReaderContext;
reinitialize();
}
Expand Down Expand Up @@ -275,7 +301,11 @@ int position() {
* Tags to add to the data returned by this query.
*/
List<Object> tags() {
return tags;
return queryAndTags.tags();
}

Query query() {
return queryAndTags.query();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@

package org.elasticsearch.compute.lucene;

import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.List;

/**
Expand All @@ -19,14 +23,32 @@ public record LuceneSlice(
boolean queryHead,
ShardContext shardContext,
List<PartialLeafReaderContext> leaves,
Weight weight,
List<Object> tags
ScoreMode scoreMode,
LuceneSliceQueue.QueryAndTags queryAndTags
) {

int numLeaves() {
return leaves.size();
}

PartialLeafReaderContext getLeaf(int index) {
return leaves.get(index);
}

Query query() {
return queryAndTags.query();
}

List<Object> tags() {
return queryAndTags.tags();
}

Weight createWeight() {
var searcher = shardContext.searcher();
try {
return searcher.createWeight(queryAndTags.query(), scoreMode, 1);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
Expand Down Expand Up @@ -209,12 +208,11 @@ public static LuceneSliceQueue create(
PartitioningStrategy partitioning = PartitioningStrategy.pick(dataPartitioning, autoStrategy, ctx, query);
partitioningStrategies.put(ctx.shardIdentifier(), partitioning);
List<List<PartialLeafReaderContext>> groups = partitioning.groups(ctx.searcher(), taskConcurrency);
Weight weight = weight(ctx, query, scoreMode);
boolean queryHead = true;
for (List<PartialLeafReaderContext> group : groups) {
if (group.isEmpty() == false) {
final int slicePosition = nextSliceId++;
slices.add(new LuceneSlice(slicePosition, queryHead, ctx, group, weight, queryAndExtra.tags));
slices.add(new LuceneSlice(slicePosition, queryHead, ctx, group, scoreMode, queryAndExtra));
queryHead = false;
}
}
Expand Down Expand Up @@ -316,16 +314,6 @@ private static PartitioningStrategy forAuto(Function<Query, PartitioningStrategy
}
}

static Weight weight(ShardContext ctx, Query query, ScoreMode scoreMode) {
var searcher = ctx.searcher();
try {
Query actualQuery = scoreMode.needsScores() ? query : new ConstantScoreQuery(query);
return searcher.createWeight(actualQuery, scoreMode, 1);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}

static final class AdaptivePartitioner {
final int desiredDocsPerSlice;
final int maxDocsPerSlice;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ protected boolean lessThan(LeafIterator a, LeafIterator b) {
return a.timeSeriesHash.compareTo(b.timeSeriesHash) < 0;
}
};
Weight weight = luceneSlice.weight();
Weight weight = luceneSlice.createWeight();
processedQueries.add(weight.getQuery());
int maxSegmentOrd = 0;
for (var leafReaderContext : luceneSlice.leaves()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import org.apache.lucene.index.TermVectors;
import org.apache.lucene.index.Terms;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.util.Bits;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.test.ESTestCase;
Expand All @@ -50,27 +52,28 @@ public void testBasics() {
LeafReaderContext leaf2 = new MockLeafReader(1000).getContext();
LeafReaderContext leaf3 = new MockLeafReader(1000).getContext();
LeafReaderContext leaf4 = new MockLeafReader(1000).getContext();
List<Object> query1 = List.of("1");
List<Object> query2 = List.of("q2");
LuceneSliceQueue.QueryAndTags t1 = new LuceneSliceQueue.QueryAndTags(new MatchAllDocsQuery(), List.of("q1"));
LuceneSliceQueue.QueryAndTags t2 = new LuceneSliceQueue.QueryAndTags(new MatchAllDocsQuery(), List.of("q2"));
var scoreMode = ScoreMode.COMPLETE_NO_SCORES;
List<LuceneSlice> sliceList = List.of(
// query1: new segment
new LuceneSlice(0, true, null, List.of(new PartialLeafReaderContext(leaf1, 0, 10)), null, query1),
new LuceneSlice(1, false, null, List.of(new PartialLeafReaderContext(leaf2, 0, 10)), null, query1),
new LuceneSlice(2, false, null, List.of(new PartialLeafReaderContext(leaf2, 10, 20)), null, query1),
new LuceneSlice(0, true, null, List.of(new PartialLeafReaderContext(leaf1, 0, 10)), scoreMode, t1),
new LuceneSlice(1, false, null, List.of(new PartialLeafReaderContext(leaf2, 0, 10)), scoreMode, t1),
new LuceneSlice(2, false, null, List.of(new PartialLeafReaderContext(leaf2, 10, 20)), scoreMode, t1),
// query1: new segment
new LuceneSlice(3, false, null, List.of(new PartialLeafReaderContext(leaf3, 0, 20)), null, query1),
new LuceneSlice(4, false, null, List.of(new PartialLeafReaderContext(leaf3, 10, 20)), null, query1),
new LuceneSlice(5, false, null, List.of(new PartialLeafReaderContext(leaf3, 20, 30)), null, query1),
new LuceneSlice(3, false, null, List.of(new PartialLeafReaderContext(leaf3, 0, 20)), scoreMode, t1),
new LuceneSlice(4, false, null, List.of(new PartialLeafReaderContext(leaf3, 10, 20)), scoreMode, t1),
new LuceneSlice(5, false, null, List.of(new PartialLeafReaderContext(leaf3, 20, 30)), scoreMode, t1),
// query1: new segment
new LuceneSlice(6, false, null, List.of(new PartialLeafReaderContext(leaf4, 0, 10)), null, query1),
new LuceneSlice(7, false, null, List.of(new PartialLeafReaderContext(leaf4, 10, 20)), null, query1),
new LuceneSlice(6, false, null, List.of(new PartialLeafReaderContext(leaf4, 0, 10)), scoreMode, t1),
new LuceneSlice(7, false, null, List.of(new PartialLeafReaderContext(leaf4, 10, 20)), scoreMode, t1),
// query2: new segment
new LuceneSlice(8, true, null, List.of(new PartialLeafReaderContext(leaf2, 0, 10)), null, query2),
new LuceneSlice(9, false, null, List.of(new PartialLeafReaderContext(leaf2, 10, 20)), null, query2),
new LuceneSlice(8, true, null, List.of(new PartialLeafReaderContext(leaf2, 0, 10)), scoreMode, t2),
new LuceneSlice(9, false, null, List.of(new PartialLeafReaderContext(leaf2, 10, 20)), scoreMode, t2),
// query1: new segment
new LuceneSlice(10, false, null, List.of(new PartialLeafReaderContext(leaf3, 0, 20)), null, query2),
new LuceneSlice(11, false, null, List.of(new PartialLeafReaderContext(leaf3, 10, 20)), null, query2),
new LuceneSlice(12, false, null, List.of(new PartialLeafReaderContext(leaf3, 20, 30)), null, query2)
new LuceneSlice(10, false, null, List.of(new PartialLeafReaderContext(leaf3, 0, 20)), scoreMode, t2),
new LuceneSlice(11, false, null, List.of(new PartialLeafReaderContext(leaf3, 10, 20)), scoreMode, t2),
new LuceneSlice(12, false, null, List.of(new PartialLeafReaderContext(leaf3, 20, 30)), scoreMode, t2)
);
// single driver
{
Expand Down Expand Up @@ -139,7 +142,7 @@ public void testRandom() throws Exception {
false,
mock(ShardContext.class),
List.of(new PartialLeafReaderContext(leafContext, minDoc, maxDoc)),
null,
ScoreMode.COMPLETE_NO_SCORES,
null
);
sliceList.add(slice);
Expand Down