Skip to content

Commit fe091b4

Browse files
committed
Improving BulkInferenceExecutorTests
1 parent 2764b19 commit fe091b4

File tree

1 file changed

+80
-62
lines changed

1 file changed

+80
-62
lines changed

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutorTests.java

Lines changed: 80 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
2323
import org.junit.After;
2424
import org.junit.Before;
25+
import org.mockito.stubbing.Answer;
2526

2627
import java.util.ArrayList;
2728
import java.util.Iterator;
@@ -30,15 +31,17 @@
3031
import java.util.concurrent.atomic.AtomicReference;
3132
import java.util.stream.Stream;
3233

34+
import static org.hamcrest.Matchers.allOf;
3335
import static org.hamcrest.Matchers.contains;
3436
import static org.hamcrest.Matchers.empty;
3537
import static org.hamcrest.Matchers.equalTo;
3638
import static org.hamcrest.Matchers.hasSize;
39+
import static org.hamcrest.Matchers.notNullValue;
3740
import static org.mockito.ArgumentMatchers.any;
38-
import static org.mockito.ArgumentMatchers.eq;
3941
import static org.mockito.Mockito.doAnswer;
4042
import static org.mockito.Mockito.doThrow;
4143
import static org.mockito.Mockito.mock;
44+
import static org.mockito.Mockito.never;
4245
import static org.mockito.Mockito.verify;
4346
import static org.mockito.Mockito.when;
4447

@@ -65,42 +68,54 @@ public void shutdownThreadPool() {
6568
terminate(threadPool);
6669
}
6770

71+
@SuppressWarnings("unchecked")
72+
private <T extends InferenceServiceResults> BulkInferenceOutputBuilder<T, List<T>> mockOutputBuilder(Class<T> resultClass)
73+
throws Exception {
74+
BulkInferenceOutputBuilder<T, List<T>> outputBuilder = mock(BulkInferenceOutputBuilder.class);
75+
List<T> output = new ArrayList<>();
76+
doAnswer(invocation -> {
77+
output.add(invocation.getArgument(0, resultClass));
78+
return null;
79+
}).when(outputBuilder).onInferenceResults(any());
80+
when(outputBuilder.buildOutput()).thenReturn(output);
81+
when(outputBuilder.inferenceResultsClass()).thenReturn(resultClass);
82+
83+
return outputBuilder;
84+
}
85+
6886
@SuppressWarnings("unchecked")
6987
public void testSuccessfulExecution() throws Exception {
70-
List<InferenceAction.Request> requests = Stream.generate(this::mockInferenceRequest).limit(between(1, 50)).toList();
71-
BulkInferenceRequestIterator requestIterator = requestIterator(requests);
72-
List<InferenceAction.Response> responses = Stream.generate(() -> mockInferenceResponse(RankedDocsResults.class))
73-
.limit(requests.size())
74-
.toList();
88+
List<InferenceAction.Request> requests = randomInferenceRequestList(between(1, 50));
89+
List<InferenceAction.Response> responses = randomInferenceResponseList(requests.size(), RankedDocsResults.class);
7590

76-
InferenceRunner inferenceRunner = mock(InferenceRunner.class);
77-
doAnswer((invocation) -> {
91+
InferenceRunner inferenceRunner = mockInferenceRunner(invocation -> {
7892
ActionListener<InferenceAction.Response> l = invocation.getArgument(1);
79-
if (randomBoolean()) {
80-
Thread.sleep(between(0, 5));
81-
}
8293
l.onResponse(responses.get(requests.indexOf(invocation.getArgument(0, InferenceAction.Request.class))));
8394
return null;
84-
}).when(inferenceRunner).doInference(any(), any());
95+
});
8596

97+
AtomicReference<List<RankedDocsResults>> output = new AtomicReference<>();
8698
ActionListener<List<RankedDocsResults>> listener = mock(ActionListener.class);
87-
88-
List<RankedDocsResults> output = new ArrayList<>();
89-
BulkInferenceOutputBuilder<RankedDocsResults, List<RankedDocsResults>> outputBuilder = mock(BulkInferenceOutputBuilder.class);
9099
doAnswer(invocation -> {
91-
output.add(invocation.getArgument(0, RankedDocsResults.class));
100+
output.set(invocation.getArgument(0, List.class));
92101
return null;
93-
}).when(outputBuilder).onInferenceResults(any());
94-
when(outputBuilder.buildOutput()).thenReturn(output);
95-
when(outputBuilder.inferenceResultsClass()).thenReturn(RankedDocsResults.class);
102+
}).when(listener).onResponse(any());
96103

97-
BulkInferenceExecutor executor = bulkExecutor(inferenceRunner);
98-
executor.execute(requestIterator, outputBuilder, listener);
104+
BulkInferenceOutputBuilder<RankedDocsResults, List<RankedDocsResults>> outputBuilder = mockOutputBuilder(RankedDocsResults.class);
105+
106+
bulkExecutor(inferenceRunner).execute(requestIterator(requests), outputBuilder, listener);
99107

100108
assertBusy(() -> {
101-
assertThat(output, hasSize(requests.size()));
102-
assertThat(output, contains(responses.stream().map(InferenceAction.Response::getResults).toArray()));
103-
verify(listener).onResponse(eq(output));
109+
verify(listener).onResponse(any());
110+
verify(listener, never()).onFailure(any());
111+
assertThat(
112+
output.get(),
113+
allOf(
114+
notNullValue(),
115+
hasSize(requests.size()),
116+
contains(responses.stream().map(InferenceAction.Response::getResults).toArray())
117+
)
118+
);
104119
});
105120
}
106121

@@ -109,35 +124,33 @@ public void testSuccessfulExecutionOnEmptyRequest() throws Exception {
109124
BulkInferenceRequestIterator requestIterator = mock(BulkInferenceRequestIterator.class);
110125
when(requestIterator.hasNext()).thenReturn(false);
111126

127+
AtomicReference<List<RankedDocsResults>> output = new AtomicReference<>();
112128
ActionListener<List<RankedDocsResults>> listener = mock(ActionListener.class);
129+
doAnswer(invocation -> {
130+
output.set(invocation.getArgument(0, List.class));
131+
return null;
132+
}).when(listener).onResponse(any());
113133

114-
List<RankedDocsResults> output = new ArrayList<>();
115-
BulkInferenceOutputBuilder<RankedDocsResults, List<RankedDocsResults>> outputBuilder = mock(BulkInferenceOutputBuilder.class);
116-
when(outputBuilder.buildOutput()).thenReturn(output);
134+
BulkInferenceOutputBuilder<RankedDocsResults, List<RankedDocsResults>> outputBuilder = mockOutputBuilder(RankedDocsResults.class);
117135

118-
BulkInferenceExecutor executor = bulkExecutor(mock(InferenceRunner.class));
119-
executor.execute(requestIterator, outputBuilder, listener);
136+
bulkExecutor(mock(InferenceRunner.class)).execute(requestIterator, outputBuilder, listener);
120137

121138
assertBusy(() -> {
122-
assertThat(output, empty());
123-
verify(listener).onResponse(eq(output));
139+
verify(listener).onResponse(any());
140+
verify(listener, never()).onFailure(any());
141+
assertThat(output.get(), allOf(notNullValue(), empty()));
124142
});
125143
}
126144

127145
@SuppressWarnings("unchecked")
128146
public void testInferenceRunnerAlwaysFails() throws Exception {
129-
List<InferenceAction.Request> requests = Stream.generate(this::mockInferenceRequest).limit(between(1, 30)).toList();
130-
BulkInferenceRequestIterator requestIterator = requestIterator(requests);
147+
List<InferenceAction.Request> requests = randomInferenceRequestList(between(1, 30));
131148

132-
InferenceRunner inferenceRunner = mock(InferenceRunner.class);
133-
doAnswer(invocation -> {
149+
InferenceRunner inferenceRunner = mock(invocation -> {
134150
ActionListener<InferenceAction.Response> listener = invocation.getArgument(1);
135-
if (randomBoolean()) {
136-
Thread.sleep(between(0, 5));
137-
}
138151
listener.onFailure(new RuntimeException("inference failure"));
139152
return null;
140-
}).when(inferenceRunner).doInference(any(), any());
153+
});
141154

142155
ActionListener<List<RankedDocsResults>> listener = mock(ActionListener.class);
143156
AtomicReference<Exception> e = new AtomicReference<>();
@@ -146,37 +159,31 @@ public void testInferenceRunnerAlwaysFails() throws Exception {
146159
return null;
147160
}).when(listener).onFailure(any());
148161

149-
BulkInferenceOutputBuilder<RankedDocsResults, List<RankedDocsResults>> outputBuilder = mock(BulkInferenceOutputBuilder.class);
162+
BulkInferenceOutputBuilder<RankedDocsResults, List<RankedDocsResults>> outputBuilder = mockOutputBuilder(RankedDocsResults.class);
150163

151-
BulkInferenceExecutor executor = bulkExecutor(inferenceRunner);
152-
executor.execute(requestIterator, outputBuilder, listener);
164+
bulkExecutor(inferenceRunner).execute(requestIterator(requests), outputBuilder, listener);
153165

154166
assertBusy(() -> {
155167
verify(listener).onFailure(any(RuntimeException.class));
168+
verify(listener, never()).onResponse(any());
156169
assertThat(e.get().getMessage(), equalTo("inference failure"));
157170
});
158171
}
159172

160173
@SuppressWarnings("unchecked")
161174
public void testInferenceRunnerSometimesFails() throws Exception {
162-
List<InferenceAction.Request> requests = Stream.generate(this::mockInferenceRequest).limit(between(2, 30)).toList();
163-
BulkInferenceRequestIterator requestIterator = requestIterator(requests);
175+
List<InferenceAction.Request> requests = randomInferenceRequestList(between(2, 30));
164176

165-
InferenceRunner inferenceRunner = mock(InferenceRunner.class);
166-
doAnswer(invocation -> {
177+
InferenceRunner inferenceRunner = mockInferenceRunner(invocation -> {
167178
ActionListener<InferenceAction.Response> listener = invocation.getArgument(1);
168-
if (randomBoolean()) {
169-
Thread.sleep(between(0, 5));
170-
}
171-
172179
if ((requests.indexOf(invocation.getArgument(0, InferenceAction.Request.class)) % requests.size()) == 0) {
173180
listener.onFailure(new RuntimeException("inference failure"));
174181
} else {
175182
listener.onResponse(mockInferenceResponse(RankedDocsResults.class));
176183
}
177184

178185
return null;
179-
}).when(inferenceRunner).doInference(any(), any());
186+
});
180187

181188
ActionListener<List<RankedDocsResults>> listener = mock(ActionListener.class);
182189
AtomicReference<Exception> e = new AtomicReference<>();
@@ -185,29 +192,25 @@ public void testInferenceRunnerSometimesFails() throws Exception {
185192
return null;
186193
}).when(listener).onFailure(any());
187194

188-
BulkInferenceOutputBuilder<RankedDocsResults, List<RankedDocsResults>> outputBuilder = mock(BulkInferenceOutputBuilder.class);
189-
when(outputBuilder.inferenceResultsClass()).thenReturn(RankedDocsResults.class);
190-
191-
BulkInferenceExecutor executor = bulkExecutor(inferenceRunner);
192-
executor.execute(requestIterator, outputBuilder, listener);
195+
BulkInferenceOutputBuilder<RankedDocsResults, List<RankedDocsResults>> outputBuilder = mockOutputBuilder(RankedDocsResults.class);
196+
bulkExecutor(inferenceRunner).execute(requestIterator(requests), outputBuilder, listener);
193197

194198
assertBusy(() -> {
195199
verify(listener).onFailure(any(RuntimeException.class));
200+
verify(listener, never()).onResponse(any());
196201
assertThat(e.get().getMessage(), equalTo("inference failure"));
197202
});
198203
}
199204

200205
@SuppressWarnings("unchecked")
201206
public void testBuildOutputFailure() throws Exception {
202-
List<InferenceAction.Request> requests = Stream.generate(this::mockInferenceRequest).limit(between(1, 30)).toList();
203-
BulkInferenceRequestIterator requestIterator = requestIterator(requests);
207+
List<InferenceAction.Request> requests = randomInferenceRequestList(between(1, 30));
204208

205-
InferenceRunner inferenceRunner = mock(InferenceRunner.class);
206-
doAnswer(invocation -> {
209+
InferenceRunner inferenceRunner = mockInferenceRunner(invocation -> {
207210
ActionListener<InferenceAction.Response> listener = invocation.getArgument(1);
208211
listener.onResponse(mockInferenceResponse(RankedDocsResults.class));
209212
return null;
210-
}).when(inferenceRunner).doInference(any(), any());
213+
});
211214

212215
ActionListener<List<RankedDocsResults>> listener = mock(ActionListener.class);
213216
AtomicReference<Exception> e = new AtomicReference<>();
@@ -222,10 +225,11 @@ public void testBuildOutputFailure() throws Exception {
222225

223226
BulkInferenceExecutor executor = bulkExecutor(inferenceRunner);
224227

225-
executor.execute(requestIterator, outputBuilder, listener);
228+
bulkExecutor(inferenceRunner).execute(requestIterator(requests), outputBuilder, listener);
226229

227230
assertBusy(() -> {
228231
verify(listener).onFailure(any(IllegalStateException.class));
232+
verify(listener, never()).onResponse(any());
229233
assertThat(e.get().getMessage(), equalTo("build output failure"));
230234
});
231235
}
@@ -255,4 +259,18 @@ private BulkInferenceRequestIterator requestIterator(List<InferenceAction.Reques
255259
doAnswer(i -> delegate.next()).when(iterator).next();
256260
return iterator;
257261
}
262+
263+
private List<InferenceAction.Request> randomInferenceRequestList(int size) {
264+
return Stream.generate(this::mockInferenceRequest).limit(size).toList();
265+
}
266+
267+
private List<InferenceAction.Response> randomInferenceResponseList(int size, Class<? extends InferenceServiceResults> resultClass) {
268+
return Stream.generate(() -> this.mockInferenceResponse(resultClass)).limit(size).toList();
269+
}
270+
271+
private InferenceRunner mockInferenceRunner(Answer<Void> doInferenceAnswer) {
272+
InferenceRunner inferenceRunner = mock(InferenceRunner.class);
273+
doAnswer(doInferenceAnswer).when(inferenceRunner).doInference(any(), any());
274+
return inferenceRunner;
275+
}
258276
}

0 commit comments

Comments
 (0)