|
29 | 29 | import java.util.Iterator; |
30 | 30 | import java.util.List; |
31 | 31 | import java.util.concurrent.CountDownLatch; |
| 32 | +import java.util.concurrent.TimeUnit; |
32 | 33 | import java.util.concurrent.atomic.AtomicReference; |
33 | 34 |
|
34 | | -import static org.hamcrest.Matchers.allOf; |
35 | | -import static org.hamcrest.Matchers.empty; |
36 | | -import static org.hamcrest.Matchers.equalTo; |
37 | | -import static org.hamcrest.Matchers.notNullValue; |
| 35 | +import static org.hamcrest.Matchers.*; |
38 | 36 | import static org.mockito.ArgumentMatchers.any; |
39 | 37 | import static org.mockito.ArgumentMatchers.eq; |
40 | | -import static org.mockito.Mockito.doAnswer; |
41 | | -import static org.mockito.Mockito.mock; |
42 | | -import static org.mockito.Mockito.when; |
| 38 | +import static org.mockito.Mockito.*; |
43 | 39 |
|
44 | 40 | public class BulkInferenceRunnerTests extends ESTestCase { |
45 | 41 | private ThreadPool threadPool; |
@@ -150,26 +146,28 @@ public void testParallelBulkExecution() throws Exception { |
150 | 146 | CountDownLatch latch = new CountDownLatch(batches); |
151 | 147 |
|
152 | 148 | for (int i = 0; i < batches; i++) { |
153 | | - List<InferenceAction.Request> requests = randomInferenceRequestList(between(1, 1_000)); |
154 | | - List<InferenceAction.Response> responses = randomInferenceResponseList(requests.size()); |
155 | | - |
156 | | - Client client = mockClient(invocation -> { |
157 | | - runWithRandomDelay(() -> { |
158 | | - ActionListener<InferenceAction.Response> l = invocation.getArgument(2); |
159 | | - l.onResponse(responses.get(requests.indexOf(invocation.getArgument(1, InferenceAction.Request.class)))); |
| 149 | + runWithRandomDelay(() -> { |
| 150 | + List<InferenceAction.Request> requests = randomInferenceRequestList(between(1, 1_000)); |
| 151 | + List<InferenceAction.Response> responses = randomInferenceResponseList(requests.size()); |
| 152 | + |
| 153 | + Client client = mockClient(invocation -> { |
| 154 | + runWithRandomDelay(() -> { |
| 155 | + ActionListener<InferenceAction.Response> l = invocation.getArgument(2); |
| 156 | + l.onResponse(responses.get(requests.indexOf(invocation.getArgument(1, InferenceAction.Request.class)))); |
| 157 | + }); |
| 158 | + return null; |
160 | 159 | }); |
161 | | - return null; |
162 | | - }); |
163 | 160 |
|
164 | | - ActionListener<List<InferenceAction.Response>> listener = ActionListener.wrap(r -> { |
165 | | - assertThat(r, equalTo(responses)); |
166 | | - latch.countDown(); |
167 | | - }, ESTestCase::fail); |
| 161 | + ActionListener<List<InferenceAction.Response>> listener = ActionListener.wrap(r -> { |
| 162 | + assertThat(r, equalTo(responses)); |
| 163 | + latch.countDown(); |
| 164 | + }, ESTestCase::fail); |
168 | 165 |
|
169 | | - inferenceRunnerFactory(client).create(randomBulkExecutionConfig()).executeBulk(requestIterator(requests), listener); |
| 166 | + inferenceRunnerFactory(client).create(randomBulkExecutionConfig()).executeBulk(requestIterator(requests), listener); |
| 167 | + }); |
170 | 168 | } |
171 | 169 |
|
172 | | - latch.await(); |
| 170 | + latch.await(10, TimeUnit.SECONDS); |
173 | 171 | } |
174 | 172 |
|
175 | 173 | private BulkInferenceRunner.Factory inferenceRunnerFactory(Client client) { |
|
0 commit comments