|
28 | 28 | import java.util.Iterator; |
29 | 29 | import java.util.List; |
30 | 30 | import java.util.concurrent.CountDownLatch; |
| 31 | +import java.util.concurrent.TimeUnit; |
31 | 32 | import java.util.concurrent.atomic.AtomicReference; |
32 | 33 |
|
33 | | -import static org.hamcrest.Matchers.allOf; |
34 | | -import static org.hamcrest.Matchers.empty; |
35 | | -import static org.hamcrest.Matchers.equalTo; |
36 | | -import static org.hamcrest.Matchers.notNullValue; |
| 34 | +import static org.hamcrest.Matchers.*; |
37 | 35 | import static org.mockito.ArgumentMatchers.any; |
38 | 36 | import static org.mockito.ArgumentMatchers.eq; |
39 | | -import static org.mockito.Mockito.doAnswer; |
40 | | -import static org.mockito.Mockito.mock; |
41 | | -import static org.mockito.Mockito.when; |
| 37 | +import static org.mockito.Mockito.*; |
42 | 38 |
|
43 | 39 | public class BulkInferenceRunnerTests extends ESTestCase { |
44 | 40 | private ThreadPool threadPool; |
@@ -149,26 +145,28 @@ public void testParallelBulkExecution() throws Exception { |
149 | 145 | CountDownLatch latch = new CountDownLatch(batches); |
150 | 146 |
|
151 | 147 | for (int i = 0; i < batches; i++) { |
152 | | - List<InferenceAction.Request> requests = randomInferenceRequestList(between(1, 1_000)); |
153 | | - List<InferenceAction.Response> responses = randomInferenceResponseList(requests.size()); |
154 | | - |
155 | | - Client client = mockClient(invocation -> { |
156 | | - runWithRandomDelay(() -> { |
157 | | - ActionListener<InferenceAction.Response> l = invocation.getArgument(2); |
158 | | - l.onResponse(responses.get(requests.indexOf(invocation.getArgument(1, InferenceAction.Request.class)))); |
| 148 | + runWithRandomDelay(() -> { |
| 149 | + List<InferenceAction.Request> requests = randomInferenceRequestList(between(1, 1_000)); |
| 150 | + List<InferenceAction.Response> responses = randomInferenceResponseList(requests.size()); |
| 151 | + |
| 152 | + Client client = mockClient(invocation -> { |
| 153 | + runWithRandomDelay(() -> { |
| 154 | + ActionListener<InferenceAction.Response> l = invocation.getArgument(2); |
| 155 | + l.onResponse(responses.get(requests.indexOf(invocation.getArgument(1, InferenceAction.Request.class)))); |
| 156 | + }); |
| 157 | + return null; |
159 | 158 | }); |
160 | | - return null; |
161 | | - }); |
162 | 159 |
|
163 | | - ActionListener<List<InferenceAction.Response>> listener = ActionListener.wrap(r -> { |
164 | | - assertThat(r, equalTo(responses)); |
165 | | - latch.countDown(); |
166 | | - }, ESTestCase::fail); |
| 160 | + ActionListener<List<InferenceAction.Response>> listener = ActionListener.wrap(r -> { |
| 161 | + assertThat(r, equalTo(responses)); |
| 162 | + latch.countDown(); |
| 163 | + }, ESTestCase::fail); |
167 | 164 |
|
168 | | - inferenceRunnerFactory(client).create(randomBulkExecutionConfig()).executeBulk(requestIterator(requests), listener); |
| 165 | + inferenceRunnerFactory(client).create(randomBulkExecutionConfig()).executeBulk(requestIterator(requests), listener); |
| 166 | + }); |
169 | 167 | } |
170 | 168 |
|
171 | | - latch.await(); |
| 169 | + latch.await(10, TimeUnit.SECONDS); |
172 | 170 | } |
173 | 171 |
|
174 | 172 | private BulkInferenceRunner.Factory inferenceRunnerFactory(Client client) { |
|
0 commit comments