|
60 | 60 | import java.util.Map; |
61 | 61 | import java.util.Optional; |
62 | 62 | import java.util.Set; |
| 63 | +import java.util.concurrent.atomic.AtomicInteger; |
63 | 64 | import java.util.function.Consumer; |
64 | 65 | import java.util.function.Function; |
65 | 66 |
|
@@ -656,25 +657,13 @@ public void chunkedInfer( |
656 | 657 | esModel.getConfigurations().getChunkingSettings() |
657 | 658 | ).batchRequestsWithListeners(listener); |
658 | 659 |
|
659 | | - for (var batch : batchedRequests) { |
660 | | - var inferenceRequest = buildInferenceRequest( |
661 | | - esModel.mlNodeDeploymentId(), |
662 | | - EmptyConfigUpdate.INSTANCE, |
663 | | - batch.batch().inputs(), |
664 | | - inputType, |
665 | | - timeout |
666 | | - ); |
667 | | - |
668 | | - ActionListener<InferModelAction.Response> mlResultsListener = batch.listener() |
669 | | - .delegateFailureAndWrap( |
670 | | - (l, inferenceResult) -> translateToChunkedResult(model.getTaskType(), inferenceResult.getInferenceResults(), l) |
671 | | - ); |
672 | | - |
673 | | - var maybeDeployListener = mlResultsListener.delegateResponse( |
674 | | - (l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, mlResultsListener) |
675 | | - ); |
676 | | - |
677 | | - client.execute(InferModelAction.INSTANCE, inferenceRequest, maybeDeployListener); |
| 660 | + if (batchedRequests.isEmpty()) { |
| 661 | + listener.onResponse(List.of()); |
| 662 | + } else { |
| 663 | + // Avoid filling the inference queue by executing the batches in series |
| 664 | + // Each batch contains up to EMBEDDING_MAX_BATCH_SIZE inference request |
| 665 | + var sequentialRunner = new BatchIterator(esModel, inputType, timeout, batchedRequests); |
| 666 | + sequentialRunner.run(); |
678 | 667 | } |
679 | 668 | } else { |
680 | 669 | listener.onFailure(notElasticsearchModelException(model)); |
@@ -990,4 +979,80 @@ static TaskType inferenceConfigToTaskType(InferenceConfig config) { |
990 | 979 | return null; |
991 | 980 | } |
992 | 981 | } |
| 982 | + |
| 983 | + /** |
| 984 | + * Iterates over the batch executing a limited number requests at a time to avoid |
| 985 | + * filling the ML node inference queue. |
| 986 | + * |
| 987 | + * First, a single request is executed, which can also trigger deploying a model |
| 988 | + * if necessary. When this request is successfully executed, a callback executes |
| 989 | + * N requests in parallel next. Each of these requests also has a callback that |
| 990 | + * executes one more request, so that at all time N requests are in-flight. This |
| 991 | + * continues until all requests are executed. |
| 992 | + */ |
| 993 | + class BatchIterator { |
| 994 | + private static final int NUM_REQUESTS_INFLIGHT = 20; // * batch size = 200 |
| 995 | + |
| 996 | + private final AtomicInteger index = new AtomicInteger(); |
| 997 | + private final ElasticsearchInternalModel esModel; |
| 998 | + private final List<EmbeddingRequestChunker.BatchRequestAndListener> requestAndListeners; |
| 999 | + private final InputType inputType; |
| 1000 | + private final TimeValue timeout; |
| 1001 | + |
| 1002 | + BatchIterator( |
| 1003 | + ElasticsearchInternalModel esModel, |
| 1004 | + InputType inputType, |
| 1005 | + TimeValue timeout, |
| 1006 | + List<EmbeddingRequestChunker.BatchRequestAndListener> requestAndListeners |
| 1007 | + ) { |
| 1008 | + this.esModel = esModel; |
| 1009 | + this.requestAndListeners = requestAndListeners; |
| 1010 | + this.inputType = inputType; |
| 1011 | + this.timeout = timeout; |
| 1012 | + } |
| 1013 | + |
| 1014 | + void run() { |
| 1015 | + // The first request may deploy the model, and upon completion runs |
| 1016 | + // NUM_REQUESTS_INFLIGHT in parallel. |
| 1017 | + inferenceExecutor.execute(() -> inferBatch(NUM_REQUESTS_INFLIGHT, true)); |
| 1018 | + } |
| 1019 | + |
| 1020 | + private void inferBatch(int runAfterCount, boolean maybeDeploy) { |
| 1021 | + int batchIndex = index.getAndIncrement(); |
| 1022 | + if (batchIndex >= requestAndListeners.size()) { |
| 1023 | + return; |
| 1024 | + } |
| 1025 | + executeRequest(batchIndex, maybeDeploy, () -> { |
| 1026 | + for (int i = 0; i < runAfterCount; i++) { |
| 1027 | + // Subsequent requests may not deploy the model, because the first request |
| 1028 | + // already did so. Upon completion, it runs one more request. |
| 1029 | + inferenceExecutor.execute(() -> inferBatch(1, false)); |
| 1030 | + } |
| 1031 | + }); |
| 1032 | + } |
| 1033 | + |
| 1034 | + private void executeRequest(int batchIndex, boolean maybeDeploy, Runnable runAfter) { |
| 1035 | + EmbeddingRequestChunker.BatchRequestAndListener batch = requestAndListeners.get(batchIndex); |
| 1036 | + var inferenceRequest = buildInferenceRequest( |
| 1037 | + esModel.mlNodeDeploymentId(), |
| 1038 | + EmptyConfigUpdate.INSTANCE, |
| 1039 | + batch.batch().inputs(), |
| 1040 | + inputType, |
| 1041 | + timeout |
| 1042 | + ); |
| 1043 | + logger.trace("Executing batch index={}", batchIndex); |
| 1044 | + |
| 1045 | + ActionListener<InferModelAction.Response> listener = batch.listener() |
| 1046 | + .delegateFailureAndWrap( |
| 1047 | + (l, inferenceResult) -> translateToChunkedResult(esModel.getTaskType(), inferenceResult.getInferenceResults(), l) |
| 1048 | + ); |
| 1049 | + if (runAfter != null) { |
| 1050 | + listener = ActionListener.runAfter(listener, runAfter); |
| 1051 | + } |
| 1052 | + if (maybeDeploy) { |
| 1053 | + listener = listener.delegateResponse((l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, l)); |
| 1054 | + } |
| 1055 | + client.execute(InferModelAction.INSTANCE, inferenceRequest, listener); |
| 1056 | + } |
| 1057 | + } |
993 | 1058 | } |
0 commit comments