|
68 | 68 | import java.util.Map; |
69 | 69 | import java.util.Optional; |
70 | 70 | import java.util.Set; |
| 71 | +import java.util.concurrent.atomic.AtomicInteger; |
71 | 72 | import java.util.function.Consumer; |
72 | 73 | import java.util.function.Function; |
73 | 74 | import java.util.stream.Stream; |
@@ -680,25 +681,13 @@ public void chunkedInfer( |
680 | 681 | esModel.getConfigurations().getChunkingSettings() |
681 | 682 | ).batchRequestsWithListeners(listener); |
682 | 683 |
|
683 | | - for (var batch : batchedRequests) { |
684 | | - var inferenceRequest = buildInferenceRequest( |
685 | | - esModel.mlNodeDeploymentId(), |
686 | | - EmptyConfigUpdate.INSTANCE, |
687 | | - batch.batch().inputs(), |
688 | | - inputType, |
689 | | - timeout |
690 | | - ); |
691 | | - |
692 | | - ActionListener<InferModelAction.Response> mlResultsListener = batch.listener() |
693 | | - .delegateFailureAndWrap( |
694 | | - (l, inferenceResult) -> translateToChunkedResult(model.getTaskType(), inferenceResult.getInferenceResults(), l) |
695 | | - ); |
696 | | - |
697 | | - var maybeDeployListener = mlResultsListener.delegateResponse( |
698 | | - (l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, mlResultsListener) |
699 | | - ); |
700 | | - |
701 | | - client.execute(InferModelAction.INSTANCE, inferenceRequest, maybeDeployListener); |
| 684 | + if (batchedRequests.isEmpty()) { |
| 685 | + listener.onResponse(List.of()); |
| 686 | + } else { |
| 687 | + // Avoid filling the inference queue by executing the batches in series |
| 688 | + // Each batch contains up to EMBEDDING_MAX_BATCH_SIZE inference request |
| 689 | + var sequentialRunner = new BatchIterator(esModel, inputType, timeout, batchedRequests); |
| 690 | + sequentialRunner.run(); |
702 | 691 | } |
703 | 692 | } else { |
704 | 693 | listener.onFailure(notElasticsearchModelException(model)); |
@@ -1017,6 +1006,82 @@ static TaskType inferenceConfigToTaskType(InferenceConfig config) { |
1017 | 1006 | } |
1018 | 1007 | } |
1019 | 1008 |
|
| 1009 | + /** |
| 1010 | + * Iterates over the batch executing a limited number requests at a time to avoid |
| 1011 | + * filling the ML node inference queue. |
| 1012 | + * |
| 1013 | + * First, a single request is executed, which can also trigger deploying a model |
| 1014 | + * if necessary. When this request is successfully executed, a callback executes |
| 1015 | + * N requests in parallel next. Each of these requests also has a callback that |
| 1016 | + * executes one more request, so that at all time N requests are in-flight. This |
| 1017 | + * continues until all requests are executed. |
| 1018 | + */ |
| 1019 | + class BatchIterator { |
| 1020 | + private static final int NUM_REQUESTS_INFLIGHT = 20; // * batch size = 200 |
| 1021 | + |
| 1022 | + private final AtomicInteger index = new AtomicInteger(); |
| 1023 | + private final ElasticsearchInternalModel esModel; |
| 1024 | + private final List<EmbeddingRequestChunker.BatchRequestAndListener> requestAndListeners; |
| 1025 | + private final InputType inputType; |
| 1026 | + private final TimeValue timeout; |
| 1027 | + |
| 1028 | + BatchIterator( |
| 1029 | + ElasticsearchInternalModel esModel, |
| 1030 | + InputType inputType, |
| 1031 | + TimeValue timeout, |
| 1032 | + List<EmbeddingRequestChunker.BatchRequestAndListener> requestAndListeners |
| 1033 | + ) { |
| 1034 | + this.esModel = esModel; |
| 1035 | + this.requestAndListeners = requestAndListeners; |
| 1036 | + this.inputType = inputType; |
| 1037 | + this.timeout = timeout; |
| 1038 | + } |
| 1039 | + |
| 1040 | + void run() { |
| 1041 | + // The first request may deploy the model, and upon completion runs |
| 1042 | + // NUM_REQUESTS_INFLIGHT in parallel. |
| 1043 | + inferenceExecutor.execute(() -> inferBatch(NUM_REQUESTS_INFLIGHT, true)); |
| 1044 | + } |
| 1045 | + |
| 1046 | + private void inferBatch(int runAfterCount, boolean maybeDeploy) { |
| 1047 | + int batchIndex = index.getAndIncrement(); |
| 1048 | + if (batchIndex >= requestAndListeners.size()) { |
| 1049 | + return; |
| 1050 | + } |
| 1051 | + executeRequest(batchIndex, maybeDeploy, () -> { |
| 1052 | + for (int i = 0; i < runAfterCount; i++) { |
| 1053 | + // Subsequent requests may not deploy the model, because the first request |
| 1054 | + // already did so. Upon completion, it runs one more request. |
| 1055 | + inferenceExecutor.execute(() -> inferBatch(1, false)); |
| 1056 | + } |
| 1057 | + }); |
| 1058 | + } |
| 1059 | + |
| 1060 | + private void executeRequest(int batchIndex, boolean maybeDeploy, Runnable runAfter) { |
| 1061 | + EmbeddingRequestChunker.BatchRequestAndListener batch = requestAndListeners.get(batchIndex); |
| 1062 | + var inferenceRequest = buildInferenceRequest( |
| 1063 | + esModel.mlNodeDeploymentId(), |
| 1064 | + EmptyConfigUpdate.INSTANCE, |
| 1065 | + batch.batch().inputs(), |
| 1066 | + inputType, |
| 1067 | + timeout |
| 1068 | + ); |
| 1069 | + logger.trace("Executing batch index={}", batchIndex); |
| 1070 | + |
| 1071 | + ActionListener<InferModelAction.Response> listener = batch.listener() |
| 1072 | + .delegateFailureAndWrap( |
| 1073 | + (l, inferenceResult) -> translateToChunkedResult(esModel.getTaskType(), inferenceResult.getInferenceResults(), l) |
| 1074 | + ); |
| 1075 | + if (runAfter != null) { |
| 1076 | + listener = ActionListener.runAfter(listener, runAfter); |
| 1077 | + } |
| 1078 | + if (maybeDeploy) { |
| 1079 | + listener = listener.delegateResponse((l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, l)); |
| 1080 | + } |
| 1081 | + client.execute(InferModelAction.INSTANCE, inferenceRequest, listener); |
| 1082 | + } |
| 1083 | + } |
| 1084 | + |
1020 | 1085 | public static class Configuration { |
1021 | 1086 | public static InferenceServiceConfiguration get() { |
1022 | 1087 | return configuration.getOrCompute(); |
|
0 commit comments