1616import org .elasticsearch .common .io .stream .StreamInput ;
1717import org .elasticsearch .common .io .stream .StreamOutput ;
1818import org .elasticsearch .common .io .stream .Writeable ;
19+ import org .elasticsearch .common .util .concurrent .ConcurrentCollections ;
1920import org .elasticsearch .core .Nullable ;
2021
2122import java .io .IOException ;
2223import java .io .UncheckedIOException ;
2324import java .util .ArrayList ;
2425import java .util .Arrays ;
2526import java .util .Collection ;
27+ import java .util .Collections ;
28+ import java .util .Comparator ;
2629import java .util .HashMap ;
2730import java .util .List ;
2831import java .util .Map ;
2932import java .util .Queue ;
30- import java .util .concurrent .ConcurrentLinkedQueue ;
33+ import java .util .concurrent .atomic . AtomicReferenceArray ;
3134import java .util .function .Function ;
3235
3336/**
@@ -77,18 +80,78 @@ public record QueryAndTags(Query query, List<Object> tags) {}
7780 public static final int MAX_SEGMENTS_PER_SLICE = 5 ; // copied from IndexSearcher
7881
7982 private final int totalSlices ;
80- private final Queue <LuceneSlice > slices ;
8183 private final Map <String , PartitioningStrategy > partitioningStrategies ;
8284
83- private LuceneSliceQueue (List <LuceneSlice > slices , Map <String , PartitioningStrategy > partitioningStrategies ) {
84- this .totalSlices = slices .size ();
85- this .slices = new ConcurrentLinkedQueue <>(slices );
85+ private final AtomicReferenceArray <LuceneSlice > slices ;
86+ /**
87+ * Queue of slice IDs that are the primary entry point for a new group of segments.
88+ * A driver should prioritize polling from this queue after failing to get a sequential
89+ * slice (the segment affinity). This ensures that threads start work on fresh,
90+ * independent segment groups before resorting to work stealing.
91+ */
92+ private final Queue <Integer > sliceHeads ;
93+
94+ /**
95+ * Queue of slice IDs that are not the primary entry point for a segment group.
96+ * This queue serves as a fallback pool for work stealing. When a thread has no more independent work,
97+ * it will "steal" a slice from this queue to keep itself utilized. A driver should pull tasks from
98+ * this queue only when {@code sliceHeads} has been exhausted.
99+ */
100+ private final Queue <Integer > stealableSlices ;
101+
102+ LuceneSliceQueue (List <LuceneSlice > sliceList , Map <String , PartitioningStrategy > partitioningStrategies ) {
103+ this .totalSlices = sliceList .size ();
104+ this .slices = new AtomicReferenceArray <>(sliceList .size ());
105+ for (int i = 0 ; i < sliceList .size (); i ++) {
106+ slices .set (i , sliceList .get (i ));
107+ }
86108 this .partitioningStrategies = partitioningStrategies ;
109+ this .sliceHeads = ConcurrentCollections .newQueue ();
110+ this .stealableSlices = ConcurrentCollections .newQueue ();
111+ for (LuceneSlice slice : sliceList ) {
112+ if (slice .getLeaf (0 ).minDoc () == 0 ) {
113+ sliceHeads .add (slice .slicePosition ());
114+ } else {
115+ stealableSlices .add (slice .slicePosition ());
116+ }
117+ }
87118 }
88119
120+ /**
121+ * Retrieves the next available {@link LuceneSlice} for processing.
122+ * <p>
123+ * This method implements a three-tiered strategy to minimize the overhead of switching between segments:
124+ * 1. If a previous slice is provided, it first attempts to return the next sequential slice.
125+ * This keeps a thread working on the same segments, minimizing the overhead of segment switching.
126+ * 2. If affinity fails, it returns a slice from the {@link #sliceHeads} queue, which is an entry point for
127+ * a new, independent group of segments, allowing the calling Driver to work on a fresh set of segments.
128+ * 3. If the {@link #sliceHeads} queue is exhausted, it "steals" a slice
129+ * from the {@link #stealableSlices} queue. This fallback ensures all threads remain utilized.
130+ *
131+ * @param prev the previously returned {@link LuceneSlice}, or {@code null} if starting
132+ * @return the next available {@link LuceneSlice}, or {@code null} if exhausted
133+ */
89134 @ Nullable
90- public LuceneSlice nextSlice () {
91- return slices .poll ();
135+ public LuceneSlice nextSlice (LuceneSlice prev ) {
136+ if (prev != null ) {
137+ final int nextId = prev .slicePosition () + 1 ;
138+ if (nextId < totalSlices ) {
139+ var slice = slices .getAndSet (nextId , null );
140+ if (slice != null ) {
141+ return slice ;
142+ }
143+ }
144+ }
145+ for (var ids : List .of (sliceHeads , stealableSlices )) {
146+ Integer nextId ;
147+ while ((nextId = ids .poll ()) != null ) {
148+ var slice = slices .getAndSet (nextId , null );
149+ if (slice != null ) {
150+ return slice ;
151+ }
152+ }
153+ }
154+ return null ;
92155 }
93156
94157 public int totalSlices () {
@@ -103,7 +166,14 @@ public Map<String, PartitioningStrategy> partitioningStrategies() {
103166 }
104167
105168 public Collection <String > remainingShardsIdentifiers () {
106- return slices .stream ().map (slice -> slice .shardContext ().shardIdentifier ()).toList ();
169+ List <String > remaining = new ArrayList <>(slices .length ());
170+ for (int i = 0 ; i < slices .length (); i ++) {
171+ LuceneSlice slice = slices .get (i );
172+ if (slice != null ) {
173+ remaining .add (slice .shardContext ().shardIdentifier ());
174+ }
175+ }
176+ return remaining ;
107177 }
108178
109179 public static LuceneSliceQueue create (
@@ -117,6 +187,7 @@ public static LuceneSliceQueue create(
117187 List <LuceneSlice > slices = new ArrayList <>();
118188 Map <String , PartitioningStrategy > partitioningStrategies = new HashMap <>(contexts .size ());
119189
190+ int nextSliceId = 0 ;
120191 for (ShardContext ctx : contexts ) {
121192 for (QueryAndTags queryAndExtra : queryFunction .apply (ctx )) {
122193 var scoreMode = scoreModeFunction .apply (ctx );
@@ -140,7 +211,7 @@ public static LuceneSliceQueue create(
140211 Weight weight = weight (ctx , query , scoreMode );
141212 for (List <PartialLeafReaderContext > group : groups ) {
142213 if (group .isEmpty () == false ) {
143- slices .add (new LuceneSlice (ctx , group , weight , queryAndExtra .tags ));
214+ slices .add (new LuceneSlice (nextSliceId ++, ctx , group , weight , queryAndExtra .tags ));
144215 }
145216 }
146217 }
@@ -158,7 +229,7 @@ public enum PartitioningStrategy implements Writeable {
158229 */
159230 SHARD (0 ) {
160231 @ Override
161- List <List <PartialLeafReaderContext >> groups (IndexSearcher searcher , int requestedNumSlices ) {
232+ List <List <PartialLeafReaderContext >> groups (IndexSearcher searcher , int taskConcurrency ) {
162233 return List .of (searcher .getLeafContexts ().stream ().map (PartialLeafReaderContext ::new ).toList ());
163234 }
164235 },
@@ -167,7 +238,7 @@ List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int requeste
167238 */
168239 SEGMENT (1 ) {
169240 @ Override
170- List <List <PartialLeafReaderContext >> groups (IndexSearcher searcher , int requestedNumSlices ) {
241+ List <List <PartialLeafReaderContext >> groups (IndexSearcher searcher , int taskConcurrency ) {
171242 IndexSearcher .LeafSlice [] gs = IndexSearcher .slices (
172243 searcher .getLeafContexts (),
173244 MAX_DOCS_PER_SLICE ,
@@ -182,52 +253,11 @@ List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int requeste
182253 */
183254 DOC (2 ) {
184255 @ Override
185- List <List <PartialLeafReaderContext >> groups (IndexSearcher searcher , int requestedNumSlices ) {
256+ List <List <PartialLeafReaderContext >> groups (IndexSearcher searcher , int taskConcurrency ) {
186257 final int totalDocCount = searcher .getIndexReader ().maxDoc ();
187- final int normalMaxDocsPerSlice = totalDocCount / requestedNumSlices ;
188- final int extraDocsInFirstSlice = totalDocCount % requestedNumSlices ;
189- final List <List <PartialLeafReaderContext >> slices = new ArrayList <>();
190- int docsAllocatedInCurrentSlice = 0 ;
191- List <PartialLeafReaderContext > currentSlice = null ;
192- int maxDocsPerSlice = normalMaxDocsPerSlice + extraDocsInFirstSlice ;
193- for (LeafReaderContext ctx : searcher .getLeafContexts ()) {
194- final int numDocsInLeaf = ctx .reader ().maxDoc ();
195- int minDoc = 0 ;
196- while (minDoc < numDocsInLeaf ) {
197- int numDocsToUse = Math .min (maxDocsPerSlice - docsAllocatedInCurrentSlice , numDocsInLeaf - minDoc );
198- if (numDocsToUse <= 0 ) {
199- break ;
200- }
201- if (currentSlice == null ) {
202- currentSlice = new ArrayList <>();
203- }
204- currentSlice .add (new PartialLeafReaderContext (ctx , minDoc , minDoc + numDocsToUse ));
205- minDoc += numDocsToUse ;
206- docsAllocatedInCurrentSlice += numDocsToUse ;
207- if (docsAllocatedInCurrentSlice == maxDocsPerSlice ) {
208- slices .add (currentSlice );
209- // once the first slice with the extra docs is added, no need for extra docs
210- maxDocsPerSlice = normalMaxDocsPerSlice ;
211- currentSlice = null ;
212- docsAllocatedInCurrentSlice = 0 ;
213- }
214- }
215- }
216- if (currentSlice != null ) {
217- slices .add (currentSlice );
218- }
219- if (requestedNumSlices < totalDocCount && slices .size () != requestedNumSlices ) {
220- throw new IllegalStateException ("wrong number of slices, expected " + requestedNumSlices + " but got " + slices .size ());
221- }
222- if (slices .stream ()
223- .flatMapToInt (
224- l -> l .stream ()
225- .mapToInt (partialLeafReaderContext -> partialLeafReaderContext .maxDoc () - partialLeafReaderContext .minDoc ())
226- )
227- .sum () != totalDocCount ) {
228- throw new IllegalStateException ("wrong doc count" );
229- }
230- return slices ;
258+ // Cap the desired slice to prevent CPU underutilization when matching documents are concentrated in one segment region.
259+ int desiredSliceSize = Math .clamp (Math .ceilDiv (totalDocCount , taskConcurrency ), 1 , MAX_DOCS_PER_SLICE );
260+ return new AdaptivePartitioner (Math .max (1 , desiredSliceSize ), MAX_SEGMENTS_PER_SLICE ).partition (searcher .getLeafContexts ());
231261 }
232262 };
233263
@@ -252,7 +282,7 @@ public void writeTo(StreamOutput out) throws IOException {
252282 out .writeByte (id );
253283 }
254284
255- abstract List <List <PartialLeafReaderContext >> groups (IndexSearcher searcher , int requestedNumSlices );
285+ abstract List <List <PartialLeafReaderContext >> groups (IndexSearcher searcher , int taskConcurrency );
256286
257287 private static PartitioningStrategy pick (
258288 DataPartitioning dataPartitioning ,
@@ -291,4 +321,67 @@ static Weight weight(ShardContext ctx, Query query, ScoreMode scoreMode) {
291321 throw new UncheckedIOException (e );
292322 }
293323 }
324+
325+ static final class AdaptivePartitioner {
326+ final int desiredDocsPerSlice ;
327+ final int maxDocsPerSlice ;
328+ final int maxSegmentsPerSlice ;
329+
330+ AdaptivePartitioner (int desiredDocsPerSlice , int maxSegmentsPerSlice ) {
331+ this .desiredDocsPerSlice = desiredDocsPerSlice ;
332+ this .maxDocsPerSlice = desiredDocsPerSlice * 5 / 4 ;
333+ this .maxSegmentsPerSlice = maxSegmentsPerSlice ;
334+ }
335+
336+ List <List <PartialLeafReaderContext >> partition (List <LeafReaderContext > leaves ) {
337+ List <LeafReaderContext > smallSegments = new ArrayList <>();
338+ List <LeafReaderContext > largeSegments = new ArrayList <>();
339+ List <List <PartialLeafReaderContext >> results = new ArrayList <>();
340+ for (LeafReaderContext leaf : leaves ) {
341+ if (leaf .reader ().maxDoc () >= 5 * desiredDocsPerSlice ) {
342+ largeSegments .add (leaf );
343+ } else {
344+ smallSegments .add (leaf );
345+ }
346+ }
347+ largeSegments .sort (Collections .reverseOrder (Comparator .comparingInt (l -> l .reader ().maxDoc ())));
348+ for (LeafReaderContext segment : largeSegments ) {
349+ results .addAll (partitionOneLargeSegment (segment ));
350+ }
351+ results .addAll (partitionSmallSegments (smallSegments ));
352+ return results ;
353+ }
354+
355+ List <List <PartialLeafReaderContext >> partitionOneLargeSegment (LeafReaderContext leaf ) {
356+ int numDocsInLeaf = leaf .reader ().maxDoc ();
357+ int numSlices = Math .max (1 , numDocsInLeaf / desiredDocsPerSlice );
358+ while (Math .ceilDiv (numDocsInLeaf , numSlices ) > maxDocsPerSlice ) {
359+ numSlices ++;
360+ }
361+ int docPerSlice = numDocsInLeaf / numSlices ;
362+ int leftoverDocs = numDocsInLeaf % numSlices ;
363+ int minDoc = 0 ;
364+ List <List <PartialLeafReaderContext >> results = new ArrayList <>();
365+ while (minDoc < numDocsInLeaf ) {
366+ int docsToUse = docPerSlice ;
367+ if (leftoverDocs > 0 ) {
368+ --leftoverDocs ;
369+ docsToUse ++;
370+ }
371+ int maxDoc = Math .min (minDoc + docsToUse , numDocsInLeaf );
372+ results .add (List .of (new PartialLeafReaderContext (leaf , minDoc , maxDoc )));
373+ minDoc = maxDoc ;
374+ }
375+ assert leftoverDocs == 0 : leftoverDocs ;
376+ assert results .stream ().allMatch (s -> s .size () == 1 ) : "must have one partial leaf per slice" ;
377+ assert results .stream ().flatMapToInt (ss -> ss .stream ().mapToInt (s -> s .maxDoc () - s .minDoc ())).sum () == numDocsInLeaf ;
378+ return results ;
379+ }
380+
381+ List <List <PartialLeafReaderContext >> partitionSmallSegments (List <LeafReaderContext > leaves ) {
382+ var slices = IndexSearcher .slices (leaves , maxDocsPerSlice , maxSegmentsPerSlice , true );
383+ return Arrays .stream (slices ).map (g -> Arrays .stream (g .partitions ).map (PartialLeafReaderContext ::new ).toList ()).toList ();
384+ }
385+ }
386+
294387}
0 commit comments