1616import org .elasticsearch .common .io .stream .StreamInput ;
1717import org .elasticsearch .common .io .stream .StreamOutput ;
1818import org .elasticsearch .common .io .stream .Writeable ;
19+ import org .elasticsearch .common .util .concurrent .ConcurrentCollections ;
1920import org .elasticsearch .core .Nullable ;
2021
2122import java .io .IOException ;
2728import java .util .List ;
2829import java .util .Map ;
2930import java .util .Queue ;
30- import java .util .concurrent .ConcurrentLinkedQueue ;
31+ import java .util .concurrent .atomic . AtomicReferenceArray ;
3132import java .util .function .Function ;
3233
3334/**
@@ -77,18 +78,62 @@ public record QueryAndTags(Query query, List<Object> tags) {}
7778 public static final int MAX_SEGMENTS_PER_SLICE = 5 ; // copied from IndexSearcher
7879
7980 private final int totalSlices ;
80- private final Queue <LuceneSlice > slices ;
81+ private final AtomicReferenceArray <LuceneSlice > slices ;
82+ private final Queue <Integer > startedPositions ;
83+ private final Queue <Integer > followedPositions ;
8184 private final Map <String , PartitioningStrategy > partitioningStrategies ;
8285
83- private LuceneSliceQueue (List <LuceneSlice > slices , Map <String , PartitioningStrategy > partitioningStrategies ) {
84- this .totalSlices = slices .size ();
85- this .slices = new ConcurrentLinkedQueue <>(slices );
86+ LuceneSliceQueue (List <LuceneSlice > sliceList , Map <String , PartitioningStrategy > partitioningStrategies ) {
87+ this .totalSlices = sliceList .size ();
88+ this .slices = new AtomicReferenceArray <>(sliceList .size ());
89+ for (int i = 0 ; i < sliceList .size (); i ++) {
90+ slices .set (i , sliceList .get (i ));
91+ }
8692 this .partitioningStrategies = partitioningStrategies ;
93+ this .startedPositions = ConcurrentCollections .newQueue ();
94+ this .followedPositions = ConcurrentCollections .newQueue ();
95+ for (LuceneSlice slice : sliceList ) {
96+ if (slice .leaves ().stream ().anyMatch (s -> s .minDoc () == 0 )) {
97+ startedPositions .add (slice .slicePosition ());
98+ } else {
99+ followedPositions .add (slice .slicePosition ());
100+ }
101+ }
87102 }
88103
104+ /**
105+ * Retrieves the next available {@link LuceneSlice} for processing.
106+ * If a previous slice is provided, this method first attempts to return the next sequential slice to maintain segment affinity
107+ * and minimize the cost of switching between segments.
108+ * <p>
109+ * If no sequential slice is available, it returns the next slice from the {@code startedPositions} queue, which starts a new
110+ * group of segments. If all started positions are exhausted, it retrieves a slice from the {@code followedPositions} queue,
111+ * enabling work stealing.
112+ *
113+ * @param prev the previously returned {@link LuceneSlice}, or {@code null} if starting
114+ * @return the next available {@link LuceneSlice}, or {@code null} if exhausted
115+ */
89116 @ Nullable
90- public LuceneSlice nextSlice () {
91- return slices .poll ();
117+ public LuceneSlice nextSlice (LuceneSlice prev ) {
118+ if (prev != null ) {
119+ final int nextId = prev .slicePosition () + 1 ;
120+ if (nextId < totalSlices ) {
121+ var slice = slices .getAndSet (nextId , null );
122+ if (slice != null ) {
123+ return slice ;
124+ }
125+ }
126+ }
127+ for (var preferredIndices : List .of (startedPositions , followedPositions )) {
128+ Integer nextId ;
129+ while ((nextId = preferredIndices .poll ()) != null ) {
130+ var slice = slices .getAndSet (nextId , null );
131+ if (slice != null ) {
132+ return slice ;
133+ }
134+ }
135+ }
136+ return null ;
92137 }
93138
94139 public int totalSlices () {
@@ -103,7 +148,14 @@ public Map<String, PartitioningStrategy> partitioningStrategies() {
103148 }
104149
105150 public Collection <String > remainingShardsIdentifiers () {
106- return slices .stream ().map (slice -> slice .shardContext ().shardIdentifier ()).toList ();
151+ List <String > remaining = new ArrayList <>(slices .length ());
152+ for (int i = 0 ; i < slices .length (); i ++) {
153+ LuceneSlice slice = slices .get (i );
154+ if (slice != null ) {
155+ remaining .add (slice .shardContext ().shardIdentifier ());
156+ }
157+ }
158+ return remaining ;
107159 }
108160
109161 public static LuceneSliceQueue create (
@@ -117,6 +169,7 @@ public static LuceneSliceQueue create(
117169 List <LuceneSlice > slices = new ArrayList <>();
118170 Map <String , PartitioningStrategy > partitioningStrategies = new HashMap <>(contexts .size ());
119171
172+ int nextSliceId = 0 ;
120173 for (ShardContext ctx : contexts ) {
121174 for (QueryAndTags queryAndExtra : queryFunction .apply (ctx )) {
122175 var scoreMode = scoreModeFunction .apply (ctx );
@@ -140,7 +193,7 @@ public static LuceneSliceQueue create(
140193 Weight weight = weight (ctx , query , scoreMode );
141194 for (List <PartialLeafReaderContext > group : groups ) {
142195 if (group .isEmpty () == false ) {
143- slices .add (new LuceneSlice (ctx , group , weight , queryAndExtra .tags ));
196+ slices .add (new LuceneSlice (nextSliceId ++, ctx , group , weight , queryAndExtra .tags ));
144197 }
145198 }
146199 }
@@ -184,6 +237,7 @@ List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int requeste
184237 @ Override
185238 List <List <PartialLeafReaderContext >> groups (IndexSearcher searcher , int requestedNumSlices ) {
186239 final int totalDocCount = searcher .getIndexReader ().maxDoc ();
240+ requestedNumSlices = Math .max (1 , totalDocCount / Math .clamp (totalDocCount / requestedNumSlices , 1 , MAX_DOCS_PER_SLICE ));
187241 final int normalMaxDocsPerSlice = totalDocCount / requestedNumSlices ;
188242 final int extraDocsInFirstSlice = totalDocCount % requestedNumSlices ;
189243 final List <List <PartialLeafReaderContext >> slices = new ArrayList <>();
0 commit comments