Skip to content

Commit f52e21f

Browse files
committed
Improve cpu utilization with dynamic slice size in doc partitioning
1 parent 2864dd8 commit f52e21f

File tree

6 files changed

+504
-60
lines changed

6 files changed

+504
-60
lines changed

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/DataPartitioning.java

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
import org.elasticsearch.compute.operator.Driver;
1111

12+
import java.util.List;
13+
1214
/**
1315
* How we partition the data across {@link Driver}s. Each request forks into
1416
* {@code min(1.5 * cpus, partition_count)} threads on the data node. More partitions
@@ -37,9 +39,20 @@ public enum DataPartitioning {
3739
*/
3840
SEGMENT,
3941
/**
40-
* Partition each shard into {@code task_concurrency} partitions, splitting
41-
* larger segments into slices. This allows bringing the most CPUs to bear on
42-
* the problem but adds extra overhead, especially in query preparation.
42+
* Partitions into dynamic-sized slices to improve CPU utilization while keeping overhead low.
43+
* This approach is more flexible than {@link #SEGMENT} and works as follows:
44+
*
45+
* <p>1. The slice size starts from a desired size based on {@code task_concurrency} but is capped
46+
* at around {@link LuceneSliceQueue#MAX_DOCS_PER_SLICE}. This prevents poor CPU usage when
47+
* matching documents are clustered together.
48+
*
49+
* <p>2. For small and medium segments (less than five times the desired slice size), it uses a
50+
* slightly different {@link #SEGMENT} strategy, which also splits segments that are larger
51+
* than the desired size. See {@link org.apache.lucene.search.IndexSearcher#slices(List, int, int, boolean)}.
52+
*
53+
* <p>3. For very large segments, multiple segments are not combined into a single slice. This allows
54+
* one driver to process an entire large segment until other drivers steal the work after finishing
55+
* their own tasks. See {@link LuceneSliceQueue#nextSlice(LuceneSlice)}.
4356
*/
44-
DOC,
57+
DOC
4558
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneOperator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ LuceneScorer getCurrentOrLoadNextScorer() {
165165
while (currentScorer == null || currentScorer.isDone()) {
166166
if (currentSlice == null || sliceIndex >= currentSlice.numLeaves()) {
167167
sliceIndex = 0;
168-
currentSlice = sliceQueue.nextSlice();
168+
currentSlice = sliceQueue.nextSlice(currentSlice);
169169
if (currentSlice == null) {
170170
doneCollecting = true;
171171
return null;

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneSlice.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,13 @@
1414
/**
1515
* Holds a list of multiple partial Lucene segments
1616
*/
17-
public record LuceneSlice(ShardContext shardContext, List<PartialLeafReaderContext> leaves, Weight weight, List<Object> tags) {
17+
public record LuceneSlice(
18+
int slicePosition,
19+
ShardContext shardContext,
20+
List<PartialLeafReaderContext> leaves,
21+
Weight weight,
22+
List<Object> tags
23+
) {
1824
int numLeaves() {
1925
return leaves.size();
2026
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneSliceQueue.java

Lines changed: 130 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,21 @@
1616
import org.elasticsearch.common.io.stream.StreamInput;
1717
import org.elasticsearch.common.io.stream.StreamOutput;
1818
import org.elasticsearch.common.io.stream.Writeable;
19+
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
1920
import org.elasticsearch.core.Nullable;
2021

2122
import java.io.IOException;
2223
import java.io.UncheckedIOException;
2324
import java.util.ArrayList;
2425
import java.util.Arrays;
2526
import java.util.Collection;
27+
import java.util.Collections;
28+
import java.util.Comparator;
2629
import java.util.HashMap;
2730
import java.util.List;
2831
import java.util.Map;
2932
import java.util.Queue;
30-
import java.util.concurrent.ConcurrentLinkedQueue;
33+
import java.util.concurrent.atomic.AtomicReferenceArray;
3134
import java.util.function.Function;
3235

3336
/**
@@ -77,18 +80,62 @@ public record QueryAndTags(Query query, List<Object> tags) {}
7780
public static final int MAX_SEGMENTS_PER_SLICE = 5; // copied from IndexSearcher
7881

7982
private final int totalSlices;
80-
private final Queue<LuceneSlice> slices;
83+
private final AtomicReferenceArray<LuceneSlice> slices;
84+
private final Queue<Integer> startedPositions;
85+
private final Queue<Integer> followedPositions;
8186
private final Map<String, PartitioningStrategy> partitioningStrategies;
8287

83-
private LuceneSliceQueue(List<LuceneSlice> slices, Map<String, PartitioningStrategy> partitioningStrategies) {
84-
this.totalSlices = slices.size();
85-
this.slices = new ConcurrentLinkedQueue<>(slices);
88+
LuceneSliceQueue(List<LuceneSlice> sliceList, Map<String, PartitioningStrategy> partitioningStrategies) {
89+
this.totalSlices = sliceList.size();
90+
this.slices = new AtomicReferenceArray<>(sliceList.size());
91+
for (int i = 0; i < sliceList.size(); i++) {
92+
slices.set(i, sliceList.get(i));
93+
}
8694
this.partitioningStrategies = partitioningStrategies;
95+
this.startedPositions = ConcurrentCollections.newQueue();
96+
this.followedPositions = ConcurrentCollections.newQueue();
97+
for (LuceneSlice slice : sliceList) {
98+
if (slice.getLeaf(0).minDoc() == 0) {
99+
startedPositions.add(slice.slicePosition());
100+
} else {
101+
followedPositions.add(slice.slicePosition());
102+
}
103+
}
87104
}
88105

106+
/**
107+
* Retrieves the next available {@link LuceneSlice} for processing.
108+
* If a previous slice is provided, this method first attempts to return the next sequential slice to maintain segment affinity
109+
* and minimize the cost of switching between segments.
110+
* <p>
111+
* If no sequential slice is available, it returns the next slice from the {@code startedPositions} queue, which starts a new
112+
* group of segments. If all started positions are exhausted, it steals a slice from the {@code followedPositions} queue,
113+
* enabling work stealing.
114+
*
115+
* @param prev the previously returned {@link LuceneSlice}, or {@code null} if starting
116+
* @return the next available {@link LuceneSlice}, or {@code null} if exhausted
117+
*/
89118
@Nullable
90-
public LuceneSlice nextSlice() {
91-
return slices.poll();
119+
public LuceneSlice nextSlice(LuceneSlice prev) {
120+
if (prev != null) {
121+
final int nextId = prev.slicePosition() + 1;
122+
if (nextId < totalSlices) {
123+
var slice = slices.getAndSet(nextId, null);
124+
if (slice != null) {
125+
return slice;
126+
}
127+
}
128+
}
129+
for (var ids : List.of(startedPositions, followedPositions)) {
130+
Integer nextId;
131+
while ((nextId = ids.poll()) != null) {
132+
var slice = slices.getAndSet(nextId, null);
133+
if (slice != null) {
134+
return slice;
135+
}
136+
}
137+
}
138+
return null;
92139
}
93140

94141
public int totalSlices() {
@@ -103,7 +150,14 @@ public Map<String, PartitioningStrategy> partitioningStrategies() {
103150
}
104151

105152
public Collection<String> remainingShardsIdentifiers() {
106-
return slices.stream().map(slice -> slice.shardContext().shardIdentifier()).toList();
153+
List<String> remaining = new ArrayList<>(slices.length());
154+
for (int i = 0; i < slices.length(); i++) {
155+
LuceneSlice slice = slices.get(i);
156+
if (slice != null) {
157+
remaining.add(slice.shardContext().shardIdentifier());
158+
}
159+
}
160+
return remaining;
107161
}
108162

109163
public static LuceneSliceQueue create(
@@ -117,6 +171,7 @@ public static LuceneSliceQueue create(
117171
List<LuceneSlice> slices = new ArrayList<>();
118172
Map<String, PartitioningStrategy> partitioningStrategies = new HashMap<>(contexts.size());
119173

174+
int nextSliceId = 0;
120175
for (ShardContext ctx : contexts) {
121176
for (QueryAndTags queryAndExtra : queryFunction.apply(ctx)) {
122177
var scoreMode = scoreModeFunction.apply(ctx);
@@ -140,7 +195,7 @@ public static LuceneSliceQueue create(
140195
Weight weight = weight(ctx, query, scoreMode);
141196
for (List<PartialLeafReaderContext> group : groups) {
142197
if (group.isEmpty() == false) {
143-
slices.add(new LuceneSlice(ctx, group, weight, queryAndExtra.tags));
198+
slices.add(new LuceneSlice(nextSliceId++, ctx, group, weight, queryAndExtra.tags));
144199
}
145200
}
146201
}
@@ -184,50 +239,9 @@ List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int requeste
184239
@Override
185240
List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int requestedNumSlices) {
186241
final int totalDocCount = searcher.getIndexReader().maxDoc();
187-
final int normalMaxDocsPerSlice = totalDocCount / requestedNumSlices;
188-
final int extraDocsInFirstSlice = totalDocCount % requestedNumSlices;
189-
final List<List<PartialLeafReaderContext>> slices = new ArrayList<>();
190-
int docsAllocatedInCurrentSlice = 0;
191-
List<PartialLeafReaderContext> currentSlice = null;
192-
int maxDocsPerSlice = normalMaxDocsPerSlice + extraDocsInFirstSlice;
193-
for (LeafReaderContext ctx : searcher.getLeafContexts()) {
194-
final int numDocsInLeaf = ctx.reader().maxDoc();
195-
int minDoc = 0;
196-
while (minDoc < numDocsInLeaf) {
197-
int numDocsToUse = Math.min(maxDocsPerSlice - docsAllocatedInCurrentSlice, numDocsInLeaf - minDoc);
198-
if (numDocsToUse <= 0) {
199-
break;
200-
}
201-
if (currentSlice == null) {
202-
currentSlice = new ArrayList<>();
203-
}
204-
currentSlice.add(new PartialLeafReaderContext(ctx, minDoc, minDoc + numDocsToUse));
205-
minDoc += numDocsToUse;
206-
docsAllocatedInCurrentSlice += numDocsToUse;
207-
if (docsAllocatedInCurrentSlice == maxDocsPerSlice) {
208-
slices.add(currentSlice);
209-
// once the first slice with the extra docs is added, no need for extra docs
210-
maxDocsPerSlice = normalMaxDocsPerSlice;
211-
currentSlice = null;
212-
docsAllocatedInCurrentSlice = 0;
213-
}
214-
}
215-
}
216-
if (currentSlice != null) {
217-
slices.add(currentSlice);
218-
}
219-
if (requestedNumSlices < totalDocCount && slices.size() != requestedNumSlices) {
220-
throw new IllegalStateException("wrong number of slices, expected " + requestedNumSlices + " but got " + slices.size());
221-
}
222-
if (slices.stream()
223-
.flatMapToInt(
224-
l -> l.stream()
225-
.mapToInt(partialLeafReaderContext -> partialLeafReaderContext.maxDoc() - partialLeafReaderContext.minDoc())
226-
)
227-
.sum() != totalDocCount) {
228-
throw new IllegalStateException("wrong doc count");
229-
}
230-
return slices;
242+
// Cap the desired slice to prevent CPU underutilization when matching documents are concentrated in one segment region.
243+
int desiredSliceSize = Math.clamp(Math.ceilDiv(totalDocCount, requestedNumSlices), 1, MAX_DOCS_PER_SLICE);
244+
return new AdaptivePartitioner(Math.max(1, desiredSliceSize), MAX_SEGMENTS_PER_SLICE).partition(searcher.getLeafContexts());
231245
}
232246
};
233247

@@ -291,4 +305,67 @@ static Weight weight(ShardContext ctx, Query query, ScoreMode scoreMode) {
291305
throw new UncheckedIOException(e);
292306
}
293307
}
308+
309+
static final class AdaptivePartitioner {
310+
final int desiredDocsPerSlice;
311+
final int maxDocsPerSlice;
312+
final int maxSegmentsPerSlice;
313+
314+
AdaptivePartitioner(int desiredDocsPerSlice, int maxSegmentsPerSlice) {
315+
this.desiredDocsPerSlice = desiredDocsPerSlice;
316+
this.maxDocsPerSlice = desiredDocsPerSlice * 5 / 4;
317+
this.maxSegmentsPerSlice = maxSegmentsPerSlice;
318+
}
319+
320+
List<List<PartialLeafReaderContext>> partition(List<LeafReaderContext> leaves) {
321+
List<LeafReaderContext> smallSegments = new ArrayList<>();
322+
List<LeafReaderContext> largeSegments = new ArrayList<>();
323+
List<List<PartialLeafReaderContext>> results = new ArrayList<>();
324+
for (LeafReaderContext leaf : leaves) {
325+
if (leaf.reader().maxDoc() >= 5 * desiredDocsPerSlice) {
326+
largeSegments.add(leaf);
327+
} else {
328+
smallSegments.add(leaf);
329+
}
330+
}
331+
largeSegments.sort(Collections.reverseOrder(Comparator.comparingInt(l -> l.reader().maxDoc())));
332+
for (LeafReaderContext segment : largeSegments) {
333+
results.addAll(partitionOneLargeSegment(segment));
334+
}
335+
results.addAll(partitionSmallSegments(smallSegments));
336+
return results;
337+
}
338+
339+
List<List<PartialLeafReaderContext>> partitionOneLargeSegment(LeafReaderContext leaf) {
340+
int numDocsInLeaf = leaf.reader().maxDoc();
341+
int numSlices = Math.max(1, numDocsInLeaf / desiredDocsPerSlice);
342+
while (Math.ceilDiv(numDocsInLeaf, numSlices) > maxDocsPerSlice) {
343+
numSlices++;
344+
}
345+
int docPerSlice = numDocsInLeaf / numSlices;
346+
int leftoverDocs = numDocsInLeaf % numSlices;
347+
int minDoc = 0;
348+
List<List<PartialLeafReaderContext>> results = new ArrayList<>();
349+
while (minDoc < numDocsInLeaf) {
350+
int docsToUse = docPerSlice;
351+
if (leftoverDocs > 0) {
352+
--leftoverDocs;
353+
docsToUse++;
354+
}
355+
int maxDoc = Math.min(minDoc + docsToUse, numDocsInLeaf);
356+
results.add(List.of(new PartialLeafReaderContext(leaf, minDoc, maxDoc)));
357+
minDoc = maxDoc;
358+
}
359+
assert leftoverDocs == 0 : leftoverDocs;
360+
assert results.stream().allMatch(s -> s.size() == 1) : "must have one partial leaf per slice";
361+
assert results.stream().flatMapToInt(ss -> ss.stream().mapToInt(s -> s.maxDoc() - s.minDoc())).sum() == numDocsInLeaf;
362+
return results;
363+
}
364+
365+
List<List<PartialLeafReaderContext>> partitionSmallSegments(List<LeafReaderContext> leaves) {
366+
var slices = IndexSearcher.slices(leaves, maxDocsPerSlice, maxSegmentsPerSlice, true);
367+
return Arrays.stream(slices).map(g -> Arrays.stream(g.partitions).map(PartialLeafReaderContext::new).toList()).toList();
368+
}
369+
}
370+
294371
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/TimeSeriesSourceOperator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ public Page getCheckedOutput() throws IOException {
9797
long startInNanos = System.nanoTime();
9898
try {
9999
if (iterator == null) {
100-
var slice = sliceQueue.nextSlice();
100+
var slice = sliceQueue.nextSlice(null);
101101
if (slice == null) {
102102
doneCollecting = true;
103103
return null;

0 commit comments

Comments
 (0)