Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import org.elasticsearch.compute.operator.Driver;

import java.util.List;

/**
* How we partition the data across {@link Driver}s. Each request forks into
* {@code min(1.5 * cpus, partition_count)} threads on the data node. More partitions
Expand Down Expand Up @@ -37,9 +39,20 @@ public enum DataPartitioning {
*/
SEGMENT,
/**
* Partition each shard into {@code task_concurrency} partitions, splitting
* larger segments into slices. This allows bringing the most CPUs to bear on
* the problem but adds extra overhead, especially in query preparation.
* Partitions into dynamic-sized slices to improve CPU utilization while keeping overhead low.
* This approach is more flexible than {@link #SEGMENT} and works as follows:
*
* <p>1. The slice size starts from a desired size based on {@code task_concurrency} but is capped
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

<ol>

* at around {@link LuceneSliceQueue#MAX_DOCS_PER_SLICE}. This prevents poor CPU usage when
* matching documents are clustered together.
*
* <p>2. For small and medium segments (less than five times the desired slice size), it uses a
* slightly different {@link #SEGMENT} strategy, which also splits segments that are larger
* than the desired size. See {@link org.apache.lucene.search.IndexSearcher#slices(List, int, int, boolean)}.
*
* <p>3. For very large segments, multiple segments are not combined into a single slice. This allows
* one driver to process an entire large segment until other drivers steal the work after finishing
* their own tasks. See {@link LuceneSliceQueue#nextSlice(LuceneSlice)}.
*/
DOC,
DOC
}
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ LuceneScorer getCurrentOrLoadNextScorer() {
while (currentScorer == null || currentScorer.isDone()) {
if (currentSlice == null || sliceIndex >= currentSlice.numLeaves()) {
sliceIndex = 0;
currentSlice = sliceQueue.nextSlice();
currentSlice = sliceQueue.nextSlice(currentSlice);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

if (currentSlice == null) {
doneCollecting = true;
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
/**
* Holds a list of multiple partial Lucene segments
*/
public record LuceneSlice(ShardContext shardContext, List<PartialLeafReaderContext> leaves, Weight weight, List<Object> tags) {
public record LuceneSlice(
int slicePosition,
ShardContext shardContext,
List<PartialLeafReaderContext> leaves,
Weight weight,
List<Object> tags
) {
int numLeaves() {
return leaves.size();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,21 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.core.Nullable;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicReferenceArray;
import java.util.function.Function;

/**
Expand Down Expand Up @@ -77,18 +80,62 @@ public record QueryAndTags(Query query, List<Object> tags) {}
public static final int MAX_SEGMENTS_PER_SLICE = 5; // copied from IndexSearcher

private final int totalSlices;
private final Queue<LuceneSlice> slices;
private final AtomicReferenceArray<LuceneSlice> slices;
private final Queue<Integer> startedPositions;
private final Queue<Integer> followedPositions;
private final Map<String, PartitioningStrategy> partitioningStrategies;

private LuceneSliceQueue(List<LuceneSlice> slices, Map<String, PartitioningStrategy> partitioningStrategies) {
this.totalSlices = slices.size();
this.slices = new ConcurrentLinkedQueue<>(slices);
LuceneSliceQueue(List<LuceneSlice> sliceList, Map<String, PartitioningStrategy> partitioningStrategies) {
this.totalSlices = sliceList.size();
this.slices = new AtomicReferenceArray<>(sliceList.size());
for (int i = 0; i < sliceList.size(); i++) {
slices.set(i, sliceList.get(i));
}
this.partitioningStrategies = partitioningStrategies;
this.startedPositions = ConcurrentCollections.newQueue();
this.followedPositions = ConcurrentCollections.newQueue();
for (LuceneSlice slice : sliceList) {
if (slice.getLeaf(0).minDoc() == 0) {
startedPositions.add(slice.slicePosition());
} else {
followedPositions.add(slice.slicePosition());
}
}
}

/**
* Retrieves the next available {@link LuceneSlice} for processing.
* If a previous slice is provided, this method first attempts to return the next sequential slice to maintain segment affinity
* and minimize the cost of switching between segments.
* <p>
* If no sequential slice is available, it returns the next slice from the {@code startedPositions} queue, which starts a new
* group of segments. If all started positions are exhausted, it steals a slice from the {@code followedPositions} queue,
* enabling work stealing.
*
* @param prev the previously returned {@link LuceneSlice}, or {@code null} if starting
* @return the next available {@link LuceneSlice}, or {@code null} if exhausted
*/
@Nullable
public LuceneSlice nextSlice() {
return slices.poll();
public LuceneSlice nextSlice(LuceneSlice prev) {
if (prev != null) {
final int nextId = prev.slicePosition() + 1;
if (nextId < totalSlices) {
var slice = slices.getAndSet(nextId, null);
if (slice != null) {
return slice;
}
}
}
for (var ids : List.of(startedPositions, followedPositions)) {
Integer nextId;
while ((nextId = ids.poll()) != null) {
var slice = slices.getAndSet(nextId, null);
if (slice != null) {
return slice;
}
}
}
return null;
}

public int totalSlices() {
Expand All @@ -103,7 +150,14 @@ public Map<String, PartitioningStrategy> partitioningStrategies() {
}

public Collection<String> remainingShardsIdentifiers() {
return slices.stream().map(slice -> slice.shardContext().shardIdentifier()).toList();
List<String> remaining = new ArrayList<>(slices.length());
for (int i = 0; i < slices.length(); i++) {
LuceneSlice slice = slices.get(i);
if (slice != null) {
remaining.add(slice.shardContext().shardIdentifier());
}
}
return remaining;
}

public static LuceneSliceQueue create(
Expand All @@ -117,6 +171,7 @@ public static LuceneSliceQueue create(
List<LuceneSlice> slices = new ArrayList<>();
Map<String, PartitioningStrategy> partitioningStrategies = new HashMap<>(contexts.size());

int nextSliceId = 0;
for (ShardContext ctx : contexts) {
for (QueryAndTags queryAndExtra : queryFunction.apply(ctx)) {
var scoreMode = scoreModeFunction.apply(ctx);
Expand All @@ -140,7 +195,7 @@ public static LuceneSliceQueue create(
Weight weight = weight(ctx, query, scoreMode);
for (List<PartialLeafReaderContext> group : groups) {
if (group.isEmpty() == false) {
slices.add(new LuceneSlice(ctx, group, weight, queryAndExtra.tags));
slices.add(new LuceneSlice(nextSliceId++, ctx, group, weight, queryAndExtra.tags));
}
}
}
Expand Down Expand Up @@ -184,50 +239,9 @@ List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int requeste
@Override
List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int requestedNumSlices) {
final int totalDocCount = searcher.getIndexReader().maxDoc();
final int normalMaxDocsPerSlice = totalDocCount / requestedNumSlices;
final int extraDocsInFirstSlice = totalDocCount % requestedNumSlices;
final List<List<PartialLeafReaderContext>> slices = new ArrayList<>();
int docsAllocatedInCurrentSlice = 0;
List<PartialLeafReaderContext> currentSlice = null;
int maxDocsPerSlice = normalMaxDocsPerSlice + extraDocsInFirstSlice;
for (LeafReaderContext ctx : searcher.getLeafContexts()) {
final int numDocsInLeaf = ctx.reader().maxDoc();
int minDoc = 0;
while (minDoc < numDocsInLeaf) {
int numDocsToUse = Math.min(maxDocsPerSlice - docsAllocatedInCurrentSlice, numDocsInLeaf - minDoc);
if (numDocsToUse <= 0) {
break;
}
if (currentSlice == null) {
currentSlice = new ArrayList<>();
}
currentSlice.add(new PartialLeafReaderContext(ctx, minDoc, minDoc + numDocsToUse));
minDoc += numDocsToUse;
docsAllocatedInCurrentSlice += numDocsToUse;
if (docsAllocatedInCurrentSlice == maxDocsPerSlice) {
slices.add(currentSlice);
// once the first slice with the extra docs is added, no need for extra docs
maxDocsPerSlice = normalMaxDocsPerSlice;
currentSlice = null;
docsAllocatedInCurrentSlice = 0;
}
}
}
if (currentSlice != null) {
slices.add(currentSlice);
}
if (requestedNumSlices < totalDocCount && slices.size() != requestedNumSlices) {
throw new IllegalStateException("wrong number of slices, expected " + requestedNumSlices + " but got " + slices.size());
}
if (slices.stream()
.flatMapToInt(
l -> l.stream()
.mapToInt(partialLeafReaderContext -> partialLeafReaderContext.maxDoc() - partialLeafReaderContext.minDoc())
)
.sum() != totalDocCount) {
throw new IllegalStateException("wrong doc count");
}
return slices;
// Cap the desired slice to prevent CPU underutilization when matching documents are concentrated in one segment region.
int desiredSliceSize = Math.clamp(Math.ceilDiv(totalDocCount, requestedNumSlices), 1, MAX_DOCS_PER_SLICE);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we shouldn't call this requestedNumSlices any more - it's just taskConcurrency here. At least, we're not respecting the request for the number of slices - we absolutely got above it via MAX_DOCS_PER_SLICE.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++ updated in 50ebdd3

return new AdaptivePartitioner(Math.max(1, desiredSliceSize), MAX_SEGMENTS_PER_SLICE).partition(searcher.getLeafContexts());
}
};

Expand Down Expand Up @@ -291,4 +305,67 @@ static Weight weight(ShardContext ctx, Query query, ScoreMode scoreMode) {
throw new UncheckedIOException(e);
}
}

static final class AdaptivePartitioner {
final int desiredDocsPerSlice;
final int maxDocsPerSlice;
final int maxSegmentsPerSlice;

AdaptivePartitioner(int desiredDocsPerSlice, int maxSegmentsPerSlice) {
this.desiredDocsPerSlice = desiredDocsPerSlice;
this.maxDocsPerSlice = desiredDocsPerSlice * 5 / 4;
this.maxSegmentsPerSlice = maxSegmentsPerSlice;
}

List<List<PartialLeafReaderContext>> partition(List<LeafReaderContext> leaves) {
List<LeafReaderContext> smallSegments = new ArrayList<>();
List<LeafReaderContext> largeSegments = new ArrayList<>();
List<List<PartialLeafReaderContext>> results = new ArrayList<>();
for (LeafReaderContext leaf : leaves) {
if (leaf.reader().maxDoc() >= 5 * desiredDocsPerSlice) {
largeSegments.add(leaf);
} else {
smallSegments.add(leaf);
}
}
largeSegments.sort(Collections.reverseOrder(Comparator.comparingInt(l -> l.reader().maxDoc())));
for (LeafReaderContext segment : largeSegments) {
results.addAll(partitionOneLargeSegment(segment));
}
results.addAll(partitionSmallSegments(smallSegments));
return results;
}

List<List<PartialLeafReaderContext>> partitionOneLargeSegment(LeafReaderContext leaf) {
int numDocsInLeaf = leaf.reader().maxDoc();
int numSlices = Math.max(1, numDocsInLeaf / desiredDocsPerSlice);
while (Math.ceilDiv(numDocsInLeaf, numSlices) > maxDocsPerSlice) {
numSlices++;
}
int docPerSlice = numDocsInLeaf / numSlices;
int leftoverDocs = numDocsInLeaf % numSlices;
int minDoc = 0;
List<List<PartialLeafReaderContext>> results = new ArrayList<>();
while (minDoc < numDocsInLeaf) {
int docsToUse = docPerSlice;
if (leftoverDocs > 0) {
--leftoverDocs;
docsToUse++;
}
int maxDoc = Math.min(minDoc + docsToUse, numDocsInLeaf);
results.add(List.of(new PartialLeafReaderContext(leaf, minDoc, maxDoc)));
minDoc = maxDoc;
}
assert leftoverDocs == 0 : leftoverDocs;
assert results.stream().allMatch(s -> s.size() == 1) : "must have one partial leaf per slice";
assert results.stream().flatMapToInt(ss -> ss.stream().mapToInt(s -> s.maxDoc() - s.minDoc())).sum() == numDocsInLeaf;
return results;
}

List<List<PartialLeafReaderContext>> partitionSmallSegments(List<LeafReaderContext> leaves) {
var slices = IndexSearcher.slices(leaves, maxDocsPerSlice, maxSegmentsPerSlice, true);
return Arrays.stream(slices).map(g -> Arrays.stream(g.partitions).map(PartialLeafReaderContext::new).toList()).toList();
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public Page getCheckedOutput() throws IOException {
long startInNanos = System.nanoTime();
try {
if (iterator == null) {
var slice = sliceQueue.nextSlice();
var slice = sliceQueue.nextSlice(null);
if (slice == null) {
doneCollecting = true;
return null;
Expand Down
Loading