Skip to content

Commit 2764b19

Browse files
committed
Fixing some tests.
1 parent 87bfc0c commit 2764b19

File tree

3 files changed

+17
-20
lines changed

3 files changed

+17
-20
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutionState.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,14 @@ public void markSeqNoAsPersisted(long seqNo) {
4747
checkpoint.markSeqNoAsPersisted(seqNo);
4848
}
4949

50-
public void persistsResponse(CheckedConsumer<InferenceAction.Response, Exception> persister) {
50+
public void persistsResponses(CheckedConsumer<InferenceAction.Response, Exception> persister) {
5151
synchronized (checkpoint) {
5252
long persistedSeqNo = checkpoint.getPersistedCheckpoint();
5353
while (persistedSeqNo < checkpoint.getProcessedCheckpoint()) {
5454
persistedSeqNo++;
5555
InferenceAction.Response response = bufferedResponses.remove(persistedSeqNo);
56-
if (hasFailure() == false) {
56+
assert response != null || hasFailure();
57+
if (hasFailure() == false && responseSent() == false) {
5758
try {
5859
persister.accept(response);
5960
} catch (Exception e) {
@@ -90,14 +91,14 @@ public <OutputType> void maybeSendResponse(CheckedSupplier<OutputType, Exception
9091
if (failureCollector.hasFailure() == false) {
9192
try {
9293
l.onResponse(responseBuilder.get());
94+
return;
9395
} catch (Exception e) {
94-
failureCollector.unwrapAndCollect(e);
96+
l.onFailure(e);
97+
return;
9598
}
9699
}
97100

98-
if (failureCollector.hasFailure()) {
99-
l.onFailure(failureCollector.getFailure());
100-
}
101+
l.onFailure(failureCollector.getFailure());
101102
}
102103
}
103104

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutor.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,7 @@ public <InferenceResult extends InferenceServiceResults, OutputType> void execut
5151
e -> bulkExecutionState.onInferenceException(seqNo, e)
5252
),
5353
() -> {
54-
if (bulkExecutionState.responseSent()) {
55-
return;
56-
}
57-
58-
bulkExecutionState.persistsResponse(outputBuilder::onInferenceResponse);
54+
bulkExecutionState.persistsResponses(outputBuilder::onInferenceResponse);
5955
bulkExecutionState.maybeSendResponse(outputBuilder::buildOutput, listener);
6056
}
6157
)

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutorTests.java

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ public void shutdownThreadPool() {
6767

6868
@SuppressWarnings("unchecked")
6969
public void testSuccessfulExecution() throws Exception {
70-
List<InferenceAction.Request> requests = Stream.generate(this::mockInferenceRequest).limit(between(1, 100)).toList();
70+
List<InferenceAction.Request> requests = Stream.generate(this::mockInferenceRequest).limit(between(1, 50)).toList();
7171
BulkInferenceRequestIterator requestIterator = requestIterator(requests);
7272
List<InferenceAction.Response> responses = Stream.generate(() -> mockInferenceResponse(RankedDocsResults.class))
7373
.limit(requests.size())
@@ -77,7 +77,7 @@ public void testSuccessfulExecution() throws Exception {
7777
doAnswer((invocation) -> {
7878
ActionListener<InferenceAction.Response> l = invocation.getArgument(1);
7979
if (randomBoolean()) {
80-
Thread.sleep(between(0, 50));
80+
Thread.sleep(between(0, 5));
8181
}
8282
l.onResponse(responses.get(requests.indexOf(invocation.getArgument(0, InferenceAction.Request.class))));
8383
return null;
@@ -101,7 +101,7 @@ public void testSuccessfulExecution() throws Exception {
101101
assertThat(output, hasSize(requests.size()));
102102
assertThat(output, contains(responses.stream().map(InferenceAction.Response::getResults).toArray()));
103103
verify(listener).onResponse(eq(output));
104-
}, 60, TimeUnit.SECONDS);
104+
});
105105
}
106106

107107
@SuppressWarnings("unchecked")
@@ -126,14 +126,14 @@ public void testSuccessfulExecutionOnEmptyRequest() throws Exception {
126126

127127
@SuppressWarnings("unchecked")
128128
public void testInferenceRunnerAlwaysFails() throws Exception {
129-
List<InferenceAction.Request> requests = Stream.generate(this::mockInferenceRequest).limit(between(1, 10)).toList();
129+
List<InferenceAction.Request> requests = Stream.generate(this::mockInferenceRequest).limit(between(1, 30)).toList();
130130
BulkInferenceRequestIterator requestIterator = requestIterator(requests);
131131

132132
InferenceRunner inferenceRunner = mock(InferenceRunner.class);
133133
doAnswer(invocation -> {
134134
ActionListener<InferenceAction.Response> listener = invocation.getArgument(1);
135135
if (randomBoolean()) {
136-
Thread.sleep(between(0, 500));
136+
Thread.sleep(between(0, 5));
137137
}
138138
listener.onFailure(new RuntimeException("inference failure"));
139139
return null;
@@ -159,17 +159,17 @@ public void testInferenceRunnerAlwaysFails() throws Exception {
159159

160160
@SuppressWarnings("unchecked")
161161
public void testInferenceRunnerSometimesFails() throws Exception {
162-
List<InferenceAction.Request> requests = Stream.generate(this::mockInferenceRequest).limit(between(2, 10)).toList();
162+
List<InferenceAction.Request> requests = Stream.generate(this::mockInferenceRequest).limit(between(2, 30)).toList();
163163
BulkInferenceRequestIterator requestIterator = requestIterator(requests);
164164

165165
InferenceRunner inferenceRunner = mock(InferenceRunner.class);
166166
doAnswer(invocation -> {
167167
ActionListener<InferenceAction.Response> listener = invocation.getArgument(1);
168168
if (randomBoolean()) {
169-
Thread.sleep(between(0, 500));
169+
Thread.sleep(between(0, 5));
170170
}
171171

172-
if (requests.indexOf(invocation.getArgument(0, InferenceAction.Request.class)) % requests.size() == 0) {
172+
if ((requests.indexOf(invocation.getArgument(0, InferenceAction.Request.class)) % requests.size()) == 0) {
173173
listener.onFailure(new RuntimeException("inference failure"));
174174
} else {
175175
listener.onResponse(mockInferenceResponse(RankedDocsResults.class));
@@ -199,7 +199,7 @@ public void testInferenceRunnerSometimesFails() throws Exception {
199199

200200
@SuppressWarnings("unchecked")
201201
public void testBuildOutputFailure() throws Exception {
202-
List<InferenceAction.Request> requests = Stream.generate(this::mockInferenceRequest).limit(between(1, 10)).toList();
202+
List<InferenceAction.Request> requests = Stream.generate(this::mockInferenceRequest).limit(between(1, 30)).toList();
203203
BulkInferenceRequestIterator requestIterator = requestIterator(requests);
204204

205205
InferenceRunner inferenceRunner = mock(InferenceRunner.class);

0 commit comments

Comments
 (0)