Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/132774.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 132774
summary: Improve cpu utilization with dynamic slice size in doc partitioning
area: ES|QL
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import org.elasticsearch.compute.operator.Driver;

import java.util.List;

/**
* How we partition the data across {@link Driver}s. Each request forks into
* {@code min(1.5 * cpus, partition_count)} threads on the data node. More partitions
Expand Down Expand Up @@ -37,9 +39,20 @@ public enum DataPartitioning {
*/
SEGMENT,
/**
* Partition each shard into {@code task_concurrency} partitions, splitting
* larger segments into slices. This allows bringing the most CPUs to bear on
* the problem but adds extra overhead, especially in query preparation.
* Partitions into dynamic-sized slices to improve CPU utilization while keeping overhead low.
* This approach is more flexible than {@link #SEGMENT} and works as follows:
*
* <ol>
* <li>The slice size starts from a desired size based on {@code task_concurrency} but is capped
* at around {@link LuceneSliceQueue#MAX_DOCS_PER_SLICE}. This prevents poor CPU usage when
* matching documents are clustered together.</li>
* <li>For small and medium segments (less than five times the desired slice size), it uses a
* slightly different {@link #SEGMENT} strategy, which also splits segments that are larger
* than the desired size. See {@link org.apache.lucene.search.IndexSearcher#slices(List, int, int, boolean)}.</li>
* <li>For very large segments, multiple segments are not combined into a single slice. This allows
* one driver to process an entire large segment until other drivers steal the work after finishing
* their own tasks. See {@link LuceneSliceQueue#nextSlice(LuceneSlice)}.</li>
* </ol>
*/
DOC,
DOC
}
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ LuceneScorer getCurrentOrLoadNextScorer() {
while (currentScorer == null || currentScorer.isDone()) {
if (currentSlice == null || sliceIndex >= currentSlice.numLeaves()) {
sliceIndex = 0;
currentSlice = sliceQueue.nextSlice();
currentSlice = sliceQueue.nextSlice(currentSlice);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

if (currentSlice == null) {
doneCollecting = true;
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
/**
* Holds a list of multiple partial Lucene segments
*/
public record LuceneSlice(ShardContext shardContext, List<PartialLeafReaderContext> leaves, Weight weight, List<Object> tags) {
public record LuceneSlice(
int slicePosition,
ShardContext shardContext,
List<PartialLeafReaderContext> leaves,
Weight weight,
List<Object> tags
) {
int numLeaves() {
return leaves.size();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,21 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.core.Nullable;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicReferenceArray;
import java.util.function.Function;

/**
Expand Down Expand Up @@ -77,18 +80,78 @@ public record QueryAndTags(Query query, List<Object> tags) {}
public static final int MAX_SEGMENTS_PER_SLICE = 5; // copied from IndexSearcher

private final int totalSlices;
private final Queue<LuceneSlice> slices;
private final Map<String, PartitioningStrategy> partitioningStrategies;

private LuceneSliceQueue(List<LuceneSlice> slices, Map<String, PartitioningStrategy> partitioningStrategies) {
this.totalSlices = slices.size();
this.slices = new ConcurrentLinkedQueue<>(slices);
private final AtomicReferenceArray<LuceneSlice> slices;
/**
* Queue of slice IDs that are the primary entry point for a new group of segments.
* A driver should prioritize polling from this queue after failing to get a sequential
* slice (the segment affinity). This ensures that threads start work on fresh,
* independent segment groups before resorting to work stealing.
*/
private final Queue<Integer> sliceHeads;

/**
* Queue of slice IDs that are not the primary entry point for a segment group.
* This queue serves as a fallback pool for work stealing. When a thread has no more independent work,
* it will "steal" a slice from this queue to keep itself utilized. A driver should pull tasks from
* this queue only when {@code sliceHeads} has been exhausted.
*/
private final Queue<Integer> stealableSlices;

LuceneSliceQueue(List<LuceneSlice> sliceList, Map<String, PartitioningStrategy> partitioningStrategies) {
this.totalSlices = sliceList.size();
this.slices = new AtomicReferenceArray<>(sliceList.size());
for (int i = 0; i < sliceList.size(); i++) {
slices.set(i, sliceList.get(i));
}
this.partitioningStrategies = partitioningStrategies;
this.sliceHeads = ConcurrentCollections.newQueue();
this.stealableSlices = ConcurrentCollections.newQueue();
for (LuceneSlice slice : sliceList) {
if (slice.getLeaf(0).minDoc() == 0) {
sliceHeads.add(slice.slicePosition());
} else {
stealableSlices.add(slice.slicePosition());
}
}
}

/**
* Retrieves the next available {@link LuceneSlice} for processing.
* <p>
* This method implements a three-tiered strategy to minimize the overhead of switching between segments:
* 1. If a previous slice is provided, it first attempts to return the next sequential slice.
* This keeps a thread working on the same segments, minimizing the overhead of segment switching.
* 2. If affinity fails, it returns a slice from the {@link #sliceHeads} queue, which is an entry point for
* a new, independent group of segments, allowing the calling Driver to work on a fresh set of segments.
* 3. If the {@link #sliceHeads} queue is exhausted, it "steals" a slice
* from the {@link #stealableSlices} queue. This fallback ensures all threads remain utilized.
*
* @param prev the previously returned {@link LuceneSlice}, or {@code null} if starting
* @return the next available {@link LuceneSlice}, or {@code null} if exhausted
*/
@Nullable
public LuceneSlice nextSlice() {
return slices.poll();
public LuceneSlice nextSlice(LuceneSlice prev) {
if (prev != null) {
final int nextId = prev.slicePosition() + 1;
if (nextId < totalSlices) {
var slice = slices.getAndSet(nextId, null);
if (slice != null) {
return slice;
}
}
}
for (var ids : List.of(sliceHeads, stealableSlices)) {
Integer nextId;
while ((nextId = ids.poll()) != null) {
var slice = slices.getAndSet(nextId, null);
if (slice != null) {
return slice;
}
}
}
return null;
}

public int totalSlices() {
Expand All @@ -103,7 +166,14 @@ public Map<String, PartitioningStrategy> partitioningStrategies() {
}

public Collection<String> remainingShardsIdentifiers() {
return slices.stream().map(slice -> slice.shardContext().shardIdentifier()).toList();
List<String> remaining = new ArrayList<>(slices.length());
for (int i = 0; i < slices.length(); i++) {
LuceneSlice slice = slices.get(i);
if (slice != null) {
remaining.add(slice.shardContext().shardIdentifier());
}
}
return remaining;
}

public static LuceneSliceQueue create(
Expand All @@ -117,6 +187,7 @@ public static LuceneSliceQueue create(
List<LuceneSlice> slices = new ArrayList<>();
Map<String, PartitioningStrategy> partitioningStrategies = new HashMap<>(contexts.size());

int nextSliceId = 0;
for (ShardContext ctx : contexts) {
for (QueryAndTags queryAndExtra : queryFunction.apply(ctx)) {
var scoreMode = scoreModeFunction.apply(ctx);
Expand All @@ -140,7 +211,7 @@ public static LuceneSliceQueue create(
Weight weight = weight(ctx, query, scoreMode);
for (List<PartialLeafReaderContext> group : groups) {
if (group.isEmpty() == false) {
slices.add(new LuceneSlice(ctx, group, weight, queryAndExtra.tags));
slices.add(new LuceneSlice(nextSliceId++, ctx, group, weight, queryAndExtra.tags));
}
}
}
Expand All @@ -158,7 +229,7 @@ public enum PartitioningStrategy implements Writeable {
*/
SHARD(0) {
@Override
List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int requestedNumSlices) {
List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int taskConcurrency) {
return List.of(searcher.getLeafContexts().stream().map(PartialLeafReaderContext::new).toList());
}
},
Expand All @@ -167,7 +238,7 @@ List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int requeste
*/
SEGMENT(1) {
@Override
List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int requestedNumSlices) {
List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int taskConcurrency) {
IndexSearcher.LeafSlice[] gs = IndexSearcher.slices(
searcher.getLeafContexts(),
MAX_DOCS_PER_SLICE,
Expand All @@ -182,52 +253,11 @@ List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int requeste
*/
DOC(2) {
@Override
List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int requestedNumSlices) {
List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int taskConcurrency) {
final int totalDocCount = searcher.getIndexReader().maxDoc();
final int normalMaxDocsPerSlice = totalDocCount / requestedNumSlices;
final int extraDocsInFirstSlice = totalDocCount % requestedNumSlices;
final List<List<PartialLeafReaderContext>> slices = new ArrayList<>();
int docsAllocatedInCurrentSlice = 0;
List<PartialLeafReaderContext> currentSlice = null;
int maxDocsPerSlice = normalMaxDocsPerSlice + extraDocsInFirstSlice;
for (LeafReaderContext ctx : searcher.getLeafContexts()) {
final int numDocsInLeaf = ctx.reader().maxDoc();
int minDoc = 0;
while (minDoc < numDocsInLeaf) {
int numDocsToUse = Math.min(maxDocsPerSlice - docsAllocatedInCurrentSlice, numDocsInLeaf - minDoc);
if (numDocsToUse <= 0) {
break;
}
if (currentSlice == null) {
currentSlice = new ArrayList<>();
}
currentSlice.add(new PartialLeafReaderContext(ctx, minDoc, minDoc + numDocsToUse));
minDoc += numDocsToUse;
docsAllocatedInCurrentSlice += numDocsToUse;
if (docsAllocatedInCurrentSlice == maxDocsPerSlice) {
slices.add(currentSlice);
// once the first slice with the extra docs is added, no need for extra docs
maxDocsPerSlice = normalMaxDocsPerSlice;
currentSlice = null;
docsAllocatedInCurrentSlice = 0;
}
}
}
if (currentSlice != null) {
slices.add(currentSlice);
}
if (requestedNumSlices < totalDocCount && slices.size() != requestedNumSlices) {
throw new IllegalStateException("wrong number of slices, expected " + requestedNumSlices + " but got " + slices.size());
}
if (slices.stream()
.flatMapToInt(
l -> l.stream()
.mapToInt(partialLeafReaderContext -> partialLeafReaderContext.maxDoc() - partialLeafReaderContext.minDoc())
)
.sum() != totalDocCount) {
throw new IllegalStateException("wrong doc count");
}
return slices;
// Cap the desired slice to prevent CPU underutilization when matching documents are concentrated in one segment region.
int desiredSliceSize = Math.clamp(Math.ceilDiv(totalDocCount, taskConcurrency), 1, MAX_DOCS_PER_SLICE);
return new AdaptivePartitioner(Math.max(1, desiredSliceSize), MAX_SEGMENTS_PER_SLICE).partition(searcher.getLeafContexts());
}
};

Expand All @@ -252,7 +282,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeByte(id);
}

abstract List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int requestedNumSlices);
abstract List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int taskConcurrency);

private static PartitioningStrategy pick(
DataPartitioning dataPartitioning,
Expand Down Expand Up @@ -291,4 +321,67 @@ static Weight weight(ShardContext ctx, Query query, ScoreMode scoreMode) {
throw new UncheckedIOException(e);
}
}

static final class AdaptivePartitioner {
final int desiredDocsPerSlice;
final int maxDocsPerSlice;
final int maxSegmentsPerSlice;

AdaptivePartitioner(int desiredDocsPerSlice, int maxSegmentsPerSlice) {
this.desiredDocsPerSlice = desiredDocsPerSlice;
this.maxDocsPerSlice = desiredDocsPerSlice * 5 / 4;
this.maxSegmentsPerSlice = maxSegmentsPerSlice;
}

List<List<PartialLeafReaderContext>> partition(List<LeafReaderContext> leaves) {
List<LeafReaderContext> smallSegments = new ArrayList<>();
List<LeafReaderContext> largeSegments = new ArrayList<>();
List<List<PartialLeafReaderContext>> results = new ArrayList<>();
for (LeafReaderContext leaf : leaves) {
if (leaf.reader().maxDoc() >= 5 * desiredDocsPerSlice) {
largeSegments.add(leaf);
} else {
smallSegments.add(leaf);
}
}
largeSegments.sort(Collections.reverseOrder(Comparator.comparingInt(l -> l.reader().maxDoc())));
for (LeafReaderContext segment : largeSegments) {
results.addAll(partitionOneLargeSegment(segment));
}
results.addAll(partitionSmallSegments(smallSegments));
return results;
}

List<List<PartialLeafReaderContext>> partitionOneLargeSegment(LeafReaderContext leaf) {
int numDocsInLeaf = leaf.reader().maxDoc();
int numSlices = Math.max(1, numDocsInLeaf / desiredDocsPerSlice);
while (Math.ceilDiv(numDocsInLeaf, numSlices) > maxDocsPerSlice) {
numSlices++;
}
int docPerSlice = numDocsInLeaf / numSlices;
int leftoverDocs = numDocsInLeaf % numSlices;
int minDoc = 0;
List<List<PartialLeafReaderContext>> results = new ArrayList<>();
while (minDoc < numDocsInLeaf) {
int docsToUse = docPerSlice;
if (leftoverDocs > 0) {
--leftoverDocs;
docsToUse++;
}
int maxDoc = Math.min(minDoc + docsToUse, numDocsInLeaf);
results.add(List.of(new PartialLeafReaderContext(leaf, minDoc, maxDoc)));
minDoc = maxDoc;
}
assert leftoverDocs == 0 : leftoverDocs;
assert results.stream().allMatch(s -> s.size() == 1) : "must have one partial leaf per slice";
assert results.stream().flatMapToInt(ss -> ss.stream().mapToInt(s -> s.maxDoc() - s.minDoc())).sum() == numDocsInLeaf;
return results;
}

List<List<PartialLeafReaderContext>> partitionSmallSegments(List<LeafReaderContext> leaves) {
var slices = IndexSearcher.slices(leaves, maxDocsPerSlice, maxSegmentsPerSlice, true);
return Arrays.stream(slices).map(g -> Arrays.stream(g.partitions).map(PartialLeafReaderContext::new).toList()).toList();
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public Page getCheckedOutput() throws IOException {
long startInNanos = System.nanoTime();
try {
if (iterator == null) {
var slice = sliceQueue.nextSlice();
var slice = sliceQueue.nextSlice(null);
if (slice == null) {
doneCollecting = true;
return null;
Expand Down
Loading