Skip to content

Commit ef60c09

Browse files
committed
Add more test case.
1 parent 15ec18d commit ef60c09

File tree

2 files changed

+34
-10
lines changed

2 files changed

+34
-10
lines changed

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunnerTests.java

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.elasticsearch.common.settings.Settings;
1313
import org.elasticsearch.common.util.concurrent.EsExecutors;
1414
import org.elasticsearch.core.TimeValue;
15+
import org.elasticsearch.logging.LogManager;
1516
import org.elasticsearch.test.ESTestCase;
1617
import org.elasticsearch.test.client.NoOpClient;
1718
import org.elasticsearch.threadpool.FixedExecutorBuilder;
@@ -27,6 +28,7 @@
2728
import java.util.ArrayList;
2829
import java.util.Iterator;
2930
import java.util.List;
31+
import java.util.concurrent.CountDownLatch;
3032
import java.util.concurrent.atomic.AtomicReference;
3133

3234
import static org.hamcrest.Matchers.allOf;
@@ -63,7 +65,7 @@ public void shutdownThreadPool() {
6365
}
6466

6567
public void testSuccessfulBulkExecution() throws Exception {
66-
List<InferenceAction.Request> requests = randomInferenceRequestList(between(1, 1000));
68+
List<InferenceAction.Request> requests = randomInferenceRequestList(between(1, 1_000));
6769
List<InferenceAction.Response> responses = randomInferenceResponseList(requests.size());
6870

6971
Client client = mockClient(invocation -> {
@@ -117,7 +119,7 @@ public void testBulkExecutionWhenInferenceRunnerAlwaysFails() throws Exception {
117119
}
118120

119121
public void testBulkExecutionWhenInferenceRunnerSometimesFails() throws Exception {
120-
List<InferenceAction.Request> requests = randomInferenceRequestList(between(1, 1000));
122+
List<InferenceAction.Request> requests = randomInferenceRequestList(between(1, 1_000));
121123

122124
Client client = mockClient(invocation -> {
123125
ActionListener<InferenceAction.Response> listener = invocation.getArgument(2);
@@ -143,6 +145,34 @@ public void testBulkExecutionWhenInferenceRunnerSometimesFails() throws Exceptio
143145
});
144146
}
145147

148+
public void testParallelBulkExecution() throws Exception {
149+
int batches = between(50, 100);
150+
CountDownLatch latch = new CountDownLatch(batches);
151+
152+
for (int i = 0; i < batches; i++) {
153+
List<InferenceAction.Request> requests = randomInferenceRequestList(between(1, 1_000));
154+
List<InferenceAction.Response> responses = randomInferenceResponseList(requests.size());
155+
156+
Client client = mockClient(invocation -> {
157+
runWithRandomDelay(() -> {
158+
ActionListener<InferenceAction.Response> l = invocation.getArgument(2);
159+
l.onResponse(responses.get(requests.indexOf(invocation.getArgument(1, InferenceAction.Request.class))));
160+
});
161+
return null;
162+
});
163+
164+
ActionListener<List<InferenceAction.Response>> listener = ActionListener.wrap(r -> {
165+
assertThat(r, equalTo(responses));
166+
LogManager.getLogger(BulkInferenceRunnerTests.class).warn("Received [{}] responses", responses.size());
167+
latch.countDown();
168+
}, ESTestCase::fail);
169+
170+
inferenceRunnerFactory(client).create(randomBulkExecutionConfig()).executeBulk(requestIterator(requests), listener);
171+
}
172+
173+
latch.await();
174+
}
175+
146176
private BulkInferenceRunner.Factory inferenceRunnerFactory(Client client) {
147177
return BulkInferenceRunner.factory(client);
148178
}
@@ -198,7 +228,7 @@ private void runWithRandomDelay(Runnable runnable) {
198228
if (randomBoolean()) {
199229
runnable.run();
200230
} else {
201-
threadPool.schedule(runnable, TimeValue.timeValueNanos(between(1, 1_000)), threadPool.generic());
231+
threadPool.schedule(runnable, TimeValue.timeValueNanos(between(1, 100_000)), threadPool.generic());
202232
}
203233
}
204234
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorTests.java

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,7 @@ protected Matcher<String> expectedDescriptionOfSimple() {
9393
@Override
9494
protected Matcher<String> expectedToStringOfSimple() {
9595
return equalTo(
96-
"RerankOperator[inference_id=["
97-
+ SIMPLE_INFERENCE_ID
98-
+ "], query=["
99-
+ SIMPLE_QUERY
100-
+ "], score_channel=["
101-
+ scoreChannel
102-
+ "]]"
96+
"RerankOperator[inference_id=[" + SIMPLE_INFERENCE_ID + "], query=[" + SIMPLE_QUERY + "], score_channel=[" + scoreChannel + "]]"
10397
);
10498
}
10599

0 commit comments

Comments
 (0)