@@ -91,8 +91,6 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
9191 public static Setting <ByteSizeValue > INDICES_INFERENCE_BATCH_SIZE = Setting .byteSizeSetting (
9292 "indices.inference.batch_size" ,
9393 DEFAULT_BATCH_SIZE ,
94- ByteSizeValue .ONE ,
95- ByteSizeValue .ofBytes (100 ),
9694 Setting .Property .NodeScope ,
9795 Setting .Property .OperatorDynamic
9896 );
@@ -172,7 +170,6 @@ private record InferenceProvider(InferenceService service, Model model) {}
172170 * @param offsetAdjustment The adjustment to apply to the chunk text offsets.
173171 */
174172 private record FieldInferenceRequest (
175- String inferenceId ,
176173 int bulkItemIndex ,
177174 String field ,
178175 String sourceField ,
@@ -252,32 +249,21 @@ private void executeNext(int itemOffset) {
252249 }
253250
254251 var items = bulkShardRequest .items ();
255- Map <String , List <FieldInferenceRequest >> requestsMap = new HashMap <>();
252+ Map <String , List <FieldInferenceRequest >> fieldRequestsMap = new HashMap <>();
256253 long totalInputLength = 0 ;
257254 int itemIndex = itemOffset ;
258-
259- while (itemIndex < items .length && totalInputLength < batchSizeInBytes ) {
255+ for (; itemIndex < bulkShardRequest .items ().length ; itemIndex ++) {
260256 var item = items [itemIndex ];
261- var requests = createFieldInferenceRequests (item , itemIndex );
262-
263- totalInputLength += requests .stream ().mapToLong (r -> r .input .length ()).sum ();
264- if (requestsMap .size () > 0 && totalInputLength >= batchSizeInBytes ) {
265- /**
266- * Exits early because the new requests exceed the allowable size.
267- * These requests will be processed in the next iteration.
268- */
257+ totalInputLength += addFieldInferenceRequests (item , itemIndex , fieldRequestsMap );
258+ if (totalInputLength >= batchSizeInBytes ) {
269259 break ;
270260 }
271-
272- for (var request : requests ) {
273- requestsMap .computeIfAbsent (request .inferenceId , k -> new ArrayList <>()).add (request );
274- }
275- itemIndex ++;
276261 }
277- int nextItemOffset = itemIndex ;
262+ int nextItemIndex = itemIndex + 1 ;
278263 Runnable onInferenceCompletion = () -> {
279264 try {
280- for (int i = itemOffset ; i < nextItemOffset ; i ++) {
265+ int limit = Math .min (nextItemIndex , items .length );
266+ for (int i = itemOffset ; i < limit ; i ++) {
281267 var result = inferenceResults .get (i );
282268 if (result == null ) {
283269 continue ;
@@ -292,12 +278,12 @@ private void executeNext(int itemOffset) {
292278 inferenceResults .set (i , null );
293279 }
294280 } finally {
295- executeNext (nextItemOffset );
281+ executeNext (nextItemIndex );
296282 }
297283 };
298284
299285 try (var releaseOnFinish = new RefCountingRunnable (onInferenceCompletion )) {
300- for (var entry : requestsMap .entrySet ()) {
286+ for (var entry : fieldRequestsMap .entrySet ()) {
301287 executeChunkedInferenceAsync (entry .getKey (), null , entry .getValue (), releaseOnFinish .acquire ());
302288 }
303289 }
@@ -425,16 +411,18 @@ public void onFailure(Exception exc) {
425411 }
426412
427413 /**
428- * Returns all inference requests from the provided {@link BulkItemRequest}.
414+ * Adds all inference requests associated with their respective inference IDs to the given {@code requestsMap}
415+ * for the specified {@code item}.
429416 *
430417 * @param item The bulk request item to process.
431418 * @param itemIndex The position of the item within the original bulk request.
432- * @return The list of {@link FieldInferenceRequest} associated with the item.
419+ * @param requestsMap A map storing inference requests, where each key is an inference ID,
420+ * and the value is a list of associated {@link FieldInferenceRequest} objects.
421+ * @return The total content length of all newly added requests, or {@code 0} if no requests were added.
433422 */
434- private List < FieldInferenceRequest > createFieldInferenceRequests (BulkItemRequest item , int itemIndex ) {
423+ private long addFieldInferenceRequests (BulkItemRequest item , int itemIndex , Map < String , List < FieldInferenceRequest >> requestsMap ) {
435424 boolean isUpdateRequest = false ;
436425 final IndexRequest indexRequest ;
437-
438426 if (item .request () instanceof IndexRequest ir ) {
439427 indexRequest = ir ;
440428 } else if (item .request () instanceof UpdateRequest updateRequest ) {
@@ -448,16 +436,16 @@ private List<FieldInferenceRequest> createFieldInferenceRequests(BulkItemRequest
448436 SemanticTextFieldMapper .CONTENT_TYPE
449437 )
450438 );
451- return List . of () ;
439+ return 0 ;
452440 }
453441 indexRequest = updateRequest .doc ();
454442 } else {
455443 // ignore delete request
456- return List . of () ;
444+ return 0 ;
457445 }
458446
459447 final Map <String , Object > docMap = indexRequest .sourceAsMap ();
460- List < FieldInferenceRequest > requests = new ArrayList <>() ;
448+ long inputLength = 0 ;
461449 for (var entry : fieldInferenceMap .values ()) {
462450 String field = entry .getName ();
463451 String inferenceId = entry .getInferenceId ();
@@ -526,9 +514,12 @@ private List<FieldInferenceRequest> createFieldInferenceRequests(BulkItemRequest
526514 break ;
527515 }
528516
517+ inputLength += values .stream ().mapToLong (String ::length ).sum ();
518+
519+ List <FieldInferenceRequest > requests = requestsMap .computeIfAbsent (inferenceId , k -> new ArrayList <>());
529520 int offsetAdjustment = 0 ;
530521 for (String v : values ) {
531- requests .add (new FieldInferenceRequest (inferenceId , itemIndex , field , sourceField , v , order ++, offsetAdjustment ));
522+ requests .add (new FieldInferenceRequest (itemIndex , field , sourceField , v , order ++, offsetAdjustment ));
532523
533524 // When using the inference metadata fields format, all the input values are concatenated so that the
534525 // chunk text offsets are expressed in the context of a single string. Calculate the offset adjustment
@@ -537,7 +528,7 @@ private List<FieldInferenceRequest> createFieldInferenceRequests(BulkItemRequest
537528 }
538529 }
539530 }
540- return requests ;
531+ return inputLength ;
541532 }
542533
543534 private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot (int id ) {
0 commit comments