Skip to content

Commit 1e95722

Browse files
committed
Fix inference throttling.
1 parent 115ee49 commit 1e95722

File tree

3 files changed

+48
-33
lines changed

3 files changed

+48
-33
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutor;
2121
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
2222

23+
import java.util.List;
24+
2325
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
2426

2527
public abstract class InferenceOperator extends AsyncOperator<InferenceOperator.OngoingInference> {
@@ -66,9 +68,7 @@ public Page getOutput() {
6668
}
6769

6870
try (OutputBuilder outputBuilder = outputBuilder(ongoingInference.inputPage)) {
69-
for (int i = 0; i < ongoingInference.responses.length; i++) {
70-
outputBuilder.addInferenceResponse(ongoingInference.responses[i]);
71-
}
71+
ongoingInference.responses.forEach(outputBuilder::addInferenceResponse);
7272
return outputBuilder.buildOutput();
7373
} finally {
7474
releaseFetchedOnAnyThread(ongoingInference);
@@ -109,7 +109,7 @@ static <IR extends InferenceServiceResults> IR inferenceResults(InferenceAction.
109109
}
110110
}
111111

112-
public record OngoingInference(Page inputPage, InferenceAction.Response[] responses) {
112+
public record OngoingInference(Page inputPage, List<InferenceAction.Response> responses) {
113113

114114
}
115115
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutor.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ public BulkInferenceExecutor(InferenceRunner inferenceRunner, ThreadPool threadP
3131
throttledInferenceRunner = ThrottledInferenceRunner.create(inferenceRunner, executorService, bulkExecutionConfig);
3232
}
3333

34-
public void execute(BulkInferenceRequestIterator requests, ActionListener<InferenceAction.Response[]> listener) throws Exception {
34+
public void execute(BulkInferenceRequestIterator requests, ActionListener<List<InferenceAction.Response>> listener) throws Exception {
3535
final ResponseHandler responseHandler = new ResponseHandler();
3636
runInferenceRequests(requests, listener.delegateFailureAndWrap(responseHandler::handleResponses));
3737
}
@@ -64,7 +64,7 @@ private void runInferenceRequests(BulkInferenceRequestIterator requests, ActionL
6464
private static class ResponseHandler {
6565
private final List<InferenceAction.Response> responses = new ArrayList<>();
6666

67-
private void handleResponses(ActionListener<InferenceAction.Response[]> listener, BulkInferenceExecutionState bulkExecutionState) {
67+
private void handleResponses(ActionListener<List<InferenceAction.Response>> listener, BulkInferenceExecutionState bulkExecutionState) {
6868

6969
try {
7070
persistsInferenceResponses(bulkExecutionState);
@@ -75,7 +75,7 @@ private void handleResponses(ActionListener<InferenceAction.Response[]> listener
7575

7676
if (bulkExecutionState.hasFailure() == false) {
7777
try {
78-
listener.onResponse(responses.toArray(InferenceAction.Response[]::new));
78+
listener.onResponse(responses);
7979
return;
8080
} catch (Exception e) {
8181
bulkExecutionState.addFailure(e);
@@ -125,9 +125,7 @@ public static ThrottledInferenceRunner create(
125125

126126
public void doInference(InferenceAction.Request request, ActionListener<InferenceAction.Response> listener) {
127127
this.enqueueTask(listener.delegateFailureAndWrap((l, releasable) -> {
128-
try (releasable) {
129-
inferenceRunner.doInference(request, l);
130-
}
128+
inferenceRunner.doInference(request, ActionListener.releaseAfter(l, releasable));
131129
}));
132130
}
133131
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutorTests.java

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@
3030
import java.util.stream.Stream;
3131

3232
import static org.hamcrest.Matchers.allOf;
33-
import static org.hamcrest.Matchers.arrayContaining;
34-
import static org.hamcrest.Matchers.emptyArray;
33+
import static org.hamcrest.Matchers.empty;
3534
import static org.hamcrest.Matchers.equalTo;
3635
import static org.hamcrest.Matchers.notNullValue;
3736
import static org.mockito.ArgumentMatchers.any;
@@ -63,68 +62,74 @@ public void shutdownThreadPool() {
6362
}
6463

6564
public void testSuccessfulExecution() throws Exception {
66-
List<InferenceAction.Request> requests = randomInferenceRequestList(between(1, 50));
67-
InferenceAction.Response[] responses = randomInferenceResponseList(requests.size()).toArray(InferenceAction.Response[]::new);
65+
List<InferenceAction.Request> requests = randomInferenceRequestList(between(90_000, 100_000));
66+
List<InferenceAction.Response> responses = randomInferenceResponseList(requests.size());
6867

6968
InferenceRunner inferenceRunner = mockInferenceRunner(invocation -> {
70-
ActionListener<InferenceAction.Response> l = invocation.getArgument(1);
71-
l.onResponse(responses[requests.indexOf(invocation.getArgument(0, InferenceAction.Request.class))]);
69+
runWithRandomDelay(() -> {
70+
ActionListener<InferenceAction.Response> l = invocation.getArgument(1);
71+
l.onResponse(responses.get(requests.indexOf(invocation.getArgument(0, InferenceAction.Request.class))));
72+
});
7273
return null;
7374
});
7475

75-
AtomicReference<InferenceAction.Response[]> output = new AtomicReference<>();
76-
ActionListener<InferenceAction.Response[]> listener = ActionListener.wrap(output::set, r -> fail("Unexpected exception"));
76+
AtomicReference<List<InferenceAction.Response>> output = new AtomicReference<>();
77+
ActionListener<List<InferenceAction.Response>> listener = ActionListener.wrap(output::set, r -> fail("Unexpected exception"));
7778

7879
bulkExecutor(inferenceRunner).execute(requestIterator(requests), listener);
7980

80-
assertBusy(() -> assertThat(output.get(), allOf(notNullValue(), arrayContaining(responses))));
81+
assertBusy(() -> assertThat(output.get(), allOf(notNullValue(), equalTo(responses))));
8182
}
8283

8384
public void testSuccessfulExecutionOnEmptyRequest() throws Exception {
8485
BulkInferenceRequestIterator requestIterator = mock(BulkInferenceRequestIterator.class);
8586
when(requestIterator.hasNext()).thenReturn(false);
8687

87-
AtomicReference<InferenceAction.Response[]> output = new AtomicReference<>();
88-
ActionListener<InferenceAction.Response[]> listener = ActionListener.wrap(output::set, r -> fail("Unexpected exception"));
88+
AtomicReference<List<InferenceAction.Response>> output = new AtomicReference<>();
89+
ActionListener<List<InferenceAction.Response>> listener = ActionListener.wrap(output::set, r -> fail("Unexpected exception"));
8990

9091
bulkExecutor(mock(InferenceRunner.class)).execute(requestIterator, listener);
9192

92-
assertBusy(() -> assertThat(output.get(), allOf(notNullValue(), emptyArray())));
93+
assertBusy(() -> assertThat(output.get(), allOf(notNullValue(), empty())));
9394
}
9495

9596
public void testInferenceRunnerAlwaysFails() throws Exception {
96-
List<InferenceAction.Request> requests = randomInferenceRequestList(between(1, 30));
97+
List<InferenceAction.Request> requests = randomInferenceRequestList(between(90_000, 100_000));
9798

9899
InferenceRunner inferenceRunner = mock(invocation -> {
99-
ActionListener<InferenceAction.Response> listener = invocation.getArgument(1);
100-
listener.onFailure(new RuntimeException("inference failure"));
100+
runWithRandomDelay(() -> {
101+
ActionListener<InferenceAction.Response> listener = invocation.getArgument(1);
102+
listener.onFailure(new RuntimeException("inference failure"));
103+
});
101104
return null;
102105
});
103106

104107
AtomicReference<Exception> exception = new AtomicReference<>();
105-
ActionListener<InferenceAction.Response[]> listener = ActionListener.wrap(r -> fail("Expceted exception"), exception::set);
108+
ActionListener<List<InferenceAction.Response>> listener = ActionListener.wrap(r -> fail("Expceted exception"), exception::set);
106109

107110
bulkExecutor(inferenceRunner).execute(requestIterator(requests), listener);
108111

109112
assertBusy(() -> assertThat(exception.get().getMessage(), equalTo("inference failure")));
110113
}
111114

112115
public void testInferenceRunnerSometimesFails() throws Exception {
113-
List<InferenceAction.Request> requests = randomInferenceRequestList(between(2, 30));
116+
List<InferenceAction.Request> requests = randomInferenceRequestList(between(90_000, 100_000));
114117

115118
InferenceRunner inferenceRunner = mockInferenceRunner(invocation -> {
116119
ActionListener<InferenceAction.Response> listener = invocation.getArgument(1);
117-
if ((requests.indexOf(invocation.getArgument(0, InferenceAction.Request.class)) % requests.size()) == 0) {
118-
listener.onFailure(new RuntimeException("inference failure"));
119-
} else {
120-
listener.onResponse(mockInferenceResponse());
121-
}
120+
runWithRandomDelay(() -> {
121+
if ((requests.indexOf(invocation.getArgument(0, InferenceAction.Request.class)) % requests.size()) == 0) {
122+
listener.onFailure(new RuntimeException("inference failure"));
123+
} else {
124+
listener.onResponse(mockInferenceResponse());
125+
}
126+
});
122127

123128
return null;
124129
});
125130

126131
AtomicReference<Exception> exception = new AtomicReference<>();
127-
ActionListener<InferenceAction.Response[]> listener = ActionListener.wrap(r -> fail("Expceted exception"), exception::set);
132+
ActionListener<List<InferenceAction.Response>> listener = ActionListener.wrap(r -> fail("Expceted exception"), exception::set);
128133

129134
bulkExecutor(inferenceRunner).execute(requestIterator(requests), listener);
130135

@@ -170,4 +175,16 @@ private InferenceRunner mockInferenceRunner(Answer<Void> doInferenceAnswer) {
170175
doAnswer(doInferenceAnswer).when(inferenceRunner).doInference(any(), any());
171176
return inferenceRunner;
172177
}
178+
179+
private void runWithRandomDelay(Runnable runnable) {
180+
if (randomBoolean()) {
181+
runnable.run();
182+
} else {
183+
threadPool.schedule(
184+
runnable,
185+
TimeValue.timeValueNanos(between(1, 1_000)),
186+
threadPool.executor(EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME)
187+
);
188+
}
189+
}
173190
}

0 commit comments

Comments
 (0)