Skip to content

Commit 381a85a

Browse files
committed
Revert "handle batch size as a real maximum"
This reverts commit 6cb165b.
1 parent 76e2e1a commit 381a85a

File tree

1 file changed

+23
-32
lines changed

1 file changed

+23
-32
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java

Lines changed: 23 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,6 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
9191
public static Setting<ByteSizeValue> INDICES_INFERENCE_BATCH_SIZE = Setting.byteSizeSetting(
9292
"indices.inference.batch_size",
9393
DEFAULT_BATCH_SIZE,
94-
ByteSizeValue.ONE,
95-
ByteSizeValue.ofBytes(100),
9694
Setting.Property.NodeScope,
9795
Setting.Property.OperatorDynamic
9896
);
@@ -172,7 +170,6 @@ private record InferenceProvider(InferenceService service, Model model) {}
172170
* @param offsetAdjustment The adjustment to apply to the chunk text offsets.
173171
*/
174172
private record FieldInferenceRequest(
175-
String inferenceId,
176173
int bulkItemIndex,
177174
String field,
178175
String sourceField,
@@ -252,32 +249,21 @@ private void executeNext(int itemOffset) {
252249
}
253250

254251
var items = bulkShardRequest.items();
255-
Map<String, List<FieldInferenceRequest>> requestsMap = new HashMap<>();
252+
Map<String, List<FieldInferenceRequest>> fieldRequestsMap = new HashMap<>();
256253
long totalInputLength = 0;
257254
int itemIndex = itemOffset;
258-
259-
while (itemIndex < items.length && totalInputLength < batchSizeInBytes) {
255+
for (; itemIndex < bulkShardRequest.items().length; itemIndex++) {
260256
var item = items[itemIndex];
261-
var requests = createFieldInferenceRequests(item, itemIndex);
262-
263-
totalInputLength += requests.stream().mapToLong(r -> r.input.length()).sum();
264-
if (requestsMap.size() > 0 && totalInputLength >= batchSizeInBytes) {
265-
/**
266-
* Exits early because the new requests exceed the allowable size.
267-
* These requests will be processed in the next iteration.
268-
*/
257+
totalInputLength += addFieldInferenceRequests(item, itemIndex, fieldRequestsMap);
258+
if (totalInputLength >= batchSizeInBytes) {
269259
break;
270260
}
271-
272-
for (var request : requests) {
273-
requestsMap.computeIfAbsent(request.inferenceId, k -> new ArrayList<>()).add(request);
274-
}
275-
itemIndex++;
276261
}
277-
int nextItemOffset = itemIndex;
262+
int nextItemIndex = itemIndex + 1;
278263
Runnable onInferenceCompletion = () -> {
279264
try {
280-
for (int i = itemOffset; i < nextItemOffset; i++) {
265+
int limit = Math.min(nextItemIndex, items.length);
266+
for (int i = itemOffset; i < limit; i++) {
281267
var result = inferenceResults.get(i);
282268
if (result == null) {
283269
continue;
@@ -292,12 +278,12 @@ private void executeNext(int itemOffset) {
292278
inferenceResults.set(i, null);
293279
}
294280
} finally {
295-
executeNext(nextItemOffset);
281+
executeNext(nextItemIndex);
296282
}
297283
};
298284

299285
try (var releaseOnFinish = new RefCountingRunnable(onInferenceCompletion)) {
300-
for (var entry : requestsMap.entrySet()) {
286+
for (var entry : fieldRequestsMap.entrySet()) {
301287
executeChunkedInferenceAsync(entry.getKey(), null, entry.getValue(), releaseOnFinish.acquire());
302288
}
303289
}
@@ -425,16 +411,18 @@ public void onFailure(Exception exc) {
425411
}
426412

427413
/**
428-
* Returns all inference requests from the provided {@link BulkItemRequest}.
414+
* Adds all inference requests associated with their respective inference IDs to the given {@code requestsMap}
415+
* for the specified {@code item}.
429416
*
430417
* @param item The bulk request item to process.
431418
* @param itemIndex The position of the item within the original bulk request.
432-
* @return The list of {@link FieldInferenceRequest} associated with the item.
419+
* @param requestsMap A map storing inference requests, where each key is an inference ID,
420+
* and the value is a list of associated {@link FieldInferenceRequest} objects.
421+
* @return The total content length of all newly added requests, or {@code 0} if no requests were added.
433422
*/
434-
private List<FieldInferenceRequest> createFieldInferenceRequests(BulkItemRequest item, int itemIndex) {
423+
private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<String, List<FieldInferenceRequest>> requestsMap) {
435424
boolean isUpdateRequest = false;
436425
final IndexRequest indexRequest;
437-
438426
if (item.request() instanceof IndexRequest ir) {
439427
indexRequest = ir;
440428
} else if (item.request() instanceof UpdateRequest updateRequest) {
@@ -448,16 +436,16 @@ private List<FieldInferenceRequest> createFieldInferenceRequests(BulkItemRequest
448436
SemanticTextFieldMapper.CONTENT_TYPE
449437
)
450438
);
451-
return List.of();
439+
return 0;
452440
}
453441
indexRequest = updateRequest.doc();
454442
} else {
455443
// ignore delete request
456-
return List.of();
444+
return 0;
457445
}
458446

459447
final Map<String, Object> docMap = indexRequest.sourceAsMap();
460-
List<FieldInferenceRequest> requests = new ArrayList<>();
448+
long inputLength = 0;
461449
for (var entry : fieldInferenceMap.values()) {
462450
String field = entry.getName();
463451
String inferenceId = entry.getInferenceId();
@@ -526,9 +514,12 @@ private List<FieldInferenceRequest> createFieldInferenceRequests(BulkItemRequest
526514
break;
527515
}
528516

517+
inputLength += values.stream().mapToLong(String::length).sum();
518+
519+
List<FieldInferenceRequest> requests = requestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>());
529520
int offsetAdjustment = 0;
530521
for (String v : values) {
531-
requests.add(new FieldInferenceRequest(inferenceId, itemIndex, field, sourceField, v, order++, offsetAdjustment));
522+
requests.add(new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment));
532523

533524
// When using the inference metadata fields format, all the input values are concatenated so that the
534525
// chunk text offsets are expressed in the context of a single string. Calculate the offset adjustment
@@ -537,7 +528,7 @@ private List<FieldInferenceRequest> createFieldInferenceRequests(BulkItemRequest
537528
}
538529
}
539530
}
540-
return requests;
531+
return inputLength;
541532
}
542533

543534
private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) {

0 commit comments

Comments
 (0)