1616import org .elasticsearch .common .io .stream .StreamInput ;
1717import org .elasticsearch .common .io .stream .StreamOutput ;
1818import org .elasticsearch .common .io .stream .Writeable ;
19+ import org .elasticsearch .common .util .concurrent .ConcurrentCollections ;
1920import org .elasticsearch .core .Nullable ;
2021
2122import java .io .IOException ;
2223import java .io .UncheckedIOException ;
2324import java .util .ArrayList ;
2425import java .util .Arrays ;
2526import java .util .Collection ;
27+ import java .util .Collections ;
28+ import java .util .Comparator ;
2629import java .util .HashMap ;
2730import java .util .List ;
2831import java .util .Map ;
2932import java .util .Queue ;
30- import java .util .concurrent .ConcurrentLinkedQueue ;
33+ import java .util .concurrent .atomic . AtomicReferenceArray ;
3134import java .util .function .Function ;
3235
3336/**
@@ -77,18 +80,62 @@ public record QueryAndTags(Query query, List<Object> tags) {}
7780 public static final int MAX_SEGMENTS_PER_SLICE = 5 ; // copied from IndexSearcher
7881
7982 private final int totalSlices ;
80- private final Queue <LuceneSlice > slices ;
83+ private final AtomicReferenceArray <LuceneSlice > slices ;
84+ private final Queue <Integer > startedPositions ;
85+ private final Queue <Integer > followedPositions ;
8186 private final Map <String , PartitioningStrategy > partitioningStrategies ;
8287
83- private LuceneSliceQueue (List <LuceneSlice > slices , Map <String , PartitioningStrategy > partitioningStrategies ) {
84- this .totalSlices = slices .size ();
85- this .slices = new ConcurrentLinkedQueue <>(slices );
88+ LuceneSliceQueue (List <LuceneSlice > sliceList , Map <String , PartitioningStrategy > partitioningStrategies ) {
89+ this .totalSlices = sliceList .size ();
90+ this .slices = new AtomicReferenceArray <>(sliceList .size ());
91+ for (int i = 0 ; i < sliceList .size (); i ++) {
92+ slices .set (i , sliceList .get (i ));
93+ }
8694 this .partitioningStrategies = partitioningStrategies ;
95+ this .startedPositions = ConcurrentCollections .newQueue ();
96+ this .followedPositions = ConcurrentCollections .newQueue ();
97+ for (LuceneSlice slice : sliceList ) {
98+ if (slice .getLeaf (0 ).minDoc () == 0 ) {
99+ startedPositions .add (slice .slicePosition ());
100+ } else {
101+ followedPositions .add (slice .slicePosition ());
102+ }
103+ }
87104 }
88105
106+ /**
107+ * Retrieves the next available {@link LuceneSlice} for processing.
108+ * If a previous slice is provided, this method first attempts to return the next sequential slice to maintain segment affinity
109+ * and minimize the cost of switching between segments.
110+ * <p>
111+ * If no sequential slice is available, it returns the next slice from the {@code startedPositions} queue, which starts a new
112+ * group of segments. If all started positions are exhausted, it steals a slice from the {@code followedPositions} queue,
113+ * enabling work stealing.
114+ *
115+ * @param prev the previously returned {@link LuceneSlice}, or {@code null} if starting
116+ * @return the next available {@link LuceneSlice}, or {@code null} if exhausted
117+ */
89118 @ Nullable
90- public LuceneSlice nextSlice () {
91- return slices .poll ();
119+ public LuceneSlice nextSlice (LuceneSlice prev ) {
120+ if (prev != null ) {
121+ final int nextId = prev .slicePosition () + 1 ;
122+ if (nextId < totalSlices ) {
123+ var slice = slices .getAndSet (nextId , null );
124+ if (slice != null ) {
125+ return slice ;
126+ }
127+ }
128+ }
129+ for (var ids : List .of (startedPositions , followedPositions )) {
130+ Integer nextId ;
131+ while ((nextId = ids .poll ()) != null ) {
132+ var slice = slices .getAndSet (nextId , null );
133+ if (slice != null ) {
134+ return slice ;
135+ }
136+ }
137+ }
138+ return null ;
92139 }
93140
94141 public int totalSlices () {
@@ -103,7 +150,14 @@ public Map<String, PartitioningStrategy> partitioningStrategies() {
103150 }
104151
105152 public Collection <String > remainingShardsIdentifiers () {
106- return slices .stream ().map (slice -> slice .shardContext ().shardIdentifier ()).toList ();
153+ List <String > remaining = new ArrayList <>(slices .length ());
154+ for (int i = 0 ; i < slices .length (); i ++) {
155+ LuceneSlice slice = slices .get (i );
156+ if (slice != null ) {
157+ remaining .add (slice .shardContext ().shardIdentifier ());
158+ }
159+ }
160+ return remaining ;
107161 }
108162
109163 public static LuceneSliceQueue create (
@@ -117,6 +171,7 @@ public static LuceneSliceQueue create(
117171 List <LuceneSlice > slices = new ArrayList <>();
118172 Map <String , PartitioningStrategy > partitioningStrategies = new HashMap <>(contexts .size ());
119173
174+ int nextSliceId = 0 ;
120175 for (ShardContext ctx : contexts ) {
121176 for (QueryAndTags queryAndExtra : queryFunction .apply (ctx )) {
122177 var scoreMode = scoreModeFunction .apply (ctx );
@@ -140,7 +195,7 @@ public static LuceneSliceQueue create(
140195 Weight weight = weight (ctx , query , scoreMode );
141196 for (List <PartialLeafReaderContext > group : groups ) {
142197 if (group .isEmpty () == false ) {
143- slices .add (new LuceneSlice (ctx , group , weight , queryAndExtra .tags ));
198+ slices .add (new LuceneSlice (nextSliceId ++, ctx , group , weight , queryAndExtra .tags ));
144199 }
145200 }
146201 }
@@ -184,50 +239,9 @@ List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int requeste
184239 @ Override
185240 List <List <PartialLeafReaderContext >> groups (IndexSearcher searcher , int requestedNumSlices ) {
186241 final int totalDocCount = searcher .getIndexReader ().maxDoc ();
187- final int normalMaxDocsPerSlice = totalDocCount / requestedNumSlices ;
188- final int extraDocsInFirstSlice = totalDocCount % requestedNumSlices ;
189- final List <List <PartialLeafReaderContext >> slices = new ArrayList <>();
190- int docsAllocatedInCurrentSlice = 0 ;
191- List <PartialLeafReaderContext > currentSlice = null ;
192- int maxDocsPerSlice = normalMaxDocsPerSlice + extraDocsInFirstSlice ;
193- for (LeafReaderContext ctx : searcher .getLeafContexts ()) {
194- final int numDocsInLeaf = ctx .reader ().maxDoc ();
195- int minDoc = 0 ;
196- while (minDoc < numDocsInLeaf ) {
197- int numDocsToUse = Math .min (maxDocsPerSlice - docsAllocatedInCurrentSlice , numDocsInLeaf - minDoc );
198- if (numDocsToUse <= 0 ) {
199- break ;
200- }
201- if (currentSlice == null ) {
202- currentSlice = new ArrayList <>();
203- }
204- currentSlice .add (new PartialLeafReaderContext (ctx , minDoc , minDoc + numDocsToUse ));
205- minDoc += numDocsToUse ;
206- docsAllocatedInCurrentSlice += numDocsToUse ;
207- if (docsAllocatedInCurrentSlice == maxDocsPerSlice ) {
208- slices .add (currentSlice );
209- // once the first slice with the extra docs is added, no need for extra docs
210- maxDocsPerSlice = normalMaxDocsPerSlice ;
211- currentSlice = null ;
212- docsAllocatedInCurrentSlice = 0 ;
213- }
214- }
215- }
216- if (currentSlice != null ) {
217- slices .add (currentSlice );
218- }
219- if (requestedNumSlices < totalDocCount && slices .size () != requestedNumSlices ) {
220- throw new IllegalStateException ("wrong number of slices, expected " + requestedNumSlices + " but got " + slices .size ());
221- }
222- if (slices .stream ()
223- .flatMapToInt (
224- l -> l .stream ()
225- .mapToInt (partialLeafReaderContext -> partialLeafReaderContext .maxDoc () - partialLeafReaderContext .minDoc ())
226- )
227- .sum () != totalDocCount ) {
228- throw new IllegalStateException ("wrong doc count" );
229- }
230- return slices ;
242+ // Cap the desired slice to prevent CPU underutilization when matching documents are concentrated in one segment region.
243+ int desiredSliceSize = Math .clamp (Math .ceilDiv (totalDocCount , requestedNumSlices ), 1 , MAX_DOCS_PER_SLICE );
244+ return new AdaptivePartitioner (Math .max (1 , desiredSliceSize ), MAX_SEGMENTS_PER_SLICE ).partition (searcher .getLeafContexts ());
231245 }
232246 };
233247
@@ -291,4 +305,67 @@ static Weight weight(ShardContext ctx, Query query, ScoreMode scoreMode) {
291305 throw new UncheckedIOException (e );
292306 }
293307 }
308+
309+ static final class AdaptivePartitioner {
310+ final int desiredDocsPerSlice ;
311+ final int maxDocsPerSlice ;
312+ final int maxSegmentsPerSlice ;
313+
314+ AdaptivePartitioner (int desiredDocsPerSlice , int maxSegmentsPerSlice ) {
315+ this .desiredDocsPerSlice = desiredDocsPerSlice ;
316+ this .maxDocsPerSlice = desiredDocsPerSlice * 5 / 4 ;
317+ this .maxSegmentsPerSlice = maxSegmentsPerSlice ;
318+ }
319+
320+ List <List <PartialLeafReaderContext >> partition (List <LeafReaderContext > leaves ) {
321+ List <LeafReaderContext > smallSegments = new ArrayList <>();
322+ List <LeafReaderContext > largeSegments = new ArrayList <>();
323+ List <List <PartialLeafReaderContext >> results = new ArrayList <>();
324+ for (LeafReaderContext leaf : leaves ) {
325+ if (leaf .reader ().maxDoc () >= 5 * desiredDocsPerSlice ) {
326+ largeSegments .add (leaf );
327+ } else {
328+ smallSegments .add (leaf );
329+ }
330+ }
331+ largeSegments .sort (Collections .reverseOrder (Comparator .comparingInt (l -> l .reader ().maxDoc ())));
332+ for (LeafReaderContext segment : largeSegments ) {
333+ results .addAll (partitionOneLargeSegment (segment ));
334+ }
335+ results .addAll (partitionSmallSegments (smallSegments ));
336+ return results ;
337+ }
338+
339+ List <List <PartialLeafReaderContext >> partitionOneLargeSegment (LeafReaderContext leaf ) {
340+ int numDocsInLeaf = leaf .reader ().maxDoc ();
341+ int numSlices = Math .max (1 , numDocsInLeaf / desiredDocsPerSlice );
342+ while (Math .ceilDiv (numDocsInLeaf , numSlices ) > maxDocsPerSlice ) {
343+ numSlices ++;
344+ }
345+ int docPerSlice = numDocsInLeaf / numSlices ;
346+ int leftoverDocs = numDocsInLeaf % numSlices ;
347+ int minDoc = 0 ;
348+ List <List <PartialLeafReaderContext >> results = new ArrayList <>();
349+ while (minDoc < numDocsInLeaf ) {
350+ int docsToUse = docPerSlice ;
351+ if (leftoverDocs > 0 ) {
352+ --leftoverDocs ;
353+ docsToUse ++;
354+ }
355+ int maxDoc = Math .min (minDoc + docsToUse , numDocsInLeaf );
356+ results .add (List .of (new PartialLeafReaderContext (leaf , minDoc , maxDoc )));
357+ minDoc = maxDoc ;
358+ }
359+ assert leftoverDocs == 0 : leftoverDocs ;
360+ assert results .stream ().allMatch (s -> s .size () == 1 ) : "must have one partial leaf per slice" ;
361+ assert results .stream ().flatMapToInt (ss -> ss .stream ().mapToInt (s -> s .maxDoc () - s .minDoc ())).sum () == numDocsInLeaf ;
362+ return results ;
363+ }
364+
365+ List <List <PartialLeafReaderContext >> partitionSmallSegments (List <LeafReaderContext > leaves ) {
366+ var slices = IndexSearcher .slices (leaves , maxDocsPerSlice , maxSegmentsPerSlice , true );
367+ return Arrays .stream (slices ).map (g -> Arrays .stream (g .partitions ).map (PartialLeafReaderContext ::new ).toList ()).toList ();
368+ }
369+ }
370+
294371}
0 commit comments