-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Improve cpu utilization with dynamic slice size in doc partitioning #132774
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
1c7de75
50ebdd3
22d8230
388eb39
f16fd98
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -165,7 +165,7 @@ LuceneScorer getCurrentOrLoadNextScorer() { | |
| while (currentScorer == null || currentScorer.isDone()) { | ||
| if (currentSlice == null || sliceIndex >= currentSlice.numLeaves()) { | ||
| sliceIndex = 0; | ||
| currentSlice = sliceQueue.nextSlice(); | ||
| currentSlice = sliceQueue.nextSlice(currentSlice); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
| if (currentSlice == null) { | ||
| doneCollecting = true; | ||
| return null; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,18 +16,21 @@ | |
| import org.elasticsearch.common.io.stream.StreamInput; | ||
| import org.elasticsearch.common.io.stream.StreamOutput; | ||
| import org.elasticsearch.common.io.stream.Writeable; | ||
| import org.elasticsearch.common.util.concurrent.ConcurrentCollections; | ||
| import org.elasticsearch.core.Nullable; | ||
|
|
||
| import java.io.IOException; | ||
| import java.io.UncheckedIOException; | ||
| import java.util.ArrayList; | ||
| import java.util.Arrays; | ||
| import java.util.Collection; | ||
| import java.util.Collections; | ||
| import java.util.Comparator; | ||
| import java.util.HashMap; | ||
| import java.util.List; | ||
| import java.util.Map; | ||
| import java.util.Queue; | ||
| import java.util.concurrent.ConcurrentLinkedQueue; | ||
| import java.util.concurrent.atomic.AtomicReferenceArray; | ||
| import java.util.function.Function; | ||
|
|
||
| /** | ||
|
|
@@ -77,18 +80,62 @@ public record QueryAndTags(Query query, List<Object> tags) {} | |
| public static final int MAX_SEGMENTS_PER_SLICE = 5; // copied from IndexSearcher | ||
|
|
||
| private final int totalSlices; | ||
| private final Queue<LuceneSlice> slices; | ||
| private final AtomicReferenceArray<LuceneSlice> slices; | ||
| private final Queue<Integer> startedPositions; | ||
| private final Queue<Integer> followedPositions; | ||
| private final Map<String, PartitioningStrategy> partitioningStrategies; | ||
|
|
||
| private LuceneSliceQueue(List<LuceneSlice> slices, Map<String, PartitioningStrategy> partitioningStrategies) { | ||
| this.totalSlices = slices.size(); | ||
| this.slices = new ConcurrentLinkedQueue<>(slices); | ||
| LuceneSliceQueue(List<LuceneSlice> sliceList, Map<String, PartitioningStrategy> partitioningStrategies) { | ||
| this.totalSlices = sliceList.size(); | ||
| this.slices = new AtomicReferenceArray<>(sliceList.size()); | ||
| for (int i = 0; i < sliceList.size(); i++) { | ||
| slices.set(i, sliceList.get(i)); | ||
| } | ||
| this.partitioningStrategies = partitioningStrategies; | ||
| this.startedPositions = ConcurrentCollections.newQueue(); | ||
| this.followedPositions = ConcurrentCollections.newQueue(); | ||
| for (LuceneSlice slice : sliceList) { | ||
| if (slice.getLeaf(0).minDoc() == 0) { | ||
| startedPositions.add(slice.slicePosition()); | ||
| } else { | ||
| followedPositions.add(slice.slicePosition()); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Retrieves the next available {@link LuceneSlice} for processing. | ||
| * If a previous slice is provided, this method first attempts to return the next sequential slice to maintain segment affinity | ||
| * and minimize the cost of switching between segments. | ||
| * <p> | ||
| * If no sequential slice is available, it returns the next slice from the {@code startedPositions} queue, which starts a new | ||
| * group of segments. If all started positions are exhausted, it steals a slice from the {@code followedPositions} queue, | ||
| * enabling work stealing. | ||
| * | ||
| * @param prev the previously returned {@link LuceneSlice}, or {@code null} if starting | ||
| * @return the next available {@link LuceneSlice}, or {@code null} if exhausted | ||
| */ | ||
| @Nullable | ||
| public LuceneSlice nextSlice() { | ||
| return slices.poll(); | ||
| public LuceneSlice nextSlice(LuceneSlice prev) { | ||
| if (prev != null) { | ||
| final int nextId = prev.slicePosition() + 1; | ||
| if (nextId < totalSlices) { | ||
| var slice = slices.getAndSet(nextId, null); | ||
| if (slice != null) { | ||
| return slice; | ||
| } | ||
| } | ||
| } | ||
| for (var ids : List.of(startedPositions, followedPositions)) { | ||
| Integer nextId; | ||
| while ((nextId = ids.poll()) != null) { | ||
| var slice = slices.getAndSet(nextId, null); | ||
| if (slice != null) { | ||
| return slice; | ||
| } | ||
| } | ||
| } | ||
| return null; | ||
| } | ||
|
|
||
| public int totalSlices() { | ||
|
|
@@ -103,7 +150,14 @@ public Map<String, PartitioningStrategy> partitioningStrategies() { | |
| } | ||
|
|
||
| public Collection<String> remainingShardsIdentifiers() { | ||
| return slices.stream().map(slice -> slice.shardContext().shardIdentifier()).toList(); | ||
| List<String> remaining = new ArrayList<>(slices.length()); | ||
| for (int i = 0; i < slices.length(); i++) { | ||
| LuceneSlice slice = slices.get(i); | ||
| if (slice != null) { | ||
| remaining.add(slice.shardContext().shardIdentifier()); | ||
| } | ||
| } | ||
| return remaining; | ||
| } | ||
|
|
||
| public static LuceneSliceQueue create( | ||
|
|
@@ -117,6 +171,7 @@ public static LuceneSliceQueue create( | |
| List<LuceneSlice> slices = new ArrayList<>(); | ||
| Map<String, PartitioningStrategy> partitioningStrategies = new HashMap<>(contexts.size()); | ||
|
|
||
| int nextSliceId = 0; | ||
| for (ShardContext ctx : contexts) { | ||
| for (QueryAndTags queryAndExtra : queryFunction.apply(ctx)) { | ||
| var scoreMode = scoreModeFunction.apply(ctx); | ||
|
|
@@ -140,7 +195,7 @@ public static LuceneSliceQueue create( | |
| Weight weight = weight(ctx, query, scoreMode); | ||
| for (List<PartialLeafReaderContext> group : groups) { | ||
| if (group.isEmpty() == false) { | ||
| slices.add(new LuceneSlice(ctx, group, weight, queryAndExtra.tags)); | ||
| slices.add(new LuceneSlice(nextSliceId++, ctx, group, weight, queryAndExtra.tags)); | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -184,50 +239,9 @@ List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int requeste | |
| @Override | ||
| List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int requestedNumSlices) { | ||
| final int totalDocCount = searcher.getIndexReader().maxDoc(); | ||
| final int normalMaxDocsPerSlice = totalDocCount / requestedNumSlices; | ||
| final int extraDocsInFirstSlice = totalDocCount % requestedNumSlices; | ||
| final List<List<PartialLeafReaderContext>> slices = new ArrayList<>(); | ||
| int docsAllocatedInCurrentSlice = 0; | ||
| List<PartialLeafReaderContext> currentSlice = null; | ||
| int maxDocsPerSlice = normalMaxDocsPerSlice + extraDocsInFirstSlice; | ||
| for (LeafReaderContext ctx : searcher.getLeafContexts()) { | ||
| final int numDocsInLeaf = ctx.reader().maxDoc(); | ||
| int minDoc = 0; | ||
| while (minDoc < numDocsInLeaf) { | ||
| int numDocsToUse = Math.min(maxDocsPerSlice - docsAllocatedInCurrentSlice, numDocsInLeaf - minDoc); | ||
| if (numDocsToUse <= 0) { | ||
| break; | ||
| } | ||
| if (currentSlice == null) { | ||
| currentSlice = new ArrayList<>(); | ||
| } | ||
| currentSlice.add(new PartialLeafReaderContext(ctx, minDoc, minDoc + numDocsToUse)); | ||
| minDoc += numDocsToUse; | ||
| docsAllocatedInCurrentSlice += numDocsToUse; | ||
| if (docsAllocatedInCurrentSlice == maxDocsPerSlice) { | ||
| slices.add(currentSlice); | ||
| // once the first slice with the extra docs is added, no need for extra docs | ||
| maxDocsPerSlice = normalMaxDocsPerSlice; | ||
| currentSlice = null; | ||
| docsAllocatedInCurrentSlice = 0; | ||
| } | ||
| } | ||
| } | ||
| if (currentSlice != null) { | ||
| slices.add(currentSlice); | ||
| } | ||
| if (requestedNumSlices < totalDocCount && slices.size() != requestedNumSlices) { | ||
| throw new IllegalStateException("wrong number of slices, expected " + requestedNumSlices + " but got " + slices.size()); | ||
| } | ||
| if (slices.stream() | ||
| .flatMapToInt( | ||
| l -> l.stream() | ||
| .mapToInt(partialLeafReaderContext -> partialLeafReaderContext.maxDoc() - partialLeafReaderContext.minDoc()) | ||
| ) | ||
| .sum() != totalDocCount) { | ||
| throw new IllegalStateException("wrong doc count"); | ||
| } | ||
| return slices; | ||
| // Cap the desired slice to prevent CPU underutilization when matching documents are concentrated in one segment region. | ||
| int desiredSliceSize = Math.clamp(Math.ceilDiv(totalDocCount, requestedNumSlices), 1, MAX_DOCS_PER_SLICE); | ||
|
||
| return new AdaptivePartitioner(Math.max(1, desiredSliceSize), MAX_SEGMENTS_PER_SLICE).partition(searcher.getLeafContexts()); | ||
| } | ||
| }; | ||
|
|
||
|
|
@@ -291,4 +305,67 @@ static Weight weight(ShardContext ctx, Query query, ScoreMode scoreMode) { | |
| throw new UncheckedIOException(e); | ||
| } | ||
| } | ||
|
|
||
| static final class AdaptivePartitioner { | ||
| final int desiredDocsPerSlice; | ||
| final int maxDocsPerSlice; | ||
| final int maxSegmentsPerSlice; | ||
|
|
||
| AdaptivePartitioner(int desiredDocsPerSlice, int maxSegmentsPerSlice) { | ||
| this.desiredDocsPerSlice = desiredDocsPerSlice; | ||
| this.maxDocsPerSlice = desiredDocsPerSlice * 5 / 4; | ||
| this.maxSegmentsPerSlice = maxSegmentsPerSlice; | ||
| } | ||
|
|
||
| List<List<PartialLeafReaderContext>> partition(List<LeafReaderContext> leaves) { | ||
| List<LeafReaderContext> smallSegments = new ArrayList<>(); | ||
| List<LeafReaderContext> largeSegments = new ArrayList<>(); | ||
| List<List<PartialLeafReaderContext>> results = new ArrayList<>(); | ||
| for (LeafReaderContext leaf : leaves) { | ||
| if (leaf.reader().maxDoc() >= 5 * desiredDocsPerSlice) { | ||
| largeSegments.add(leaf); | ||
| } else { | ||
| smallSegments.add(leaf); | ||
| } | ||
| } | ||
| largeSegments.sort(Collections.reverseOrder(Comparator.comparingInt(l -> l.reader().maxDoc()))); | ||
| for (LeafReaderContext segment : largeSegments) { | ||
| results.addAll(partitionOneLargeSegment(segment)); | ||
| } | ||
| results.addAll(partitionSmallSegments(smallSegments)); | ||
| return results; | ||
| } | ||
|
|
||
| List<List<PartialLeafReaderContext>> partitionOneLargeSegment(LeafReaderContext leaf) { | ||
| int numDocsInLeaf = leaf.reader().maxDoc(); | ||
| int numSlices = Math.max(1, numDocsInLeaf / desiredDocsPerSlice); | ||
| while (Math.ceilDiv(numDocsInLeaf, numSlices) > maxDocsPerSlice) { | ||
| numSlices++; | ||
| } | ||
| int docPerSlice = numDocsInLeaf / numSlices; | ||
| int leftoverDocs = numDocsInLeaf % numSlices; | ||
| int minDoc = 0; | ||
| List<List<PartialLeafReaderContext>> results = new ArrayList<>(); | ||
| while (minDoc < numDocsInLeaf) { | ||
| int docsToUse = docPerSlice; | ||
| if (leftoverDocs > 0) { | ||
| --leftoverDocs; | ||
| docsToUse++; | ||
| } | ||
| int maxDoc = Math.min(minDoc + docsToUse, numDocsInLeaf); | ||
| results.add(List.of(new PartialLeafReaderContext(leaf, minDoc, maxDoc))); | ||
| minDoc = maxDoc; | ||
| } | ||
| assert leftoverDocs == 0 : leftoverDocs; | ||
| assert results.stream().allMatch(s -> s.size() == 1) : "must have one partial leaf per slice"; | ||
| assert results.stream().flatMapToInt(ss -> ss.stream().mapToInt(s -> s.maxDoc() - s.minDoc())).sum() == numDocsInLeaf; | ||
| return results; | ||
| } | ||
|
|
||
| List<List<PartialLeafReaderContext>> partitionSmallSegments(List<LeafReaderContext> leaves) { | ||
| var slices = IndexSearcher.slices(leaves, maxDocsPerSlice, maxSegmentsPerSlice, true); | ||
| return Arrays.stream(slices).map(g -> Arrays.stream(g.partitions).map(PartialLeafReaderContext::new).toList()).toList(); | ||
| } | ||
| } | ||
|
|
||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
<ol>