Skip to content

Commit 0b2ebf6

Browse files
committed
Always process batches in order
1 parent 233defd commit 0b2ebf6

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -317,18 +317,20 @@ public void onFailure(Exception exc) {
317317
modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener);
318318
return;
319319
}
320+
// TODO More efficiently batch requests
320321
int currentBatchSize = Math.min(requests.size(), batchSize);
321-
322322
final ChunkingSettings chunkingSettings = requests.getFirst().chunkingSettings;
323-
final List<FieldInferenceRequest> nextBatch = new ArrayList<>();
324-
final List<String> inputs = new ArrayList<>();
323+
final List<FieldInferenceRequest> currentBatch = new ArrayList<>();
325324
for (FieldInferenceRequest request : requests) {
326-
if (Objects.equals(chunkingSettings, request.chunkingSettings) && inputs.size() < currentBatchSize) {
327-
inputs.add(request.input);
328-
} else {
329-
nextBatch.add(request);
325+
if (Objects.equals(request.chunkingSettings, chunkingSettings) == false || currentBatch.size() >= currentBatchSize) {
326+
break;
330327
}
328+
currentBatch.add(request);
331329
}
330+
331+
final List<FieldInferenceRequest> nextBatch = requests.subList(currentBatch.size(), requests.size());
332+
final List<String> inputs = currentBatch.stream().map(FieldInferenceRequest::input).collect(Collectors.toList());
333+
332334
ActionListener<List<ChunkedInference>> completionListener = new ActionListener<>() {
333335
@Override
334336
public void onResponse(List<ChunkedInference> results) {

0 commit comments

Comments
 (0)