Skip to content

Commit dfa1c87

Browse files
authored
[ML] Batch the chunks (#115477) (#116823)
Models running on an ml node have a queue of requests, when that queue is full new requests are rejected. A large document can chunk into hundreds of requests and in extreme cases a single large document can overflow the queue. Avoid this by batches of chunks keeping certain number of requests in flight.
1 parent 187aecf commit dfa1c87

File tree

3 files changed

+205
-33
lines changed

3 files changed

+205
-33
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 84 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
import java.util.Map;
6161
import java.util.Optional;
6262
import java.util.Set;
63+
import java.util.concurrent.atomic.AtomicInteger;
6364
import java.util.function.Consumer;
6465
import java.util.function.Function;
6566

@@ -656,25 +657,13 @@ public void chunkedInfer(
656657
esModel.getConfigurations().getChunkingSettings()
657658
).batchRequestsWithListeners(listener);
658659

659-
for (var batch : batchedRequests) {
660-
var inferenceRequest = buildInferenceRequest(
661-
esModel.mlNodeDeploymentId(),
662-
EmptyConfigUpdate.INSTANCE,
663-
batch.batch().inputs(),
664-
inputType,
665-
timeout
666-
);
667-
668-
ActionListener<InferModelAction.Response> mlResultsListener = batch.listener()
669-
.delegateFailureAndWrap(
670-
(l, inferenceResult) -> translateToChunkedResult(model.getTaskType(), inferenceResult.getInferenceResults(), l)
671-
);
672-
673-
var maybeDeployListener = mlResultsListener.delegateResponse(
674-
(l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, mlResultsListener)
675-
);
676-
677-
client.execute(InferModelAction.INSTANCE, inferenceRequest, maybeDeployListener);
660+
if (batchedRequests.isEmpty()) {
661+
listener.onResponse(List.of());
662+
} else {
663+
// Avoid filling the inference queue by executing the batches in series
664+
// Each batch contains up to EMBEDDING_MAX_BATCH_SIZE inference request
665+
var sequentialRunner = new BatchIterator(esModel, inputType, timeout, batchedRequests);
666+
sequentialRunner.run();
678667
}
679668
} else {
680669
listener.onFailure(notElasticsearchModelException(model));
@@ -990,4 +979,80 @@ static TaskType inferenceConfigToTaskType(InferenceConfig config) {
990979
return null;
991980
}
992981
}
982+
983+
/**
984+
* Iterates over the batch executing a limited number requests at a time to avoid
985+
* filling the ML node inference queue.
986+
*
987+
* First, a single request is executed, which can also trigger deploying a model
988+
* if necessary. When this request is successfully executed, a callback executes
989+
* N requests in parallel next. Each of these requests also has a callback that
990+
* executes one more request, so that at all time N requests are in-flight. This
991+
* continues until all requests are executed.
992+
*/
993+
class BatchIterator {
994+
private static final int NUM_REQUESTS_INFLIGHT = 20; // * batch size = 200
995+
996+
private final AtomicInteger index = new AtomicInteger();
997+
private final ElasticsearchInternalModel esModel;
998+
private final List<EmbeddingRequestChunker.BatchRequestAndListener> requestAndListeners;
999+
private final InputType inputType;
1000+
private final TimeValue timeout;
1001+
1002+
BatchIterator(
1003+
ElasticsearchInternalModel esModel,
1004+
InputType inputType,
1005+
TimeValue timeout,
1006+
List<EmbeddingRequestChunker.BatchRequestAndListener> requestAndListeners
1007+
) {
1008+
this.esModel = esModel;
1009+
this.requestAndListeners = requestAndListeners;
1010+
this.inputType = inputType;
1011+
this.timeout = timeout;
1012+
}
1013+
1014+
void run() {
1015+
// The first request may deploy the model, and upon completion runs
1016+
// NUM_REQUESTS_INFLIGHT in parallel.
1017+
inferenceExecutor.execute(() -> inferBatch(NUM_REQUESTS_INFLIGHT, true));
1018+
}
1019+
1020+
private void inferBatch(int runAfterCount, boolean maybeDeploy) {
1021+
int batchIndex = index.getAndIncrement();
1022+
if (batchIndex >= requestAndListeners.size()) {
1023+
return;
1024+
}
1025+
executeRequest(batchIndex, maybeDeploy, () -> {
1026+
for (int i = 0; i < runAfterCount; i++) {
1027+
// Subsequent requests may not deploy the model, because the first request
1028+
// already did so. Upon completion, it runs one more request.
1029+
inferenceExecutor.execute(() -> inferBatch(1, false));
1030+
}
1031+
});
1032+
}
1033+
1034+
private void executeRequest(int batchIndex, boolean maybeDeploy, Runnable runAfter) {
1035+
EmbeddingRequestChunker.BatchRequestAndListener batch = requestAndListeners.get(batchIndex);
1036+
var inferenceRequest = buildInferenceRequest(
1037+
esModel.mlNodeDeploymentId(),
1038+
EmptyConfigUpdate.INSTANCE,
1039+
batch.batch().inputs(),
1040+
inputType,
1041+
timeout
1042+
);
1043+
logger.trace("Executing batch index={}", batchIndex);
1044+
1045+
ActionListener<InferModelAction.Response> listener = batch.listener()
1046+
.delegateFailureAndWrap(
1047+
(l, inferenceResult) -> translateToChunkedResult(esModel.getTaskType(), inferenceResult.getInferenceResults(), l)
1048+
);
1049+
if (runAfter != null) {
1050+
listener = ActionListener.runAfter(listener, runAfter);
1051+
}
1052+
if (maybeDeploy) {
1053+
listener = listener.delegateResponse((l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, l));
1054+
}
1055+
client.execute(InferModelAction.INSTANCE, inferenceRequest, listener);
1056+
}
1057+
}
9931058
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,25 @@
2424
import java.util.concurrent.atomic.AtomicReference;
2525

2626
import static org.hamcrest.Matchers.contains;
27+
import static org.hamcrest.Matchers.empty;
2728
import static org.hamcrest.Matchers.hasSize;
2829
import static org.hamcrest.Matchers.instanceOf;
2930
import static org.hamcrest.Matchers.startsWith;
3031

3132
public class EmbeddingRequestChunkerTests extends ESTestCase {
3233

34+
public void testEmptyInput() {
35+
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
36+
var batches = new EmbeddingRequestChunker(List.of(), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener());
37+
assertThat(batches, empty());
38+
}
39+
40+
public void testBlankInput() {
41+
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
42+
var batches = new EmbeddingRequestChunker(List.of(""), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener());
43+
assertThat(batches, hasSize(1));
44+
}
45+
3346
public void testShortInputsAreSingleBatch() {
3447
String input = "one chunk";
3548
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());

0 commit comments

Comments
 (0)