Skip to content

Commit b97a90c

Browse files
committed
Redo inference batching
1 parent e4c0d9a commit b97a90c

File tree

1 file changed

+35
-45
lines changed

1 file changed

+35
-45
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 35 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -994,9 +994,19 @@ static TaskType inferenceConfigToTaskType(InferenceConfig config) {
994994
}
995995
}
996996

997-
// Iterates over the batch sending 1 request at a time to avoid
998-
// filling the ml node inference queue.
997+
/**
998+
* Iterates over the batch executing a limited number requests at a time to avoid
999+
* filling the ML node inference queue.
1000+
*
1001+
* First, a single request is executed, which can also trigger deploying a model
1002+
* if necessary. When this request is successfully executed, a callback executes
1003+
* N requests in parallel next. Each of these requests also has a callback that
1004+
* executes one more request, so that at all time N requests are in-flight. This
1005+
* continues until all requests are executed.
1006+
*/
9991007
class BatchIterator {
1008+
private static final int NUM_REQUESTS_INFLIGHT = 20; // * batch size = 200
1009+
10001010
private final AtomicInteger index = new AtomicInteger();
10011011
private final ElasticsearchInternalModel esModel;
10021012
private final List<EmbeddingRequestChunker.BatchRequestAndListener> requestAndListeners;
@@ -1016,67 +1026,47 @@ class BatchIterator {
10161026
}
10171027

10181028
void run() {
1019-
inferenceExecutor.execute(this::inferBatchAndRunAfter);
1029+
// The first request may deploy the model, and upon completion runs
1030+
// NUM_REQUESTS_INFLIGHT in parallel.
1031+
inferenceExecutor.execute(() -> inferBatch(NUM_REQUESTS_INFLIGHT, true));
10201032
}
10211033

1022-
private void inferBatchAndRunAfter() {
1023-
int NUM_REQUESTS_INFLIGHT = 20; // * batch size = 200
1024-
int requestCount = 0;
1025-
// loop does not include the final request
1026-
while (requestCount < NUM_REQUESTS_INFLIGHT - 1 && index.get() < requestAndListeners.size() - 1) {
1027-
1028-
var batch = requestAndListeners.get(index.get());
1029-
executeRequest(batch);
1030-
requestCount++;
1031-
index.incrementAndGet();
1034+
private void inferBatch(int runAfterCount, boolean maybeDeploy) {
1035+
int batchIndex = index.getAndIncrement();
1036+
if (batchIndex >= requestAndListeners.size()) {
1037+
return;
10321038
}
1033-
1034-
var batch = requestAndListeners.get(index.get());
1035-
executeRequest(batch, () -> {
1036-
if (index.incrementAndGet() < requestAndListeners.size()) {
1037-
run(); // start the next batch
1039+
executeRequest(batchIndex, maybeDeploy, () -> {
1040+
for (int i = 0; i < runAfterCount; i++) {
1041+
// Subsequent requests may not deploy the model, because the first request
1042+
// already did so. Upon completion, it runs one more request.
1043+
inferenceExecutor.execute(() -> inferBatch(1, false));
10381044
}
10391045
});
10401046
}
10411047

1042-
private void executeRequest(EmbeddingRequestChunker.BatchRequestAndListener batch) {
1048+
private void executeRequest(int batchIndex, boolean maybeDeploy, Runnable runAfter) {
1049+
EmbeddingRequestChunker.BatchRequestAndListener batch = requestAndListeners.get(batchIndex);
10431050
var inferenceRequest = buildInferenceRequest(
10441051
esModel.mlNodeDeploymentId(),
10451052
EmptyConfigUpdate.INSTANCE,
10461053
batch.batch().inputs(),
10471054
inputType,
10481055
timeout
10491056
);
1057+
logger.trace("Executing batch index={}", batchIndex);
10501058

1051-
ActionListener<InferModelAction.Response> mlResultsListener = batch.listener()
1059+
ActionListener<InferModelAction.Response> listener = batch.listener()
10521060
.delegateFailureAndWrap(
10531061
(l, inferenceResult) -> translateToChunkedResult(esModel.getTaskType(), inferenceResult.getInferenceResults(), l)
10541062
);
1055-
1056-
var maybeDeployListener = mlResultsListener.delegateResponse(
1057-
(l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, l)
1058-
);
1059-
1060-
client.execute(InferModelAction.INSTANCE, inferenceRequest, maybeDeployListener);
1061-
}
1062-
1063-
private void executeRequest(EmbeddingRequestChunker.BatchRequestAndListener batch, Runnable runAfter) {
1064-
var inferenceRequest = buildInferenceRequest(
1065-
esModel.mlNodeDeploymentId(),
1066-
EmptyConfigUpdate.INSTANCE,
1067-
batch.batch().inputs(),
1068-
inputType,
1069-
timeout
1070-
);
1071-
1072-
ActionListener<InferModelAction.Response> mlResultsListener = batch.listener()
1073-
.delegateFailureAndWrap(
1074-
(l, inferenceResult) -> translateToChunkedResult(esModel.getTaskType(), inferenceResult.getInferenceResults(), l)
1075-
);
1076-
1077-
// schedule the next request once the results have been processed
1078-
var runNextListener = ActionListener.runAfter(mlResultsListener, runAfter);
1079-
client.execute(InferModelAction.INSTANCE, inferenceRequest, runNextListener);
1063+
if (runAfter != null) {
1064+
listener = ActionListener.runAfter(listener, runAfter);
1065+
}
1066+
if (maybeDeploy) {
1067+
listener = listener.delegateResponse((l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, l));
1068+
}
1069+
client.execute(InferModelAction.INSTANCE, inferenceRequest, listener);
10801070
}
10811071
}
10821072
}

0 commit comments

Comments
 (0)