Skip to content

Commit 6fa285a

Browse files
author
afoucret
committed
[ESQL] Remove result BulkInferenceRunner::executeBulk(BulkInferenceRequestItemIterator requests, ActionListener<List<BulkInferenceResponse>> listener) method.
1 parent dee21ef commit 6fa285a

File tree

4 files changed

+48
-60
lines changed

4 files changed

+48
-60
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceResponse;
2121
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner;
2222

23+
import java.util.ArrayList;
2324
import java.util.List;
2425

2526
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
@@ -79,7 +80,13 @@ protected void performAsync(Page input, ActionListener<OngoingInferenceResult> l
7980
BulkInferenceRequestItemIterator requests = requests(input);
8081
listener = ActionListener.releaseBefore(requests, listener);
8182

82-
bulkInferenceRunner.executeBulk(requests, listener.map(responses -> new OngoingInferenceResult(input, responses)));
83+
OngoingInferenceResult result = new OngoingInferenceResult(input, new ArrayList<>());
84+
listener = listener.delegateResponse((l, e) -> {
85+
Releasables.close(result);
86+
l.onFailure(e);
87+
});
88+
89+
bulkInferenceRunner.executeBulk(requests, result.responses()::add, listener.map(responses -> result));
8390
} catch (Exception e) {
8491
listener.onFailure(e);
8592
}
@@ -170,9 +177,6 @@ default void releasePageOnAnyThread(Page page) {
170177
/**
171178
* Represents the result of an ongoing inference operation, including the original input page
172179
* and the list of inference responses.
173-
*
174-
* @param inputPage The input page used to generate inference requests.
175-
* @param responses The inference responses returned by the inference service.
176180
*/
177181
public record OngoingInferenceResult(Page inputPage, List<BulkInferenceResponse> responses) implements Releasable {
178182

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
import org.elasticsearch.threadpool.ThreadPool;
1515
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1616

17-
import java.util.ArrayList;
18-
import java.util.List;
1917
import java.util.Queue;
2018
import java.util.Set;
2119
import java.util.concurrent.ConcurrentLinkedQueue;
@@ -91,17 +89,6 @@ public BulkInferenceRunner(Client client, int maxRunningTasks) {
9189
this.executor = client.threadPool().executor(ThreadPool.Names.SEARCH);
9290
}
9391

94-
/**
95-
* Executes multiple inference requests in bulk and collects all responses.
96-
*
97-
* @param requests An iterator over the inference requests to execute
98-
* @param listener Called with the list of all responses in request order
99-
*/
100-
public void executeBulk(BulkInferenceRequestItemIterator requests, ActionListener<List<BulkInferenceResponse>> listener) {
101-
List<BulkInferenceResponse> responses = new ArrayList<>();
102-
executeBulk(requests, responses::add, listener.delegateFailureIgnoreResponseAndWrap(l -> l.onResponse(responses)));
103-
}
104-
10592
/**
10693
* Executes multiple inference requests in bulk with streaming response handling.
10794
* <p>

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluatorTests.java

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
import java.util.List;
3939
import java.util.concurrent.atomic.AtomicReference;
40+
import java.util.function.Consumer;
4041

4142
import static org.elasticsearch.xpack.esql.EsqlTestUtils.as;
4243
import static org.hamcrest.Matchers.containsString;
@@ -80,19 +81,14 @@ public void testFoldTextEmbeddingFunction() throws Exception {
8081
BulkInferenceRunner bulkInferenceRunner = mock(BulkInferenceRunner.class);
8182

8283
doAnswer(i -> {
83-
threadPool.schedule(
84-
() -> i.getArgument(1, ActionListener.class)
85-
.onResponse(
86-
List.of(
87-
new BulkInferenceResponse(new BulkInferenceRequestItem(null, new int[] { 1 }), inferenceResponse(embedding))
88-
)
89-
),
90-
TimeValue.timeValueMillis(between(1, 10)),
91-
threadPool.generic()
92-
);
84+
threadPool.schedule(() -> {
85+
i.getArgument(1, Consumer.class)
86+
.accept(new BulkInferenceResponse(new BulkInferenceRequestItem(null, new int[] { 1 }), inferenceResponse(embedding)));
87+
i.getArgument(2, ActionListener.class).onResponse(null);
88+
}, TimeValue.timeValueMillis(between(1, 10)), threadPool.generic());
9389

9490
return null;
95-
}).when(bulkInferenceRunner).executeBulk(any(), any());
91+
}).when(bulkInferenceRunner).executeBulk(any(), any(), any());
9692
when(bulkInferenceRunner.threadPool()).thenReturn(threadPool);
9793

9894
when(inferenceService.bulkInferenceRunner()).thenReturn(bulkInferenceRunner);
@@ -275,23 +271,16 @@ public void testFoldCompletionFunction() throws Exception {
275271
BulkInferenceRunner bulkInferenceRunner = mock(BulkInferenceRunner.class);
276272

277273
doAnswer(i -> {
278-
threadPool.schedule(
279-
280-
() -> i.getArgument(1, ActionListener.class)
281-
.onResponse(
282-
List.of(
283-
new BulkInferenceResponse(
284-
new BulkInferenceRequestItem(null, new int[] { 1 }),
285-
completionResponse(completionText)
286-
)
287-
)
288-
),
289-
TimeValue.timeValueMillis(between(1, 10)),
290-
threadPool.generic()
291-
);
274+
threadPool.schedule(() -> {
275+
i.getArgument(1, Consumer.class)
276+
.accept(
277+
new BulkInferenceResponse(new BulkInferenceRequestItem(null, new int[] { 1 }), completionResponse(completionText))
278+
);
279+
i.getArgument(2, ActionListener.class).onResponse(null);
280+
}, TimeValue.timeValueMillis(between(1, 10)), threadPool.generic());
292281

293282
return null;
294-
}).when(bulkInferenceRunner).executeBulk(any(), any());
283+
}).when(bulkInferenceRunner).executeBulk(any(), any(), any());
295284
when(bulkInferenceRunner.threadPool()).thenReturn(threadPool);
296285

297286
when(inferenceService.bulkInferenceRunner()).thenReturn(bulkInferenceRunner);

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunnerTests.java

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,12 @@
2424
import java.util.List;
2525
import java.util.concurrent.CountDownLatch;
2626
import java.util.concurrent.TimeUnit;
27+
import java.util.concurrent.atomic.AtomicBoolean;
2728
import java.util.concurrent.atomic.AtomicReference;
2829

29-
import static org.hamcrest.Matchers.allOf;
3030
import static org.hamcrest.Matchers.empty;
3131
import static org.hamcrest.Matchers.equalTo;
32+
import static org.hamcrest.Matchers.not;
3233
import static org.hamcrest.Matchers.notNullValue;
3334
import static org.mockito.ArgumentMatchers.any;
3435
import static org.mockito.ArgumentMatchers.eq;
@@ -61,29 +62,35 @@ public void testSuccessfulBulkExecution() throws Exception {
6162
return null;
6263
});
6364

64-
AtomicReference<List<BulkInferenceResponse>> output = new AtomicReference<>();
65-
ActionListener<List<BulkInferenceResponse>> listener = ActionListener.wrap(output::set, ESTestCase::fail);
65+
List<BulkInferenceResponse> output = new ArrayList<>();
66+
AtomicBoolean completed = new AtomicBoolean(false);
67+
ActionListener<Void> listener = ActionListener.wrap(r -> completed.set(true), e -> fail("Did not expect an exception"));
6668

6769
inferenceRunnerFactory(client).create(randomBulkExecutionConfig())
68-
.executeBulk(requestIterator(requests), assertAnswerUsingSearchThreadPool(listener));
70+
.executeBulk(requestIterator(requests), output::add, assertAnswerUsingSearchThreadPool(listener));
6971

7072
assertBusy(() -> {
71-
assertThat(output.get(), notNullValue());
72-
assertThat(output.get().stream().map(BulkInferenceResponse::response).toList(), equalTo(responses));
73+
assertThat(completed.get(), equalTo(true));
74+
assertThat(output, not(empty()));
75+
assertThat(output.stream().map(BulkInferenceResponse::response).toList(), equalTo(responses));
7376
});
7477
}
7578

7679
public void testSuccessfulBulkExecutionOnEmptyRequest() throws Exception {
7780
BulkInferenceRequestItemIterator requestIterator = mock(BulkInferenceRequestItemIterator.class);
7881
when(requestIterator.hasNext()).thenReturn(false);
7982

80-
AtomicReference<List<BulkInferenceResponse>> output = new AtomicReference<>();
81-
ActionListener<List<BulkInferenceResponse>> listener = ActionListener.wrap(output::set, ESTestCase::fail);
83+
List<BulkInferenceResponse> output = new ArrayList<>();
84+
AtomicBoolean completed = new AtomicBoolean(false);
85+
ActionListener<Void> listener = ActionListener.wrap(r -> completed.set(true), e -> fail("Did not expect an exception"));
8286

8387
inferenceRunnerFactory(new NoOpClient(threadPool)).create(randomBulkExecutionConfig())
84-
.executeBulk(requestIterator, assertAnswerUsingSearchThreadPool(listener));
88+
.executeBulk(requestIterator, output::add, assertAnswerUsingSearchThreadPool(listener));
8589

86-
assertBusy(() -> assertThat(output.get(), allOf(notNullValue(), empty())));
90+
assertBusy(() -> {
91+
assertThat(completed.get(), equalTo(true));
92+
assertThat(output, empty());
93+
});
8794
}
8895

8996
public void testBulkExecutionWhenInferenceRunnerAlwaysFails() throws Exception {
@@ -98,10 +105,10 @@ public void testBulkExecutionWhenInferenceRunnerAlwaysFails() throws Exception {
98105
});
99106

100107
AtomicReference<Exception> exception = new AtomicReference<>();
101-
ActionListener<List<BulkInferenceResponse>> listener = ActionListener.wrap(r -> fail("Expected an exception"), exception::set);
108+
ActionListener<Void> listener = ActionListener.wrap(r -> fail("Expected an exception"), exception::set);
102109

103110
inferenceRunnerFactory(client).create(randomBulkExecutionConfig())
104-
.executeBulk(requestIterator(requests), assertAnswerUsingSearchThreadPool(listener));
111+
.executeBulk(requestIterator(requests), r -> {}, assertAnswerUsingSearchThreadPool(listener));
105112

106113
assertBusy(() -> {
107114
assertThat(exception.get(), notNullValue());
@@ -126,10 +133,10 @@ public void testBulkExecutionWhenInferenceRunnerSometimesFails() throws Exceptio
126133
});
127134

128135
AtomicReference<Exception> exception = new AtomicReference<>();
129-
ActionListener<List<BulkInferenceResponse>> listener = ActionListener.wrap(r -> fail("Expected an exception"), exception::set);
136+
ActionListener<Void> listener = ActionListener.wrap(r -> fail("Expected an exception"), exception::set);
130137

131138
inferenceRunnerFactory(client).create(randomBulkExecutionConfig())
132-
.executeBulk(requestIterator(requests), assertAnswerUsingSearchThreadPool(listener));
139+
.executeBulk(requestIterator(requests), r -> {}, assertAnswerUsingSearchThreadPool(listener));
133140

134141
assertBusy(() -> {
135142
assertThat(exception.get(), notNullValue());
@@ -153,14 +160,15 @@ public void testParallelBulkExecution() throws Exception {
153160
});
154161
return null;
155162
});
163+
ArrayList<BulkInferenceResponse> output = new ArrayList<>();
156164

157-
ActionListener<List<BulkInferenceResponse>> listener = ActionListener.wrap(r -> {
158-
assertThat(r.stream().map(BulkInferenceResponse::response).toList(), equalTo(responses));
165+
ActionListener<Void> listener = ActionListener.wrap(r -> {
166+
assertThat(output.stream().map(BulkInferenceResponse::response).toList(), equalTo(responses));
159167
latch.countDown();
160168
}, ESTestCase::fail);
161169

162170
inferenceRunnerFactory(client).create(randomBulkExecutionConfig())
163-
.executeBulk(requestIterator(requests), assertAnswerUsingSearchThreadPool(listener));
171+
.executeBulk(requestIterator(requests), output::add, assertAnswerUsingSearchThreadPool(listener));
164172
});
165173
}
166174

0 commit comments

Comments
 (0)