@@ -994,9 +994,19 @@ static TaskType inferenceConfigToTaskType(InferenceConfig config) {
994994 }
995995 }
996996
997- // Iterates over the batch sending 1 request at a time to avoid
998- // filling the ml node inference queue.
997+ /**
998+ * Iterates over the batch executing a limited number requests at a time to avoid
999+ * filling the ML node inference queue.
1000+ *
1001+ * First, a single request is executed, which can also trigger deploying a model
1002+ * if necessary. When this request is successfully executed, a callback executes
1003+ * N requests in parallel next. Each of these requests also has a callback that
1004+ * executes one more request, so that at all time N requests are in-flight. This
1005+ * continues until all requests are executed.
1006+ */
9991007 class BatchIterator {
1008+ private static final int NUM_REQUESTS_INFLIGHT = 20 ; // * batch size = 200
1009+
10001010 private final AtomicInteger index = new AtomicInteger ();
10011011 private final ElasticsearchInternalModel esModel ;
10021012 private final List <EmbeddingRequestChunker .BatchRequestAndListener > requestAndListeners ;
@@ -1016,67 +1026,47 @@ class BatchIterator {
10161026 }
10171027
10181028 void run () {
1019- inferenceExecutor .execute (this ::inferBatchAndRunAfter );
1029+ // The first request may deploy the model, and upon completion runs
1030+ // NUM_REQUESTS_INFLIGHT in parallel.
1031+ inferenceExecutor .execute (() -> inferBatch (NUM_REQUESTS_INFLIGHT , true ));
10201032 }
10211033
1022- private void inferBatchAndRunAfter () {
1023- int NUM_REQUESTS_INFLIGHT = 20 ; // * batch size = 200
1024- int requestCount = 0 ;
1025- // loop does not include the final request
1026- while (requestCount < NUM_REQUESTS_INFLIGHT - 1 && index .get () < requestAndListeners .size () - 1 ) {
1027-
1028- var batch = requestAndListeners .get (index .get ());
1029- executeRequest (batch );
1030- requestCount ++;
1031- index .incrementAndGet ();
1034+ private void inferBatch (int runAfterCount , boolean maybeDeploy ) {
1035+ int batchIndex = index .getAndIncrement ();
1036+ if (batchIndex >= requestAndListeners .size ()) {
1037+ return ;
10321038 }
1033-
1034- var batch = requestAndListeners . get ( index . get ());
1035- executeRequest ( batch , () -> {
1036- if ( index . incrementAndGet () < requestAndListeners . size ()) {
1037- run (); // start the next batch
1039+ executeRequest ( batchIndex , maybeDeploy , () -> {
1040+ for ( int i = 0 ; i < runAfterCount ; i ++) {
1041+ // Subsequent requests may not deploy the model, because the first request
1042+ // already did so. Upon completion, it runs one more request.
1043+ inferenceExecutor . execute (() -> inferBatch ( 1 , false ));
10381044 }
10391045 });
10401046 }
10411047
1042- private void executeRequest (EmbeddingRequestChunker .BatchRequestAndListener batch ) {
1048+ private void executeRequest (int batchIndex , boolean maybeDeploy , Runnable runAfter ) {
1049+ EmbeddingRequestChunker .BatchRequestAndListener batch = requestAndListeners .get (batchIndex );
10431050 var inferenceRequest = buildInferenceRequest (
10441051 esModel .mlNodeDeploymentId (),
10451052 EmptyConfigUpdate .INSTANCE ,
10461053 batch .batch ().inputs (),
10471054 inputType ,
10481055 timeout
10491056 );
1057+ logger .trace ("Executing batch index={}" , batchIndex );
10501058
1051- ActionListener <InferModelAction .Response > mlResultsListener = batch .listener ()
1059+ ActionListener <InferModelAction .Response > listener = batch .listener ()
10521060 .delegateFailureAndWrap (
10531061 (l , inferenceResult ) -> translateToChunkedResult (esModel .getTaskType (), inferenceResult .getInferenceResults (), l )
10541062 );
1055-
1056- var maybeDeployListener = mlResultsListener .delegateResponse (
1057- (l , exception ) -> maybeStartDeployment (esModel , exception , inferenceRequest , l )
1058- );
1059-
1060- client .execute (InferModelAction .INSTANCE , inferenceRequest , maybeDeployListener );
1061- }
1062-
1063- private void executeRequest (EmbeddingRequestChunker .BatchRequestAndListener batch , Runnable runAfter ) {
1064- var inferenceRequest = buildInferenceRequest (
1065- esModel .mlNodeDeploymentId (),
1066- EmptyConfigUpdate .INSTANCE ,
1067- batch .batch ().inputs (),
1068- inputType ,
1069- timeout
1070- );
1071-
1072- ActionListener <InferModelAction .Response > mlResultsListener = batch .listener ()
1073- .delegateFailureAndWrap (
1074- (l , inferenceResult ) -> translateToChunkedResult (esModel .getTaskType (), inferenceResult .getInferenceResults (), l )
1075- );
1076-
1077- // schedule the next request once the results have been processed
1078- var runNextListener = ActionListener .runAfter (mlResultsListener , runAfter );
1079- client .execute (InferModelAction .INSTANCE , inferenceRequest , runNextListener );
1063+ if (runAfter != null ) {
1064+ listener = ActionListener .runAfter (listener , runAfter );
1065+ }
1066+ if (maybeDeploy ) {
1067+ listener = listener .delegateResponse ((l , exception ) -> maybeStartDeployment (esModel , exception , inferenceRequest , l ));
1068+ }
1069+ client .execute (InferModelAction .INSTANCE , inferenceRequest , listener );
10801070 }
10811071 }
10821072}
0 commit comments