Skip to content

Commit 6cb165b

Browse files
committed
handle batch size as a real maximum
1 parent 8b47105 commit 6cb165b

File tree

1 file changed

+32
-23
lines changed

1 file changed

+32
-23
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
9191
public static Setting<ByteSizeValue> INDICES_INFERENCE_BATCH_SIZE = Setting.byteSizeSetting(
9292
"indices.inference.batch_size",
9393
DEFAULT_BATCH_SIZE,
94+
ByteSizeValue.ONE,
95+
ByteSizeValue.ofBytes(100),
9496
Setting.Property.NodeScope,
9597
Setting.Property.OperatorDynamic
9698
);
@@ -170,6 +172,7 @@ private record InferenceProvider(InferenceService service, Model model) {}
170172
* @param offsetAdjustment The adjustment to apply to the chunk text offsets.
171173
*/
172174
private record FieldInferenceRequest(
175+
String inferenceId,
173176
int bulkItemIndex,
174177
String field,
175178
String sourceField,
@@ -249,21 +252,32 @@ private void executeNext(int itemOffset) {
249252
}
250253

251254
var items = bulkShardRequest.items();
252-
Map<String, List<FieldInferenceRequest>> fieldRequestsMap = new HashMap<>();
255+
Map<String, List<FieldInferenceRequest>> requestsMap = new HashMap<>();
253256
long totalInputLength = 0;
254257
int itemIndex = itemOffset;
255-
for (; itemIndex < bulkShardRequest.items().length; itemIndex++) {
258+
259+
while (itemIndex < items.length && totalInputLength < batchSizeInBytes) {
256260
var item = items[itemIndex];
257-
totalInputLength += addFieldInferenceRequests(item, itemIndex, fieldRequestsMap);
258-
if (totalInputLength >= batchSizeInBytes) {
261+
var requests = createFieldInferenceRequests(item, itemIndex);
262+
263+
totalInputLength += requests.stream().mapToLong(r -> r.input.length()).sum();
264+
if (requestsMap.size() > 0 && totalInputLength >= batchSizeInBytes) {
265+
/**
266+
* Exits early because the new requests exceed the allowable size.
267+
* These requests will be processed in the next iteration.
268+
*/
259269
break;
260270
}
271+
272+
for (var request : requests) {
273+
requestsMap.computeIfAbsent(request.inferenceId, k -> new ArrayList<>()).add(request);
274+
}
275+
itemIndex++;
261276
}
262-
int nextItemIndex = itemIndex + 1;
277+
int nextItemOffset = itemIndex;
263278
Runnable onInferenceCompletion = () -> {
264279
try {
265-
int limit = Math.min(nextItemIndex, items.length);
266-
for (int i = itemOffset; i < limit; i++) {
280+
for (int i = itemOffset; i < nextItemOffset; i++) {
267281
var result = inferenceResults.get(i);
268282
if (result == null) {
269283
continue;
@@ -278,12 +292,12 @@ private void executeNext(int itemOffset) {
278292
inferenceResults.set(i, null);
279293
}
280294
} finally {
281-
executeNext(nextItemIndex);
295+
executeNext(nextItemOffset);
282296
}
283297
};
284298

285299
try (var releaseOnFinish = new RefCountingRunnable(onInferenceCompletion)) {
286-
for (var entry : fieldRequestsMap.entrySet()) {
300+
for (var entry : requestsMap.entrySet()) {
287301
executeChunkedInferenceAsync(entry.getKey(), null, entry.getValue(), releaseOnFinish.acquire());
288302
}
289303
}
@@ -411,18 +425,16 @@ public void onFailure(Exception exc) {
411425
}
412426

413427
/**
414-
* Adds all inference requests associated with their respective inference IDs to the given {@code requestsMap}
415-
* for the specified {@code item}.
428+
* Returns all inference requests from the provided {@link BulkItemRequest}.
416429
*
417430
* @param item The bulk request item to process.
418431
* @param itemIndex The position of the item within the original bulk request.
419-
* @param requestsMap A map storing inference requests, where each key is an inference ID,
420-
* and the value is a list of associated {@link FieldInferenceRequest} objects.
421-
* @return The total content length of all newly added requests, or {@code 0} if no requests were added.
432+
* @return The list of {@link FieldInferenceRequest} associated with the item.
422433
*/
423-
private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<String, List<FieldInferenceRequest>> requestsMap) {
434+
private List<FieldInferenceRequest> createFieldInferenceRequests(BulkItemRequest item, int itemIndex) {
424435
boolean isUpdateRequest = false;
425436
final IndexRequest indexRequest;
437+
426438
if (item.request() instanceof IndexRequest ir) {
427439
indexRequest = ir;
428440
} else if (item.request() instanceof UpdateRequest updateRequest) {
@@ -436,16 +448,16 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
436448
SemanticTextFieldMapper.CONTENT_TYPE
437449
)
438450
);
439-
return 0;
451+
return List.of();
440452
}
441453
indexRequest = updateRequest.doc();
442454
} else {
443455
// ignore delete request
444-
return 0;
456+
return List.of();
445457
}
446458

447459
final Map<String, Object> docMap = indexRequest.sourceAsMap();
448-
long inputLength = 0;
460+
List<FieldInferenceRequest> requests = new ArrayList<>();
449461
for (var entry : fieldInferenceMap.values()) {
450462
String field = entry.getName();
451463
String inferenceId = entry.getInferenceId();
@@ -514,12 +526,9 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
514526
break;
515527
}
516528

517-
inputLength += values.stream().mapToLong(String::length).sum();
518-
519-
List<FieldInferenceRequest> requests = requestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>());
520529
int offsetAdjustment = 0;
521530
for (String v : values) {
522-
requests.add(new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment));
531+
requests.add(new FieldInferenceRequest(inferenceId, itemIndex, field, sourceField, v, order++, offsetAdjustment));
523532

524533
// When using the inference metadata fields format, all the input values are concatenated so that the
525534
// chunk text offsets are expressed in the context of a single string. Calculate the offset adjustment
@@ -528,7 +537,7 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
528537
}
529538
}
530539
}
531-
return inputLength;
540+
return requests;
532541
}
533542

534543
private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) {

0 commit comments

Comments
 (0)