Skip to content

Commit 8db5cd2

Browse files
committed
BulkInferenceExecutor unit tests.
1 parent ce48c97 commit 8db5cd2

File tree

3 files changed

+258
-2
lines changed

3 files changed

+258
-2
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutor.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ private void doExecute(
4848
) {
4949
final BulkInferenceExecutionState<OutputType> bulkExecutionState = new BulkInferenceExecutionState<>();
5050

51+
if (requests.hasNext() == false) {
52+
bulkExecutionState.markAllRequestsSent();
53+
bulkExecutionState.maybeSendResponse(outputBuilder::buildOutput, listener);
54+
}
55+
5156
while (requests.hasNext() && bulkExecutionState.responseSent() == false) {
5257
long seqNo = bulkExecutionState.generateSeqNo();
5358
InferenceAction.Request request = requests.next();

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceOutputBuilder.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ public abstract class BulkInferenceOutputBuilder<InferenceResults extends Infere
1818

1919
public abstract OutputType buildOutput();
2020

21-
public abstract void onInferenceResults(InferenceResults results);
21+
public abstract void onInferenceResults(InferenceResults results) throws Exception;
2222

23-
public void onInferenceResponse(InferenceAction.Response response) throws Exception {
23+
public final void onInferenceResponse(InferenceAction.Response response) throws Exception {
2424
InferenceServiceResults results = response.getResults();
2525
if (inferenceResultsClass().isInstance(response.getResults()) == false) {
2626
throw new IllegalStateException(
Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference.bulk;
9+
10+
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.core.TimeValue;
12+
import org.elasticsearch.inference.InferenceServiceResults;
13+
import org.elasticsearch.test.ESTestCase;
14+
import org.elasticsearch.threadpool.TestThreadPool;
15+
import org.elasticsearch.threadpool.ThreadPool;
16+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
17+
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
18+
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
19+
import org.junit.After;
20+
import org.junit.Before;
21+
22+
import java.util.ArrayList;
23+
import java.util.Iterator;
24+
import java.util.List;
25+
import java.util.concurrent.TimeUnit;
26+
import java.util.concurrent.atomic.AtomicReference;
27+
import java.util.stream.Stream;
28+
29+
import static org.hamcrest.Matchers.contains;
30+
import static org.hamcrest.Matchers.empty;
31+
import static org.hamcrest.Matchers.equalTo;
32+
import static org.hamcrest.Matchers.hasSize;
33+
import static org.mockito.ArgumentMatchers.any;
34+
import static org.mockito.ArgumentMatchers.eq;
35+
import static org.mockito.Mockito.doAnswer;
36+
import static org.mockito.Mockito.doThrow;
37+
import static org.mockito.Mockito.mock;
38+
import static org.mockito.Mockito.verify;
39+
import static org.mockito.Mockito.when;
40+
41+
public class BulkInferenceExecutorTests extends ESTestCase {
42+
private ThreadPool threadPool;
43+
44+
@Before
45+
public void setThreadPool() {
46+
threadPool = new TestThreadPool("test");
47+
}
48+
49+
@After
50+
public void shutdownThreadPool() {
51+
terminate(threadPool);
52+
}
53+
54+
@SuppressWarnings("unchecked")
55+
public void testSuccessfulExecution() throws Exception {
56+
List<InferenceAction.Request> requests = Stream.generate(this::mockInferenceRequest).limit(between(1, 100)).toList();
57+
BulkInferenceRequestIterator requestIterator = requestIterator(requests);
58+
List<InferenceAction.Response> responses = Stream.generate(() -> mockInferenceResponse(RankedDocsResults.class))
59+
.limit(requests.size())
60+
.toList();
61+
62+
InferenceRunner inferenceRunner = mock(InferenceRunner.class);
63+
doAnswer((invocation) -> {
64+
ActionListener<InferenceAction.Response> l = invocation.getArgument(1);
65+
if (randomBoolean()) {
66+
Thread.sleep(between(0, 500));
67+
}
68+
l.onResponse(responses.get(requests.indexOf(invocation.getArgument(0, InferenceAction.Request.class))));
69+
return null;
70+
}).when(inferenceRunner).doInference(any(), any());
71+
72+
ActionListener<List<RankedDocsResults>> listener = mock(ActionListener.class);
73+
74+
List<RankedDocsResults> output = new ArrayList<>();
75+
BulkInferenceOutputBuilder<RankedDocsResults, List<RankedDocsResults>> outputBuilder = mock(BulkInferenceOutputBuilder.class);
76+
doAnswer(invocation -> {
77+
output.add((RankedDocsResults) invocation.getArgument(0, InferenceAction.Response.class).getResults());
78+
return null;
79+
}).when(outputBuilder).onInferenceResponse(any());
80+
when(outputBuilder.buildOutput()).thenReturn(output);
81+
82+
BulkInferenceExecutor<RankedDocsResults, List<RankedDocsResults>> executor = bulkExecutor(inferenceRunner);
83+
executor.execute(requestIterator, outputBuilder, listener);
84+
85+
assertBusy(() -> {
86+
assertThat(output, hasSize(requests.size()));
87+
assertThat(output, contains(responses.stream().map(InferenceAction.Response::getResults).toArray()));
88+
verify(listener).onResponse(eq(output));
89+
verify(outputBuilder).close();
90+
verify(requestIterator).close();
91+
});
92+
}
93+
94+
@SuppressWarnings("unchecked")
95+
public void testSuccessfulExecutionOnEmptyRequest() throws Exception {
96+
BulkInferenceRequestIterator requestIterator = mock(BulkInferenceRequestIterator.class);
97+
when(requestIterator.hasNext()).thenReturn(false);
98+
99+
100+
ActionListener<List<RankedDocsResults>> listener = mock(ActionListener.class);
101+
102+
List<RankedDocsResults> output = new ArrayList<>();
103+
BulkInferenceOutputBuilder<RankedDocsResults, List<RankedDocsResults>> outputBuilder = mock(BulkInferenceOutputBuilder.class);
104+
when(outputBuilder.buildOutput()).thenReturn(output);
105+
106+
BulkInferenceExecutor<RankedDocsResults, List<RankedDocsResults>> executor = bulkExecutor(mock(InferenceRunner.class));
107+
executor.execute(requestIterator, outputBuilder, listener);
108+
109+
assertBusy(() -> {
110+
assertThat(output, empty());
111+
verify(listener).onResponse(eq(output));
112+
verify(outputBuilder).close();
113+
verify(requestIterator).close();
114+
});
115+
}
116+
117+
@SuppressWarnings("unchecked")
118+
public void testInferenceRunnerAlwaysFails() throws Exception {
119+
List<InferenceAction.Request> requests = Stream.generate(this::mockInferenceRequest).limit(between(1, 20)).toList();
120+
BulkInferenceRequestIterator requestIterator = requestIterator(requests);
121+
122+
InferenceRunner inferenceRunner = mock(InferenceRunner.class);
123+
doAnswer(invocation -> {
124+
ActionListener<InferenceAction.Response> listener = invocation.getArgument(1);
125+
if (randomBoolean()) {
126+
Thread.sleep(between(0, 500));
127+
}
128+
listener.onFailure(new RuntimeException("inference failure"));
129+
return null;
130+
}).when(inferenceRunner).doInference(any(), any());
131+
132+
ActionListener<List<RankedDocsResults>> listener = mock(ActionListener.class);
133+
AtomicReference<Exception> e = new AtomicReference<>();
134+
doAnswer(i -> {
135+
e.set(i.getArgument(0, Exception.class));
136+
return null;
137+
}).when(listener).onFailure(any());
138+
139+
BulkInferenceOutputBuilder<RankedDocsResults, List<RankedDocsResults>> outputBuilder = mock(BulkInferenceOutputBuilder.class);
140+
141+
BulkInferenceExecutor<RankedDocsResults, List<RankedDocsResults>> executor = bulkExecutor(inferenceRunner);
142+
executor.execute(requestIterator, outputBuilder, listener);
143+
144+
assertBusy(() -> {
145+
verify(listener).onFailure(any(RuntimeException.class));
146+
assertThat(e.get().getMessage(), equalTo("inference failure"));
147+
verify(outputBuilder).close();
148+
verify(requestIterator).close();
149+
});
150+
}
151+
152+
@SuppressWarnings("unchecked")
153+
public void testInferenceRunnerSometimesFails() throws Exception {
154+
List<InferenceAction.Request> requests = Stream.generate(this::mockInferenceRequest).limit(between(2, 20)).toList();
155+
BulkInferenceRequestIterator requestIterator = requestIterator(requests);
156+
157+
InferenceRunner inferenceRunner = mock(InferenceRunner.class);
158+
doAnswer(invocation -> {
159+
ActionListener<InferenceAction.Response> listener = invocation.getArgument(1);
160+
if (randomBoolean()) {
161+
Thread.sleep(between(0, 500));
162+
}
163+
164+
if (requests.indexOf(invocation.getArgument(0, InferenceAction.Request.class)) % requests.size() == 0) {
165+
listener.onFailure(new RuntimeException("inference failure"));
166+
} else {
167+
listener.onResponse(mockInferenceResponse(RankedDocsResults.class));
168+
}
169+
170+
return null;
171+
}).when(inferenceRunner).doInference(any(), any());
172+
173+
ActionListener<List<RankedDocsResults>> listener = mock(ActionListener.class);
174+
AtomicReference<Exception> e = new AtomicReference<>();
175+
doAnswer(i -> {
176+
e.set(i.getArgument(0, Exception.class));
177+
return null;
178+
}).when(listener).onFailure(any());
179+
180+
BulkInferenceOutputBuilder<RankedDocsResults, List<RankedDocsResults>> outputBuilder = mock(BulkInferenceOutputBuilder.class);
181+
182+
BulkInferenceExecutor<RankedDocsResults, List<RankedDocsResults>> executor = bulkExecutor(inferenceRunner);
183+
executor.execute(requestIterator, outputBuilder, listener);
184+
185+
assertBusy(() -> {
186+
verify(listener).onFailure(any(RuntimeException.class));
187+
assertThat(e.get().getMessage(), equalTo("inference failure"));
188+
verify(outputBuilder).close();
189+
verify(requestIterator).close();
190+
});
191+
}
192+
193+
@SuppressWarnings("unchecked")
194+
public void testBuildOutputFailure() throws Exception {
195+
List<InferenceAction.Request> requests = Stream.generate(this::mockInferenceRequest).limit(between(1, 20)).toList();
196+
BulkInferenceRequestIterator requestIterator = requestIterator(requests);
197+
198+
InferenceRunner inferenceRunner = mock(InferenceRunner.class);
199+
doAnswer(invocation -> {
200+
ActionListener<InferenceAction.Response> listener = invocation.getArgument(1);
201+
listener.onResponse(mockInferenceResponse(RankedDocsResults.class));
202+
return null;
203+
}).when(inferenceRunner).doInference(any(), any());
204+
205+
ActionListener<List<RankedDocsResults>> listener = mock(ActionListener.class);
206+
AtomicReference<Exception> e = new AtomicReference<>();
207+
doAnswer(i -> {
208+
e.set(i.getArgument(0, Exception.class));
209+
return null;
210+
}).when(listener).onFailure(any());
211+
212+
BulkInferenceOutputBuilder<RankedDocsResults, List<RankedDocsResults>> outputBuilder = mock(BulkInferenceOutputBuilder.class);
213+
doThrow(new IllegalStateException("build output failure")).when(outputBuilder.buildOutput());
214+
215+
BulkInferenceExecutor<RankedDocsResults, List<RankedDocsResults>> executor = bulkExecutor(inferenceRunner);
216+
executor.execute(requestIterator, outputBuilder, listener);
217+
218+
assertBusy(() -> {
219+
verify(listener).onFailure(any(IllegalStateException.class));
220+
assertThat(e.get().getMessage(), equalTo("build output failure"));
221+
verify(outputBuilder).close();
222+
verify(requestIterator).close();
223+
});
224+
}
225+
226+
private BulkInferenceExecutor<RankedDocsResults, List<RankedDocsResults>> bulkExecutor(InferenceRunner inferenceRunner) {
227+
return new BulkInferenceExecutor<>(inferenceRunner, threadPool, randomBulkExecutionConfig());
228+
}
229+
230+
private InferenceAction.Request mockInferenceRequest() {
231+
return mock(InferenceAction.Request.class);
232+
}
233+
234+
private InferenceAction.Response mockInferenceResponse(Class<? extends InferenceServiceResults> resultClass) {
235+
InferenceAction.Response response = mock(InferenceAction.Response.class);
236+
when(response.getResults()).thenReturn(mock(resultClass));
237+
return response;
238+
}
239+
240+
private BulkInferenceExecutionConfig randomBulkExecutionConfig() {
241+
return new BulkInferenceExecutionConfig(new TimeValue(between(1, 30), TimeUnit.SECONDS), between(1, 100));
242+
}
243+
244+
private BulkInferenceRequestIterator requestIterator(List<InferenceAction.Request> requests) {
245+
final Iterator<InferenceAction.Request> delegate = requests.iterator();
246+
BulkInferenceRequestIterator iterator = mock(BulkInferenceRequestIterator.class);
247+
doAnswer(i -> delegate.hasNext()).when(iterator).hasNext();
248+
doAnswer(i -> delegate.next()).when(iterator).next();
249+
return iterator;
250+
}
251+
}

0 commit comments

Comments
 (0)