@@ -91,6 +91,8 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
9191 public static Setting <ByteSizeValue > INDICES_INFERENCE_BATCH_SIZE = Setting .byteSizeSetting (
9292 "indices.inference.batch_size" ,
9393 DEFAULT_BATCH_SIZE ,
94+ ByteSizeValue .ONE ,
95+ ByteSizeValue .ofBytes (100 ),
9496 Setting .Property .NodeScope ,
9597 Setting .Property .OperatorDynamic
9698 );
@@ -170,6 +172,7 @@ private record InferenceProvider(InferenceService service, Model model) {}
170172 * @param offsetAdjustment The adjustment to apply to the chunk text offsets.
171173 */
172174 private record FieldInferenceRequest (
175+ String inferenceId ,
173176 int bulkItemIndex ,
174177 String field ,
175178 String sourceField ,
@@ -249,21 +252,32 @@ private void executeNext(int itemOffset) {
249252 }
250253
251254 var items = bulkShardRequest .items ();
252- Map <String , List <FieldInferenceRequest >> fieldRequestsMap = new HashMap <>();
255+ Map <String , List <FieldInferenceRequest >> requestsMap = new HashMap <>();
253256 long totalInputLength = 0 ;
254257 int itemIndex = itemOffset ;
255- for (; itemIndex < bulkShardRequest .items ().length ; itemIndex ++) {
258+
259+ while (itemIndex < items .length && totalInputLength < batchSizeInBytes ) {
256260 var item = items [itemIndex ];
257- totalInputLength += addFieldInferenceRequests (item , itemIndex , fieldRequestsMap );
258- if (totalInputLength >= batchSizeInBytes ) {
261+ var requests = createFieldInferenceRequests (item , itemIndex );
262+
263+ totalInputLength += requests .stream ().mapToLong (r -> r .input .length ()).sum ();
264+ if (requestsMap .size () > 0 && totalInputLength >= batchSizeInBytes ) {
265+ /**
266+ * Exits early because the new requests exceed the allowable size.
267+ * These requests will be processed in the next iteration.
268+ */
259269 break ;
260270 }
271+
272+ for (var request : requests ) {
273+ requestsMap .computeIfAbsent (request .inferenceId , k -> new ArrayList <>()).add (request );
274+ }
275+ itemIndex ++;
261276 }
262- int nextItemIndex = itemIndex + 1 ;
277+ int nextItemOffset = itemIndex ;
263278 Runnable onInferenceCompletion = () -> {
264279 try {
265- int limit = Math .min (nextItemIndex , items .length );
266- for (int i = itemOffset ; i < limit ; i ++) {
280+ for (int i = itemOffset ; i < nextItemOffset ; i ++) {
267281 var result = inferenceResults .get (i );
268282 if (result == null ) {
269283 continue ;
@@ -278,12 +292,12 @@ private void executeNext(int itemOffset) {
278292 inferenceResults .set (i , null );
279293 }
280294 } finally {
281- executeNext (nextItemIndex );
295+ executeNext (nextItemOffset );
282296 }
283297 };
284298
285299 try (var releaseOnFinish = new RefCountingRunnable (onInferenceCompletion )) {
286- for (var entry : fieldRequestsMap .entrySet ()) {
300+ for (var entry : requestsMap .entrySet ()) {
287301 executeChunkedInferenceAsync (entry .getKey (), null , entry .getValue (), releaseOnFinish .acquire ());
288302 }
289303 }
@@ -411,18 +425,16 @@ public void onFailure(Exception exc) {
411425 }
412426
413427 /**
414- * Adds all inference requests associated with their respective inference IDs to the given {@code requestsMap}
415- * for the specified {@code item}.
428+ * Returns all inference requests from the provided {@link BulkItemRequest}.
416429 *
417430 * @param item The bulk request item to process.
418431 * @param itemIndex The position of the item within the original bulk request.
419- * @param requestsMap A map storing inference requests, where each key is an inference ID,
420- * and the value is a list of associated {@link FieldInferenceRequest} objects.
421- * @return The total content length of all newly added requests, or {@code 0} if no requests were added.
432+ * @return The list of {@link FieldInferenceRequest} associated with the item.
422433 */
423- private long addFieldInferenceRequests (BulkItemRequest item , int itemIndex , Map < String , List < FieldInferenceRequest >> requestsMap ) {
434+ private List < FieldInferenceRequest > createFieldInferenceRequests (BulkItemRequest item , int itemIndex ) {
424435 boolean isUpdateRequest = false ;
425436 final IndexRequest indexRequest ;
437+
426438 if (item .request () instanceof IndexRequest ir ) {
427439 indexRequest = ir ;
428440 } else if (item .request () instanceof UpdateRequest updateRequest ) {
@@ -436,16 +448,16 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
436448 SemanticTextFieldMapper .CONTENT_TYPE
437449 )
438450 );
439- return 0 ;
451+ return List . of () ;
440452 }
441453 indexRequest = updateRequest .doc ();
442454 } else {
443455 // ignore delete request
444- return 0 ;
456+ return List . of () ;
445457 }
446458
447459 final Map <String , Object > docMap = indexRequest .sourceAsMap ();
448- long inputLength = 0 ;
460+ List < FieldInferenceRequest > requests = new ArrayList <>() ;
449461 for (var entry : fieldInferenceMap .values ()) {
450462 String field = entry .getName ();
451463 String inferenceId = entry .getInferenceId ();
@@ -514,12 +526,9 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
514526 break ;
515527 }
516528
517- inputLength += values .stream ().mapToLong (String ::length ).sum ();
518-
519- List <FieldInferenceRequest > requests = requestsMap .computeIfAbsent (inferenceId , k -> new ArrayList <>());
520529 int offsetAdjustment = 0 ;
521530 for (String v : values ) {
522- requests .add (new FieldInferenceRequest (itemIndex , field , sourceField , v , order ++, offsetAdjustment ));
531+ requests .add (new FieldInferenceRequest (inferenceId , itemIndex , field , sourceField , v , order ++, offsetAdjustment ));
523532
524533 // When using the inference metadata fields format, all the input values are concatenated so that the
525534 // chunk text offsets are expressed in the context of a single string. Calculate the offset adjustment
@@ -528,7 +537,7 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
528537 }
529538 }
530539 }
531- return inputLength ;
540+ return requests ;
532541 }
533542
534543 private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot (int id ) {
0 commit comments