1212import org .elasticsearch .common .settings .Settings ;
1313import org .elasticsearch .common .util .concurrent .EsExecutors ;
1414import org .elasticsearch .core .TimeValue ;
15+ import org .elasticsearch .logging .LogManager ;
1516import org .elasticsearch .test .ESTestCase ;
1617import org .elasticsearch .test .client .NoOpClient ;
1718import org .elasticsearch .threadpool .FixedExecutorBuilder ;
2728import java .util .ArrayList ;
2829import java .util .Iterator ;
2930import java .util .List ;
31+ import java .util .concurrent .CountDownLatch ;
3032import java .util .concurrent .atomic .AtomicReference ;
3133
3234import static org .hamcrest .Matchers .allOf ;
@@ -63,7 +65,7 @@ public void shutdownThreadPool() {
6365 }
6466
6567 public void testSuccessfulBulkExecution () throws Exception {
66- List <InferenceAction .Request > requests = randomInferenceRequestList (between (1 , 1000 ));
68+ List <InferenceAction .Request > requests = randomInferenceRequestList (between (1 , 1_000 ));
6769 List <InferenceAction .Response > responses = randomInferenceResponseList (requests .size ());
6870
6971 Client client = mockClient (invocation -> {
@@ -117,7 +119,7 @@ public void testBulkExecutionWhenInferenceRunnerAlwaysFails() throws Exception {
117119 }
118120
119121 public void testBulkExecutionWhenInferenceRunnerSometimesFails () throws Exception {
120- List <InferenceAction .Request > requests = randomInferenceRequestList (between (1 , 1000 ));
122+ List <InferenceAction .Request > requests = randomInferenceRequestList (between (1 , 1_000 ));
121123
122124 Client client = mockClient (invocation -> {
123125 ActionListener <InferenceAction .Response > listener = invocation .getArgument (2 );
@@ -143,6 +145,34 @@ public void testBulkExecutionWhenInferenceRunnerSometimesFails() throws Exceptio
143145 });
144146 }
145147
148+ public void testParallelBulkExecution () throws Exception {
149+ int batches = between (50 , 100 );
150+ CountDownLatch latch = new CountDownLatch (batches );
151+
152+ for (int i = 0 ; i < batches ; i ++) {
153+ List <InferenceAction .Request > requests = randomInferenceRequestList (between (1 , 1_000 ));
154+ List <InferenceAction .Response > responses = randomInferenceResponseList (requests .size ());
155+
156+ Client client = mockClient (invocation -> {
157+ runWithRandomDelay (() -> {
158+ ActionListener <InferenceAction .Response > l = invocation .getArgument (2 );
159+ l .onResponse (responses .get (requests .indexOf (invocation .getArgument (1 , InferenceAction .Request .class ))));
160+ });
161+ return null ;
162+ });
163+
164+ ActionListener <List <InferenceAction .Response >> listener = ActionListener .wrap (r -> {
165+ assertThat (r , equalTo (responses ));
166+ LogManager .getLogger (BulkInferenceRunnerTests .class ).warn ("Received [{}] responses" , responses .size ());
167+ latch .countDown ();
168+ }, ESTestCase ::fail );
169+
170+ inferenceRunnerFactory (client ).create (randomBulkExecutionConfig ()).executeBulk (requestIterator (requests ), listener );
171+ }
172+
173+ latch .await ();
174+ }
175+
146176 private BulkInferenceRunner .Factory inferenceRunnerFactory (Client client ) {
147177 return BulkInferenceRunner .factory (client );
148178 }
@@ -198,7 +228,7 @@ private void runWithRandomDelay(Runnable runnable) {
198228 if (randomBoolean ()) {
199229 runnable .run ();
200230 } else {
201- threadPool .schedule (runnable , TimeValue .timeValueNanos (between (1 , 1_000 )), threadPool .generic ());
231+ threadPool .schedule (runnable , TimeValue .timeValueNanos (between (1 , 100_000 )), threadPool .generic ());
202232 }
203233 }
204234}
0 commit comments