Skip to content

Commit 5be2aca

Browse files
authored
[ML] Batch the chunks (#115477) (#116655)
Models running on an ml node have a queue of requests, when that queue is full new requests are rejected. A large document can chunk into hundreds of requests and in extreme cases a single large document can overflow the queue. Avoid this by batches of chunks keeping certain number of requests in flight.
1 parent 6cbb16f commit 5be2aca

File tree

3 files changed

+205
-33
lines changed

3 files changed

+205
-33
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 84 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
import java.util.Map;
6969
import java.util.Optional;
7070
import java.util.Set;
71+
import java.util.concurrent.atomic.AtomicInteger;
7172
import java.util.function.Consumer;
7273
import java.util.function.Function;
7374
import java.util.stream.Stream;
@@ -680,25 +681,13 @@ public void chunkedInfer(
680681
esModel.getConfigurations().getChunkingSettings()
681682
).batchRequestsWithListeners(listener);
682683

683-
for (var batch : batchedRequests) {
684-
var inferenceRequest = buildInferenceRequest(
685-
esModel.mlNodeDeploymentId(),
686-
EmptyConfigUpdate.INSTANCE,
687-
batch.batch().inputs(),
688-
inputType,
689-
timeout
690-
);
691-
692-
ActionListener<InferModelAction.Response> mlResultsListener = batch.listener()
693-
.delegateFailureAndWrap(
694-
(l, inferenceResult) -> translateToChunkedResult(model.getTaskType(), inferenceResult.getInferenceResults(), l)
695-
);
696-
697-
var maybeDeployListener = mlResultsListener.delegateResponse(
698-
(l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, mlResultsListener)
699-
);
700-
701-
client.execute(InferModelAction.INSTANCE, inferenceRequest, maybeDeployListener);
684+
if (batchedRequests.isEmpty()) {
685+
listener.onResponse(List.of());
686+
} else {
687+
// Avoid filling the inference queue by executing the batches in series
688+
// Each batch contains up to EMBEDDING_MAX_BATCH_SIZE inference request
689+
var sequentialRunner = new BatchIterator(esModel, inputType, timeout, batchedRequests);
690+
sequentialRunner.run();
702691
}
703692
} else {
704693
listener.onFailure(notElasticsearchModelException(model));
@@ -1017,6 +1006,82 @@ static TaskType inferenceConfigToTaskType(InferenceConfig config) {
10171006
}
10181007
}
10191008

1009+
/**
1010+
* Iterates over the batch executing a limited number requests at a time to avoid
1011+
* filling the ML node inference queue.
1012+
*
1013+
* First, a single request is executed, which can also trigger deploying a model
1014+
* if necessary. When this request is successfully executed, a callback executes
1015+
* N requests in parallel next. Each of these requests also has a callback that
1016+
* executes one more request, so that at all time N requests are in-flight. This
1017+
* continues until all requests are executed.
1018+
*/
1019+
class BatchIterator {
1020+
private static final int NUM_REQUESTS_INFLIGHT = 20; // * batch size = 200
1021+
1022+
private final AtomicInteger index = new AtomicInteger();
1023+
private final ElasticsearchInternalModel esModel;
1024+
private final List<EmbeddingRequestChunker.BatchRequestAndListener> requestAndListeners;
1025+
private final InputType inputType;
1026+
private final TimeValue timeout;
1027+
1028+
BatchIterator(
1029+
ElasticsearchInternalModel esModel,
1030+
InputType inputType,
1031+
TimeValue timeout,
1032+
List<EmbeddingRequestChunker.BatchRequestAndListener> requestAndListeners
1033+
) {
1034+
this.esModel = esModel;
1035+
this.requestAndListeners = requestAndListeners;
1036+
this.inputType = inputType;
1037+
this.timeout = timeout;
1038+
}
1039+
1040+
void run() {
1041+
// The first request may deploy the model, and upon completion runs
1042+
// NUM_REQUESTS_INFLIGHT in parallel.
1043+
inferenceExecutor.execute(() -> inferBatch(NUM_REQUESTS_INFLIGHT, true));
1044+
}
1045+
1046+
private void inferBatch(int runAfterCount, boolean maybeDeploy) {
1047+
int batchIndex = index.getAndIncrement();
1048+
if (batchIndex >= requestAndListeners.size()) {
1049+
return;
1050+
}
1051+
executeRequest(batchIndex, maybeDeploy, () -> {
1052+
for (int i = 0; i < runAfterCount; i++) {
1053+
// Subsequent requests may not deploy the model, because the first request
1054+
// already did so. Upon completion, it runs one more request.
1055+
inferenceExecutor.execute(() -> inferBatch(1, false));
1056+
}
1057+
});
1058+
}
1059+
1060+
private void executeRequest(int batchIndex, boolean maybeDeploy, Runnable runAfter) {
1061+
EmbeddingRequestChunker.BatchRequestAndListener batch = requestAndListeners.get(batchIndex);
1062+
var inferenceRequest = buildInferenceRequest(
1063+
esModel.mlNodeDeploymentId(),
1064+
EmptyConfigUpdate.INSTANCE,
1065+
batch.batch().inputs(),
1066+
inputType,
1067+
timeout
1068+
);
1069+
logger.trace("Executing batch index={}", batchIndex);
1070+
1071+
ActionListener<InferModelAction.Response> listener = batch.listener()
1072+
.delegateFailureAndWrap(
1073+
(l, inferenceResult) -> translateToChunkedResult(esModel.getTaskType(), inferenceResult.getInferenceResults(), l)
1074+
);
1075+
if (runAfter != null) {
1076+
listener = ActionListener.runAfter(listener, runAfter);
1077+
}
1078+
if (maybeDeploy) {
1079+
listener = listener.delegateResponse((l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, l));
1080+
}
1081+
client.execute(InferModelAction.INSTANCE, inferenceRequest, listener);
1082+
}
1083+
}
1084+
10201085
public static class Configuration {
10211086
public static InferenceServiceConfiguration get() {
10221087
return configuration.getOrCompute();

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,25 @@
2424
import java.util.concurrent.atomic.AtomicReference;
2525

2626
import static org.hamcrest.Matchers.contains;
27+
import static org.hamcrest.Matchers.empty;
2728
import static org.hamcrest.Matchers.hasSize;
2829
import static org.hamcrest.Matchers.instanceOf;
2930
import static org.hamcrest.Matchers.startsWith;
3031

3132
public class EmbeddingRequestChunkerTests extends ESTestCase {
3233

34+
public void testEmptyInput() {
35+
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
36+
var batches = new EmbeddingRequestChunker(List.of(), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener());
37+
assertThat(batches, empty());
38+
}
39+
40+
public void testBlankInput() {
41+
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
42+
var batches = new EmbeddingRequestChunker(List.of(""), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener());
43+
assertThat(batches, hasSize(1));
44+
}
45+
3346
public void testShortInputsAreSingleBatch() {
3447
String input = "one chunk";
3548
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());

0 commit comments

Comments
 (0)