|
30 | 30 | import java.util.stream.Stream; |
31 | 31 |
|
32 | 32 | import static org.hamcrest.Matchers.allOf; |
33 | | -import static org.hamcrest.Matchers.arrayContaining; |
34 | | -import static org.hamcrest.Matchers.emptyArray; |
| 33 | +import static org.hamcrest.Matchers.empty; |
35 | 34 | import static org.hamcrest.Matchers.equalTo; |
36 | 35 | import static org.hamcrest.Matchers.notNullValue; |
37 | 36 | import static org.mockito.ArgumentMatchers.any; |
@@ -63,68 +62,74 @@ public void shutdownThreadPool() { |
63 | 62 | } |
64 | 63 |
|
65 | 64 | public void testSuccessfulExecution() throws Exception { |
66 | | - List<InferenceAction.Request> requests = randomInferenceRequestList(between(1, 50)); |
67 | | - InferenceAction.Response[] responses = randomInferenceResponseList(requests.size()).toArray(InferenceAction.Response[]::new); |
| 65 | + List<InferenceAction.Request> requests = randomInferenceRequestList(between(90_000, 100_000)); |
| 66 | + List<InferenceAction.Response> responses = randomInferenceResponseList(requests.size()); |
68 | 67 |
|
69 | 68 | InferenceRunner inferenceRunner = mockInferenceRunner(invocation -> { |
70 | | - ActionListener<InferenceAction.Response> l = invocation.getArgument(1); |
71 | | - l.onResponse(responses[requests.indexOf(invocation.getArgument(0, InferenceAction.Request.class))]); |
| 69 | + runWithRandomDelay(() -> { |
| 70 | + ActionListener<InferenceAction.Response> l = invocation.getArgument(1); |
| 71 | + l.onResponse(responses.get(requests.indexOf(invocation.getArgument(0, InferenceAction.Request.class)))); |
| 72 | + }); |
72 | 73 | return null; |
73 | 74 | }); |
74 | 75 |
|
75 | | - AtomicReference<InferenceAction.Response[]> output = new AtomicReference<>(); |
76 | | - ActionListener<InferenceAction.Response[]> listener = ActionListener.wrap(output::set, r -> fail("Unexpected exception")); |
| 76 | + AtomicReference<List<InferenceAction.Response>> output = new AtomicReference<>(); |
| 77 | + ActionListener<List<InferenceAction.Response>> listener = ActionListener.wrap(output::set, r -> fail("Unexpected exception")); |
77 | 78 |
|
78 | 79 | bulkExecutor(inferenceRunner).execute(requestIterator(requests), listener); |
79 | 80 |
|
80 | | - assertBusy(() -> assertThat(output.get(), allOf(notNullValue(), arrayContaining(responses)))); |
| 81 | + assertBusy(() -> assertThat(output.get(), allOf(notNullValue(), equalTo(responses)))); |
81 | 82 | } |
82 | 83 |
|
83 | 84 | public void testSuccessfulExecutionOnEmptyRequest() throws Exception { |
84 | 85 | BulkInferenceRequestIterator requestIterator = mock(BulkInferenceRequestIterator.class); |
85 | 86 | when(requestIterator.hasNext()).thenReturn(false); |
86 | 87 |
|
87 | | - AtomicReference<InferenceAction.Response[]> output = new AtomicReference<>(); |
88 | | - ActionListener<InferenceAction.Response[]> listener = ActionListener.wrap(output::set, r -> fail("Unexpected exception")); |
| 88 | + AtomicReference<List<InferenceAction.Response>> output = new AtomicReference<>(); |
| 89 | + ActionListener<List<InferenceAction.Response>> listener = ActionListener.wrap(output::set, r -> fail("Unexpected exception")); |
89 | 90 |
|
90 | 91 | bulkExecutor(mock(InferenceRunner.class)).execute(requestIterator, listener); |
91 | 92 |
|
92 | | - assertBusy(() -> assertThat(output.get(), allOf(notNullValue(), emptyArray()))); |
| 93 | + assertBusy(() -> assertThat(output.get(), allOf(notNullValue(), empty()))); |
93 | 94 | } |
94 | 95 |
|
95 | 96 | public void testInferenceRunnerAlwaysFails() throws Exception { |
96 | | - List<InferenceAction.Request> requests = randomInferenceRequestList(between(1, 30)); |
| 97 | + List<InferenceAction.Request> requests = randomInferenceRequestList(between(90_000, 100_000)); |
97 | 98 |
|
98 | 99 | InferenceRunner inferenceRunner = mock(invocation -> { |
99 | | - ActionListener<InferenceAction.Response> listener = invocation.getArgument(1); |
100 | | - listener.onFailure(new RuntimeException("inference failure")); |
| 100 | + runWithRandomDelay(() -> { |
| 101 | + ActionListener<InferenceAction.Response> listener = invocation.getArgument(1); |
| 102 | + listener.onFailure(new RuntimeException("inference failure")); |
| 103 | + }); |
101 | 104 | return null; |
102 | 105 | }); |
103 | 106 |
|
104 | 107 | AtomicReference<Exception> exception = new AtomicReference<>(); |
105 | | - ActionListener<InferenceAction.Response[]> listener = ActionListener.wrap(r -> fail("Expceted exception"), exception::set); |
| 108 | + ActionListener<List<InferenceAction.Response>> listener = ActionListener.wrap(r -> fail("Expceted exception"), exception::set); |
106 | 109 |
|
107 | 110 | bulkExecutor(inferenceRunner).execute(requestIterator(requests), listener); |
108 | 111 |
|
109 | 112 | assertBusy(() -> assertThat(exception.get().getMessage(), equalTo("inference failure"))); |
110 | 113 | } |
111 | 114 |
|
112 | 115 | public void testInferenceRunnerSometimesFails() throws Exception { |
113 | | - List<InferenceAction.Request> requests = randomInferenceRequestList(between(2, 30)); |
| 116 | + List<InferenceAction.Request> requests = randomInferenceRequestList(between(90_000, 100_000)); |
114 | 117 |
|
115 | 118 | InferenceRunner inferenceRunner = mockInferenceRunner(invocation -> { |
116 | 119 | ActionListener<InferenceAction.Response> listener = invocation.getArgument(1); |
117 | | - if ((requests.indexOf(invocation.getArgument(0, InferenceAction.Request.class)) % requests.size()) == 0) { |
118 | | - listener.onFailure(new RuntimeException("inference failure")); |
119 | | - } else { |
120 | | - listener.onResponse(mockInferenceResponse()); |
121 | | - } |
| 120 | + runWithRandomDelay(() -> { |
| 121 | + if ((requests.indexOf(invocation.getArgument(0, InferenceAction.Request.class)) % requests.size()) == 0) { |
| 122 | + listener.onFailure(new RuntimeException("inference failure")); |
| 123 | + } else { |
| 124 | + listener.onResponse(mockInferenceResponse()); |
| 125 | + } |
| 126 | + }); |
122 | 127 |
|
123 | 128 | return null; |
124 | 129 | }); |
125 | 130 |
|
126 | 131 | AtomicReference<Exception> exception = new AtomicReference<>(); |
127 | | - ActionListener<InferenceAction.Response[]> listener = ActionListener.wrap(r -> fail("Expceted exception"), exception::set); |
| 132 | + ActionListener<List<InferenceAction.Response>> listener = ActionListener.wrap(r -> fail("Expceted exception"), exception::set); |
128 | 133 |
|
129 | 134 | bulkExecutor(inferenceRunner).execute(requestIterator(requests), listener); |
130 | 135 |
|
@@ -170,4 +175,16 @@ private InferenceRunner mockInferenceRunner(Answer<Void> doInferenceAnswer) { |
170 | 175 | doAnswer(doInferenceAnswer).when(inferenceRunner).doInference(any(), any()); |
171 | 176 | return inferenceRunner; |
172 | 177 | } |
| 178 | + |
| 179 | + private void runWithRandomDelay(Runnable runnable) { |
| 180 | + if (randomBoolean()) { |
| 181 | + runnable.run(); |
| 182 | + } else { |
| 183 | + threadPool.schedule( |
| 184 | + runnable, |
| 185 | + TimeValue.timeValueNanos(between(1, 1_000)), |
| 186 | + threadPool.executor(EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME) |
| 187 | + ); |
| 188 | + } |
| 189 | + } |
173 | 190 | } |
0 commit comments