Skip to content

Commit cca70d7

Browse files
authored
[ML] Batch the chunks (#115477)
Models running on an ml node have a queue of requests, when that queue is full new requests are rejected. A large document can chunk into hundreds of requests and in extreme cases a single large document can overflow the queue. Avoid this by batches of chunks keeping certain number of requests in flight.
1 parent bcf1bd4 commit cca70d7

File tree

3 files changed

+205
-33
lines changed

3 files changed

+205
-33
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 84 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
import java.util.Map;
6969
import java.util.Optional;
7070
import java.util.Set;
71+
import java.util.concurrent.atomic.AtomicInteger;
7172
import java.util.function.Consumer;
7273
import java.util.function.Function;
7374
import java.util.stream.Stream;
@@ -680,25 +681,13 @@ public void chunkedInfer(
680681
esModel.getConfigurations().getChunkingSettings()
681682
).batchRequestsWithListeners(listener);
682683

683-
for (var batch : batchedRequests) {
684-
var inferenceRequest = buildInferenceRequest(
685-
esModel.mlNodeDeploymentId(),
686-
EmptyConfigUpdate.INSTANCE,
687-
batch.batch().inputs(),
688-
inputType,
689-
timeout
690-
);
691-
692-
ActionListener<InferModelAction.Response> mlResultsListener = batch.listener()
693-
.delegateFailureAndWrap(
694-
(l, inferenceResult) -> translateToChunkedResult(model.getTaskType(), inferenceResult.getInferenceResults(), l)
695-
);
696-
697-
var maybeDeployListener = mlResultsListener.delegateResponse(
698-
(l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, mlResultsListener)
699-
);
700-
701-
client.execute(InferModelAction.INSTANCE, inferenceRequest, maybeDeployListener);
684+
if (batchedRequests.isEmpty()) {
685+
listener.onResponse(List.of());
686+
} else {
687+
// Avoid filling the inference queue by executing the batches in series
688+
// Each batch contains up to EMBEDDING_MAX_BATCH_SIZE inference request
689+
var sequentialRunner = new BatchIterator(esModel, inputType, timeout, batchedRequests);
690+
sequentialRunner.run();
702691
}
703692
} else {
704693
listener.onFailure(notElasticsearchModelException(model));
@@ -1018,6 +1007,82 @@ static TaskType inferenceConfigToTaskType(InferenceConfig config) {
10181007
}
10191008
}
10201009

1010+
/**
1011+
* Iterates over the batch executing a limited number requests at a time to avoid
1012+
* filling the ML node inference queue.
1013+
*
1014+
* First, a single request is executed, which can also trigger deploying a model
1015+
* if necessary. When this request is successfully executed, a callback executes
1016+
* N requests in parallel next. Each of these requests also has a callback that
1017+
* executes one more request, so that at all time N requests are in-flight. This
1018+
* continues until all requests are executed.
1019+
*/
1020+
class BatchIterator {
1021+
private static final int NUM_REQUESTS_INFLIGHT = 20; // * batch size = 200
1022+
1023+
private final AtomicInteger index = new AtomicInteger();
1024+
private final ElasticsearchInternalModel esModel;
1025+
private final List<EmbeddingRequestChunker.BatchRequestAndListener> requestAndListeners;
1026+
private final InputType inputType;
1027+
private final TimeValue timeout;
1028+
1029+
BatchIterator(
1030+
ElasticsearchInternalModel esModel,
1031+
InputType inputType,
1032+
TimeValue timeout,
1033+
List<EmbeddingRequestChunker.BatchRequestAndListener> requestAndListeners
1034+
) {
1035+
this.esModel = esModel;
1036+
this.requestAndListeners = requestAndListeners;
1037+
this.inputType = inputType;
1038+
this.timeout = timeout;
1039+
}
1040+
1041+
void run() {
1042+
// The first request may deploy the model, and upon completion runs
1043+
// NUM_REQUESTS_INFLIGHT in parallel.
1044+
inferenceExecutor.execute(() -> inferBatch(NUM_REQUESTS_INFLIGHT, true));
1045+
}
1046+
1047+
private void inferBatch(int runAfterCount, boolean maybeDeploy) {
1048+
int batchIndex = index.getAndIncrement();
1049+
if (batchIndex >= requestAndListeners.size()) {
1050+
return;
1051+
}
1052+
executeRequest(batchIndex, maybeDeploy, () -> {
1053+
for (int i = 0; i < runAfterCount; i++) {
1054+
// Subsequent requests may not deploy the model, because the first request
1055+
// already did so. Upon completion, it runs one more request.
1056+
inferenceExecutor.execute(() -> inferBatch(1, false));
1057+
}
1058+
});
1059+
}
1060+
1061+
private void executeRequest(int batchIndex, boolean maybeDeploy, Runnable runAfter) {
1062+
EmbeddingRequestChunker.BatchRequestAndListener batch = requestAndListeners.get(batchIndex);
1063+
var inferenceRequest = buildInferenceRequest(
1064+
esModel.mlNodeDeploymentId(),
1065+
EmptyConfigUpdate.INSTANCE,
1066+
batch.batch().inputs(),
1067+
inputType,
1068+
timeout
1069+
);
1070+
logger.trace("Executing batch index={}", batchIndex);
1071+
1072+
ActionListener<InferModelAction.Response> listener = batch.listener()
1073+
.delegateFailureAndWrap(
1074+
(l, inferenceResult) -> translateToChunkedResult(esModel.getTaskType(), inferenceResult.getInferenceResults(), l)
1075+
);
1076+
if (runAfter != null) {
1077+
listener = ActionListener.runAfter(listener, runAfter);
1078+
}
1079+
if (maybeDeploy) {
1080+
listener = listener.delegateResponse((l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, l));
1081+
}
1082+
client.execute(InferModelAction.INSTANCE, inferenceRequest, listener);
1083+
}
1084+
}
1085+
10211086
public static class Configuration {
10221087
public static InferenceServiceConfiguration get() {
10231088
return configuration.getOrCompute();

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,25 @@
2424
import java.util.concurrent.atomic.AtomicReference;
2525

2626
import static org.hamcrest.Matchers.contains;
27+
import static org.hamcrest.Matchers.empty;
2728
import static org.hamcrest.Matchers.hasSize;
2829
import static org.hamcrest.Matchers.instanceOf;
2930
import static org.hamcrest.Matchers.startsWith;
3031

3132
public class EmbeddingRequestChunkerTests extends ESTestCase {
3233

34+
public void testEmptyInput() {
35+
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
36+
var batches = new EmbeddingRequestChunker(List.of(), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener());
37+
assertThat(batches, empty());
38+
}
39+
40+
public void testBlankInput() {
41+
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
42+
var batches = new EmbeddingRequestChunker(List.of(""), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener());
43+
assertThat(batches, hasSize(1));
44+
}
45+
3346
public void testShortInputsAreSingleBatch() {
3447
String input = "one chunk";
3548
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());

0 commit comments

Comments
 (0)