Skip to content

Commit ee117e1

Browse files
committed
Allow smaller slices in doc partitioning
1 parent 522a46f commit ee117e1

File tree

6 files changed

+502
-60
lines changed

6 files changed

+502
-60
lines changed

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/DataPartitioning.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,15 @@ public enum DataPartitioning {
3737
*/
3838
SEGMENT,
3939
/**
40-
* Partition each shard into {@code task_concurrency} partitions, splitting
41-
* larger segments into slices. This allows bringing the most CPUs to bear on
42-
* the problem but adds extra overhead, especially in query preparation.
40+
* Partition each shard into dynamic-sized slices, providing more flexible partition sizes than {@link #SEGMENT}.
41+
* This strategy optimizes CPU utilization and minimizes overhead by:
42+
* 1. Starting with the desired partition size based on {@code task_concurrency}, then capping the slice size around 1M to avoid
43+
* CPU underutilization when matching documents are concentrated in one segment region.
44+
* 2. For small and medium segment sizes (whose size is less than 5 times the desired slice size), uses a variant of {@link #SEGMENT},
45+
* which also splits segments larger than the desired size as needed.
46+
* 3. For very large segments, avoids combining multiple segments in a single slice.
47+
* This enables a single driver to process the same segment until other drivers attempt to steal work
48+
* after completing their own segments.
4349
*/
44-
DOC,
50+
DOC
4551
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneOperator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ LuceneScorer getCurrentOrLoadNextScorer() {
165165
while (currentScorer == null || currentScorer.isDone()) {
166166
if (currentSlice == null || sliceIndex >= currentSlice.numLeaves()) {
167167
sliceIndex = 0;
168-
currentSlice = sliceQueue.nextSlice();
168+
currentSlice = sliceQueue.nextSlice(currentSlice);
169169
if (currentSlice == null) {
170170
doneCollecting = true;
171171
return null;

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneSlice.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,13 @@
1414
/**
1515
* Holds a list of multiple partial Lucene segments
1616
*/
17-
public record LuceneSlice(ShardContext shardContext, List<PartialLeafReaderContext> leaves, Weight weight, List<Object> tags) {
17+
public record LuceneSlice(
18+
int slicePosition,
19+
ShardContext shardContext,
20+
List<PartialLeafReaderContext> leaves,
21+
Weight weight,
22+
List<Object> tags
23+
) {
1824
int numLeaves() {
1925
return leaves.size();
2026
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneSliceQueue.java

Lines changed: 135 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,21 @@
1616
import org.elasticsearch.common.io.stream.StreamInput;
1717
import org.elasticsearch.common.io.stream.StreamOutput;
1818
import org.elasticsearch.common.io.stream.Writeable;
19+
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
1920
import org.elasticsearch.core.Nullable;
2021

2122
import java.io.IOException;
2223
import java.io.UncheckedIOException;
2324
import java.util.ArrayList;
2425
import java.util.Arrays;
2526
import java.util.Collection;
27+
import java.util.Collections;
28+
import java.util.Comparator;
2629
import java.util.HashMap;
2730
import java.util.List;
2831
import java.util.Map;
2932
import java.util.Queue;
30-
import java.util.concurrent.ConcurrentLinkedQueue;
33+
import java.util.concurrent.atomic.AtomicReferenceArray;
3134
import java.util.function.Function;
3235

3336
/**
@@ -77,18 +80,62 @@ public record QueryAndTags(Query query, List<Object> tags) {}
7780
public static final int MAX_SEGMENTS_PER_SLICE = 5; // copied from IndexSearcher
7881

7982
private final int totalSlices;
80-
private final Queue<LuceneSlice> slices;
83+
private final AtomicReferenceArray<LuceneSlice> slices;
84+
private final Queue<Integer> startedPositions;
85+
private final Queue<Integer> followedPositions;
8186
private final Map<String, PartitioningStrategy> partitioningStrategies;
8287

83-
private LuceneSliceQueue(List<LuceneSlice> slices, Map<String, PartitioningStrategy> partitioningStrategies) {
84-
this.totalSlices = slices.size();
85-
this.slices = new ConcurrentLinkedQueue<>(slices);
88+
LuceneSliceQueue(List<LuceneSlice> sliceList, Map<String, PartitioningStrategy> partitioningStrategies) {
89+
this.totalSlices = sliceList.size();
90+
this.slices = new AtomicReferenceArray<>(sliceList.size());
91+
for (int i = 0; i < sliceList.size(); i++) {
92+
slices.set(i, sliceList.get(i));
93+
}
8694
this.partitioningStrategies = partitioningStrategies;
95+
this.startedPositions = ConcurrentCollections.newQueue();
96+
this.followedPositions = ConcurrentCollections.newQueue();
97+
for (LuceneSlice slice : sliceList) {
98+
if (slice.getLeaf(0).minDoc() == 0) {
99+
startedPositions.add(slice.slicePosition());
100+
} else {
101+
followedPositions.add(slice.slicePosition());
102+
}
103+
}
87104
}
88105

106+
/**
107+
* Retrieves the next available {@link LuceneSlice} for processing.
108+
* If a previous slice is provided, this method first attempts to return the next sequential slice to maintain segment affinity
109+
* and minimize the cost of switching between segments.
110+
* <p>
111+
* If no sequential slice is available, it returns the next slice from the {@code startedPositions} queue, which starts a new
112+
* group of segments. If all started positions are exhausted, it retrieves a slice from the {@code followedPositions} queue,
113+
* enabling work stealing.
114+
*
115+
* @param prev the previously returned {@link LuceneSlice}, or {@code null} if starting
116+
* @return the next available {@link LuceneSlice}, or {@code null} if exhausted
117+
*/
89118
@Nullable
90-
public LuceneSlice nextSlice() {
91-
return slices.poll();
119+
public LuceneSlice nextSlice(LuceneSlice prev) {
120+
if (prev != null) {
121+
final int nextId = prev.slicePosition() + 1;
122+
if (nextId < totalSlices) {
123+
var slice = slices.getAndSet(nextId, null);
124+
if (slice != null) {
125+
return slice;
126+
}
127+
}
128+
}
129+
for (var preferredIndices : List.of(startedPositions, followedPositions)) {
130+
Integer nextId;
131+
while ((nextId = preferredIndices.poll()) != null) {
132+
var slice = slices.getAndSet(nextId, null);
133+
if (slice != null) {
134+
return slice;
135+
}
136+
}
137+
}
138+
return null;
92139
}
93140

94141
public int totalSlices() {
@@ -103,7 +150,14 @@ public Map<String, PartitioningStrategy> partitioningStrategies() {
103150
}
104151

105152
public Collection<String> remainingShardsIdentifiers() {
106-
return slices.stream().map(slice -> slice.shardContext().shardIdentifier()).toList();
153+
List<String> remaining = new ArrayList<>(slices.length());
154+
for (int i = 0; i < slices.length(); i++) {
155+
LuceneSlice slice = slices.get(i);
156+
if (slice != null) {
157+
remaining.add(slice.shardContext().shardIdentifier());
158+
}
159+
}
160+
return remaining;
107161
}
108162

109163
public static LuceneSliceQueue create(
@@ -117,6 +171,7 @@ public static LuceneSliceQueue create(
117171
List<LuceneSlice> slices = new ArrayList<>();
118172
Map<String, PartitioningStrategy> partitioningStrategies = new HashMap<>(contexts.size());
119173

174+
int nextSliceId = 0;
120175
for (ShardContext ctx : contexts) {
121176
for (QueryAndTags queryAndExtra : queryFunction.apply(ctx)) {
122177
var scoreMode = scoreModeFunction.apply(ctx);
@@ -140,7 +195,7 @@ public static LuceneSliceQueue create(
140195
Weight weight = weight(ctx, query, scoreMode);
141196
for (List<PartialLeafReaderContext> group : groups) {
142197
if (group.isEmpty() == false) {
143-
slices.add(new LuceneSlice(ctx, group, weight, queryAndExtra.tags));
198+
slices.add(new LuceneSlice(nextSliceId++, ctx, group, weight, queryAndExtra.tags));
144199
}
145200
}
146201
}
@@ -184,50 +239,10 @@ List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int requeste
184239
@Override
185240
List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int requestedNumSlices) {
186241
final int totalDocCount = searcher.getIndexReader().maxDoc();
187-
final int normalMaxDocsPerSlice = totalDocCount / requestedNumSlices;
188-
final int extraDocsInFirstSlice = totalDocCount % requestedNumSlices;
189-
final List<List<PartialLeafReaderContext>> slices = new ArrayList<>();
190-
int docsAllocatedInCurrentSlice = 0;
191-
List<PartialLeafReaderContext> currentSlice = null;
192-
int maxDocsPerSlice = normalMaxDocsPerSlice + extraDocsInFirstSlice;
193-
for (LeafReaderContext ctx : searcher.getLeafContexts()) {
194-
final int numDocsInLeaf = ctx.reader().maxDoc();
195-
int minDoc = 0;
196-
while (minDoc < numDocsInLeaf) {
197-
int numDocsToUse = Math.min(maxDocsPerSlice - docsAllocatedInCurrentSlice, numDocsInLeaf - minDoc);
198-
if (numDocsToUse <= 0) {
199-
break;
200-
}
201-
if (currentSlice == null) {
202-
currentSlice = new ArrayList<>();
203-
}
204-
currentSlice.add(new PartialLeafReaderContext(ctx, minDoc, minDoc + numDocsToUse));
205-
minDoc += numDocsToUse;
206-
docsAllocatedInCurrentSlice += numDocsToUse;
207-
if (docsAllocatedInCurrentSlice == maxDocsPerSlice) {
208-
slices.add(currentSlice);
209-
// once the first slice with the extra docs is added, no need for extra docs
210-
maxDocsPerSlice = normalMaxDocsPerSlice;
211-
currentSlice = null;
212-
docsAllocatedInCurrentSlice = 0;
213-
}
214-
}
215-
}
216-
if (currentSlice != null) {
217-
slices.add(currentSlice);
218-
}
219-
if (requestedNumSlices < totalDocCount && slices.size() != requestedNumSlices) {
220-
throw new IllegalStateException("wrong number of slices, expected " + requestedNumSlices + " but got " + slices.size());
221-
}
222-
if (slices.stream()
223-
.flatMapToInt(
224-
l -> l.stream()
225-
.mapToInt(partialLeafReaderContext -> partialLeafReaderContext.maxDoc() - partialLeafReaderContext.minDoc())
226-
)
227-
.sum() != totalDocCount) {
228-
throw new IllegalStateException("wrong doc count");
229-
}
230-
return slices;
242+
// Cap the desired slice at 1M to prevent CPU underutilization when matching documents are concentrated in one segment
243+
// region.
244+
int desiredSliceSize = Math.clamp(Math.ceilDiv(totalDocCount, requestedNumSlices), 1, MAX_DOCS_PER_SLICE * 4);
245+
return new AdaptivePartitioner(Math.max(1, desiredSliceSize), MAX_SEGMENTS_PER_SLICE).partition(searcher.getLeafContexts());
231246
}
232247
};
233248

@@ -278,6 +293,10 @@ private static PartitioningStrategy forAuto(Function<Query, PartitioningStrategy
278293
if (ctx.searcher().getIndexReader().maxDoc() < SMALL_INDEX_BOUNDARY) {
279294
return PartitioningStrategy.SHARD;
280295
}
296+
if (ctx.searcher().getLeafContexts().size() <= 1) {
297+
// if there's only one segment, we can use segment partitioning
298+
return PartitioningStrategy.SEGMENT;
299+
}
281300
return autoStrategy.apply(query);
282301
}
283302
}
@@ -291,4 +310,67 @@ static Weight weight(ShardContext ctx, Query query, ScoreMode scoreMode) {
291310
throw new UncheckedIOException(e);
292311
}
293312
}
313+
314+
static final class AdaptivePartitioner {
315+
final int desiredDocsPerSlice;
316+
final int maxDocsPerSlice;
317+
final int maxSegmentsPerSlice;
318+
319+
AdaptivePartitioner(int desiredDocsPerSlice, int maxSegmentsPerSlice) {
320+
this.desiredDocsPerSlice = desiredDocsPerSlice;
321+
this.maxDocsPerSlice = desiredDocsPerSlice * 5 / 4;
322+
this.maxSegmentsPerSlice = maxSegmentsPerSlice;
323+
}
324+
325+
List<List<PartialLeafReaderContext>> partition(List<LeafReaderContext> leaves) {
326+
List<LeafReaderContext> smallSegments = new ArrayList<>();
327+
List<LeafReaderContext> largeSegments = new ArrayList<>();
328+
List<List<PartialLeafReaderContext>> results = new ArrayList<>();
329+
for (LeafReaderContext leaf : leaves) {
330+
if (leaf.reader().maxDoc() >= 5 * desiredDocsPerSlice) {
331+
largeSegments.add(leaf);
332+
} else {
333+
smallSegments.add(leaf);
334+
}
335+
}
336+
largeSegments.sort(Collections.reverseOrder(Comparator.comparingInt(l -> l.reader().maxDoc())));
337+
for (LeafReaderContext segment : largeSegments) {
338+
results.addAll(partitionOneLargeSegment(segment));
339+
}
340+
results.addAll(partitionSmallSegments(smallSegments));
341+
return results;
342+
}
343+
344+
List<List<PartialLeafReaderContext>> partitionOneLargeSegment(LeafReaderContext leaf) {
345+
int numDocsInLeaf = leaf.reader().maxDoc();
346+
int numSlices = Math.max(1, numDocsInLeaf / desiredDocsPerSlice);
347+
while (Math.ceilDiv(numDocsInLeaf, numSlices) > maxDocsPerSlice) {
348+
numSlices++;
349+
}
350+
int docPerSlice = numDocsInLeaf / numSlices;
351+
int leftoverDocs = numDocsInLeaf % numSlices;
352+
int minDoc = 0;
353+
List<List<PartialLeafReaderContext>> results = new ArrayList<>();
354+
while (minDoc < numDocsInLeaf) {
355+
int docsToUse = docPerSlice;
356+
if (leftoverDocs > 0) {
357+
--leftoverDocs;
358+
docsToUse++;
359+
}
360+
int maxDoc = Math.min(minDoc + docsToUse, numDocsInLeaf);
361+
results.add(List.of(new PartialLeafReaderContext(leaf, minDoc, maxDoc)));
362+
minDoc = maxDoc;
363+
}
364+
assert leftoverDocs == 0 : leftoverDocs;
365+
assert results.stream().allMatch(s -> s.size() == 1) : "must have one partial leaf per slice";
366+
assert results.stream().flatMapToInt(ss -> ss.stream().mapToInt(s -> s.maxDoc() - s.minDoc())).sum() == numDocsInLeaf;
367+
return results;
368+
}
369+
370+
List<List<PartialLeafReaderContext>> partitionSmallSegments(List<LeafReaderContext> leaves) {
371+
var slices = IndexSearcher.slices(leaves, maxDocsPerSlice, maxSegmentsPerSlice, true);
372+
return Arrays.stream(slices).map(g -> Arrays.stream(g.partitions).map(PartialLeafReaderContext::new).toList()).toList();
373+
}
374+
}
375+
294376
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/TimeSeriesSourceOperator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ public Page getCheckedOutput() throws IOException {
9797
long startInNanos = System.nanoTime();
9898
try {
9999
if (iterator == null) {
100-
var slice = sliceQueue.nextSlice();
100+
var slice = sliceQueue.nextSlice(null);
101101
if (slice == null) {
102102
doneCollecting = true;
103103
return null;

0 commit comments

Comments
 (0)