Skip to content

Commit e4c0d9a

Browse files
committed
Bigger batches
1 parent c5b8699 commit e4c0d9a

File tree

2 files changed

+48
-12
lines changed

2 files changed

+48
-12
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,10 +1016,30 @@ class BatchIterator {
10161016
}
10171017

10181018
void run() {
1019-
inferenceExecutor.execute(() -> inferBatchAndRunAfter(requestAndListeners.get(index.get())));
1019+
inferenceExecutor.execute(this::inferBatchAndRunAfter);
10201020
}
10211021

1022-
private void inferBatchAndRunAfter(EmbeddingRequestChunker.BatchRequestAndListener batch) {
1022+
private void inferBatchAndRunAfter() {
1023+
int NUM_REQUESTS_INFLIGHT = 20; // * batch size = 200
1024+
int requestCount = 0;
1025+
// loop does not include the final request
1026+
while (requestCount < NUM_REQUESTS_INFLIGHT - 1 && index.get() < requestAndListeners.size() - 1) {
1027+
1028+
var batch = requestAndListeners.get(index.get());
1029+
executeRequest(batch);
1030+
requestCount++;
1031+
index.incrementAndGet();
1032+
}
1033+
1034+
var batch = requestAndListeners.get(index.get());
1035+
executeRequest(batch, () -> {
1036+
if (index.incrementAndGet() < requestAndListeners.size()) {
1037+
run(); // start the next batch
1038+
}
1039+
});
1040+
}
1041+
1042+
private void executeRequest(EmbeddingRequestChunker.BatchRequestAndListener batch) {
10231043
var inferenceRequest = buildInferenceRequest(
10241044
esModel.mlNodeDeploymentId(),
10251045
EmptyConfigUpdate.INSTANCE,
@@ -1033,18 +1053,30 @@ private void inferBatchAndRunAfter(EmbeddingRequestChunker.BatchRequestAndListen
10331053
(l, inferenceResult) -> translateToChunkedResult(esModel.getTaskType(), inferenceResult.getInferenceResults(), l)
10341054
);
10351055

1036-
// schedule the next request once the results have been processed
1037-
var runNextListener = ActionListener.runAfter(mlResultsListener, () -> {
1038-
if (index.incrementAndGet() < requestAndListeners.size()) {
1039-
run();
1040-
}
1041-
});
1042-
1043-
var maybeDeployListener = runNextListener.delegateResponse(
1056+
var maybeDeployListener = mlResultsListener.delegateResponse(
10441057
(l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, l)
10451058
);
10461059

10471060
client.execute(InferModelAction.INSTANCE, inferenceRequest, maybeDeployListener);
10481061
}
1062+
1063+
private void executeRequest(EmbeddingRequestChunker.BatchRequestAndListener batch, Runnable runAfter) {
1064+
var inferenceRequest = buildInferenceRequest(
1065+
esModel.mlNodeDeploymentId(),
1066+
EmptyConfigUpdate.INSTANCE,
1067+
batch.batch().inputs(),
1068+
inputType,
1069+
timeout
1070+
);
1071+
1072+
ActionListener<InferModelAction.Response> mlResultsListener = batch.listener()
1073+
.delegateFailureAndWrap(
1074+
(l, inferenceResult) -> translateToChunkedResult(esModel.getTaskType(), inferenceResult.getInferenceResults(), l)
1075+
);
1076+
1077+
// schedule the next request once the results have been processed
1078+
var runNextListener = ActionListener.runAfter(mlResultsListener, runAfter);
1079+
client.execute(InferModelAction.INSTANCE, inferenceRequest, runNextListener);
1080+
}
10491081
}
10501082
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1275,7 +1275,8 @@ public void testChunkingLargeDocument() throws InterruptedException {
12751275
assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled());
12761276

12771277
int wordsPerChunk = 10;
1278-
int numBatches = randomIntBetween(3, 6);
1278+
int numBatches = 3;
1279+
randomIntBetween(3, 6);
12791280
int numChunks = randomIntBetween(
12801281
((numBatches - 1) * ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE) + 1,
12811282
numBatches * ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE
@@ -1291,6 +1292,9 @@ public void testChunkingLargeDocument() throws InterruptedException {
12911292
numResponsesPerBatch[i] = ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE;
12921293
}
12931294
numResponsesPerBatch[numBatches - 1] = numChunks % ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE;
1295+
if (numResponsesPerBatch[numBatches - 1] == 0) {
1296+
numResponsesPerBatch[numBatches - 1] = ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE;
1297+
}
12941298

12951299
var batchIndex = new AtomicInteger();
12961300
Client client = mock(Client.class);
@@ -1347,7 +1351,7 @@ public void testChunkingLargeDocument() throws InterruptedException {
13471351
);
13481352

13491353
latch.await();
1350-
assertTrue("Listener not called", gotResults.get());
1354+
assertTrue("Listener not called with results", gotResults.get());
13511355
}
13521356

13531357
public void testParsePersistedConfig_Rerank() {

0 commit comments

Comments
 (0)