Skip to content

Commit 3dcec80

Browse files
authored
[ES|QL] Ensure Inference Service answer using only SEARCH and SEARCH_COORDINATION thread pools. (#135071)
1 parent 2e69f17 commit 3dcec80

File tree

5 files changed

+69
-92
lines changed

5 files changed

+69
-92
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.elasticsearch.action.ActionListener;
1111
import org.elasticsearch.action.support.CountDownActionListener;
12+
import org.elasticsearch.action.support.ThreadedActionListener;
1213
import org.elasticsearch.client.internal.Client;
1314
import org.elasticsearch.common.lucene.BytesRefs;
1415
import org.elasticsearch.inference.TaskType;
@@ -103,23 +104,23 @@ void resolveInferenceIds(Set<String> inferenceIds, ActionListener<InferenceResol
103104

104105
final InferenceResolution.Builder inferenceResolutionBuilder = InferenceResolution.builder();
105106

106-
final CountDownActionListener countdownListener = new CountDownActionListener(inferenceIds.size(), ActionListener.wrap(_r -> {
107-
threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION)
108-
.execute(() -> listener.onResponse(inferenceResolutionBuilder.build()));
109-
}, listener::onFailure));
107+
final CountDownActionListener countdownListener = new CountDownActionListener(
108+
inferenceIds.size(),
109+
listener.delegateFailureIgnoreResponseAndWrap(l -> l.onResponse(inferenceResolutionBuilder.build()))
110+
);
110111

111112
for (var inferenceId : inferenceIds) {
112113
client.execute(
113114
GetInferenceModelAction.INSTANCE,
114115
new GetInferenceModelAction.Request(inferenceId, TaskType.ANY),
115-
ActionListener.wrap(r -> {
116+
new ThreadedActionListener<>(threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION), ActionListener.wrap(r -> {
116117
ResolvedInference resolvedInference = new ResolvedInference(inferenceId, r.getEndpoints().getFirst().getTaskType());
117118
inferenceResolutionBuilder.withResolvedInference(resolvedInference);
118119
countdownListener.onResponse(null);
119120
}, e -> {
120121
inferenceResolutionBuilder.withError(inferenceId, e.getMessage());
121122
countdownListener.onResponse(null);
122-
})
123+
}))
123124
);
124125
}
125126
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java

Lines changed: 42 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.esql.inference.bulk;
99

1010
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.action.support.ThreadedActionListener;
1112
import org.elasticsearch.client.internal.Client;
1213
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
1314
import org.elasticsearch.threadpool.ThreadPool;
@@ -25,7 +26,6 @@
2526

2627
import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN;
2728
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
28-
import static org.elasticsearch.xpack.esql.plugin.EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME;
2929

3030
/**
3131
* Implementation of bulk inference execution with throttling and concurrency control.
@@ -88,7 +88,7 @@ public BulkInferenceRequest poll() {
8888
public BulkInferenceRunner(Client client, int maxRunningTasks) {
8989
this.permits = new Semaphore(maxRunningTasks);
9090
this.client = client;
91-
this.executor = client.threadPool().executor(ESQL_WORKER_THREAD_POOL_NAME);
91+
this.executor = client.threadPool().executor(ThreadPool.Names.SEARCH);
9292
}
9393

9494
/**
@@ -99,7 +99,7 @@ public BulkInferenceRunner(Client client, int maxRunningTasks) {
9999
*/
100100
public void executeBulk(BulkInferenceRequestIterator requests, ActionListener<List<InferenceAction.Response>> listener) {
101101
List<InferenceAction.Response> responses = new ArrayList<>();
102-
executeBulk(requests, responses::add, ActionListener.wrap(ignored -> listener.onResponse(responses), listener::onFailure));
102+
executeBulk(requests, responses::add, listener.delegateFailureIgnoreResponseAndWrap(l -> l.onResponse(responses)));
103103
}
104104

105105
/**
@@ -253,48 +253,51 @@ private void executePendingRequests(int recursionDepth) {
253253
executionState.finish();
254254
}
255255

256-
final ActionListener<InferenceAction.Response> inferenceResponseListener = ActionListener.runAfter(
257-
ActionListener.wrap(
258-
r -> executionState.onInferenceResponse(bulkRequestItem.seqNo(), r),
259-
e -> executionState.onInferenceException(bulkRequestItem.seqNo(), e)
260-
),
261-
() -> {
262-
// Release the permit we used
263-
permits.release();
264-
265-
try {
266-
synchronized (executionState) {
267-
persistPendingResponses();
268-
}
256+
final ActionListener<InferenceAction.Response> inferenceResponseListener = new ThreadedActionListener<>(
257+
executor,
258+
ActionListener.runAfter(
259+
ActionListener.wrap(
260+
r -> executionState.onInferenceResponse(bulkRequestItem.seqNo(), r),
261+
e -> executionState.onInferenceException(bulkRequestItem.seqNo(), e)
262+
),
263+
() -> {
264+
// Release the permit we used
265+
permits.release();
266+
267+
try {
268+
synchronized (executionState) {
269+
persistPendingResponses();
270+
}
269271

270-
if (executionState.finished() && responseSent.compareAndSet(false, true)) {
271-
onBulkCompletion();
272-
}
272+
if (executionState.finished() && responseSent.compareAndSet(false, true)) {
273+
onBulkCompletion();
274+
}
273275

274-
if (responseSent.get()) {
275-
// Response has already been sent
276-
// No need to continue processing this bulk.
277-
// Check if another bulk request is pending for execution.
278-
BulkInferenceRequest nexBulkRequest = pendingBulkRequests.poll();
279-
if (nexBulkRequest != null) {
280-
executor.execute(nexBulkRequest::executePendingRequests);
276+
if (responseSent.get()) {
277+
// Response has already been sent
278+
// No need to continue processing this bulk.
279+
// Check if another bulk request is pending for execution.
280+
BulkInferenceRequest nexBulkRequest = pendingBulkRequests.poll();
281+
if (nexBulkRequest != null) {
282+
executor.execute(nexBulkRequest::executePendingRequests);
283+
}
284+
return;
281285
}
282-
return;
283-
}
284-
if (executionState.finished() == false) {
285-
// Execute any pending requests if any
286-
if (recursionDepth > 100) {
287-
executor.execute(this::executePendingRequests);
288-
} else {
289-
this.executePendingRequests(recursionDepth + 1);
286+
if (executionState.finished() == false) {
287+
// Execute any pending requests if any
288+
if (recursionDepth > 100) {
289+
executor.execute(this::executePendingRequests);
290+
} else {
291+
this.executePendingRequests(recursionDepth + 1);
292+
}
293+
}
294+
} catch (Exception e) {
295+
if (responseSent.compareAndSet(false, true)) {
296+
completionListener.onFailure(e);
290297
}
291-
}
292-
} catch (Exception e) {
293-
if (responseSent.compareAndSet(false, true)) {
294-
completionListener.onFailure(e);
295298
}
296299
}
297-
}
300+
)
298301
);
299302

300303
// Handle null requests (edge case in some iterators)

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
import org.elasticsearch.action.ActionType;
1515
import org.elasticsearch.client.internal.Client;
1616
import org.elasticsearch.common.logging.LoggerMessageFormat;
17-
import org.elasticsearch.common.settings.Settings;
18-
import org.elasticsearch.common.util.concurrent.EsExecutors;
1917
import org.elasticsearch.compute.data.Block;
2018
import org.elasticsearch.compute.data.BlockFactory;
2119
import org.elasticsearch.compute.data.BooleanBlock;
@@ -33,11 +31,8 @@
3331
import org.elasticsearch.core.TimeValue;
3432
import org.elasticsearch.inference.InferenceServiceResults;
3533
import org.elasticsearch.test.client.NoOpClient;
36-
import org.elasticsearch.threadpool.FixedExecutorBuilder;
37-
import org.elasticsearch.threadpool.TestThreadPool;
3834
import org.elasticsearch.threadpool.ThreadPool;
3935
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
40-
import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
4136
import org.junit.After;
4237
import org.junit.Before;
4338

@@ -52,17 +47,7 @@ public abstract class InferenceOperatorTestCase<InferenceResultsType extends Inf
5247

5348
@Before
5449
public void setThreadPool() {
55-
threadPool = new TestThreadPool(
56-
getTestClass().getSimpleName(),
57-
new FixedExecutorBuilder(
58-
Settings.EMPTY,
59-
EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME,
60-
between(1, 10),
61-
1024,
62-
"esql",
63-
EsExecutors.TaskTrackingConfig.DEFAULT
64-
)
65-
);
50+
threadPool = createThreadPool();
6651
}
6752

6853
@Before

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceResolverTests.java

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,7 @@ public void testResolveInferenceIds() throws Exception {
103103

104104
inferenceResolver.resolveInferenceIds(
105105
inferenceIds,
106-
assertAnswerUsingThreadPool(ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
107-
throw new RuntimeException(e);
108-
}))
106+
assertAnswerUsingSearchCoordinationThreadPool(ActionListener.wrap(inferenceResolutionSetOnce::set, ESTestCase::fail))
109107
);
110108

111109
assertBusy(() -> {
@@ -123,9 +121,7 @@ public void testResolveMultipleInferenceIds() throws Exception {
123121

124122
inferenceResolver.resolveInferenceIds(
125123
inferenceIds,
126-
assertAnswerUsingThreadPool(ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
127-
throw new RuntimeException(e);
128-
}))
124+
assertAnswerUsingSearchCoordinationThreadPool(ActionListener.wrap(inferenceResolutionSetOnce::set, ESTestCase::fail))
129125
);
130126

131127
assertBusy(() -> {
@@ -151,9 +147,7 @@ public void testResolveMissingInferenceIds() throws Exception {
151147

152148
inferenceResolver.resolveInferenceIds(
153149
inferenceIds,
154-
assertAnswerUsingThreadPool(ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
155-
throw new RuntimeException(e);
156-
}))
150+
assertAnswerUsingSearchCoordinationThreadPool(ActionListener.wrap(inferenceResolutionSetOnce::set, ESTestCase::fail))
157151
);
158152

159153
assertBusy(() -> {
@@ -189,7 +183,7 @@ private Client mockClient() {
189183
return client;
190184
}
191185

192-
private <T> ActionListener<T> assertAnswerUsingThreadPool(ActionListener<T> actionListener) {
186+
private <T> ActionListener<T> assertAnswerUsingSearchCoordinationThreadPool(ActionListener<T> actionListener) {
193187
return ActionListener.runBefore(actionListener, () -> ThreadPool.assertCurrentThreadPool(ThreadPool.Names.SEARCH_COORDINATION));
194188
}
195189

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunnerTests.java

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,12 @@
99

1010
import org.elasticsearch.action.ActionListener;
1111
import org.elasticsearch.client.internal.Client;
12-
import org.elasticsearch.common.settings.Settings;
13-
import org.elasticsearch.common.util.concurrent.EsExecutors;
1412
import org.elasticsearch.core.TimeValue;
1513
import org.elasticsearch.test.ESTestCase;
1614
import org.elasticsearch.test.client.NoOpClient;
17-
import org.elasticsearch.threadpool.FixedExecutorBuilder;
18-
import org.elasticsearch.threadpool.TestThreadPool;
1915
import org.elasticsearch.threadpool.ThreadPool;
2016
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
2117
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
22-
import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
2318
import org.junit.After;
2419
import org.junit.Before;
2520
import org.mockito.stubbing.Answer;
@@ -46,17 +41,7 @@ public class BulkInferenceRunnerTests extends ESTestCase {
4641

4742
@Before
4843
public void setThreadPool() {
49-
threadPool = new TestThreadPool(
50-
getTestClass().getSimpleName(),
51-
new FixedExecutorBuilder(
52-
Settings.EMPTY,
53-
EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME,
54-
between(1, 20),
55-
1024,
56-
"esql",
57-
EsExecutors.TaskTrackingConfig.DEFAULT
58-
)
59-
);
44+
threadPool = createThreadPool();
6045
}
6146

6247
@After
@@ -79,7 +64,8 @@ public void testSuccessfulBulkExecution() throws Exception {
7964
AtomicReference<List<InferenceAction.Response>> output = new AtomicReference<>();
8065
ActionListener<List<InferenceAction.Response>> listener = ActionListener.wrap(output::set, ESTestCase::fail);
8166

82-
inferenceRunnerFactory(client).create(randomBulkExecutionConfig()).executeBulk(requestIterator(requests), listener);
67+
inferenceRunnerFactory(client).create(randomBulkExecutionConfig())
68+
.executeBulk(requestIterator(requests), assertAnswerUsingSearchThreadPool(listener));
8369

8470
assertBusy(() -> assertThat(output.get(), allOf(notNullValue(), equalTo(responses))));
8571
}
@@ -91,7 +77,8 @@ public void testSuccessfulBulkExecutionOnEmptyRequest() throws Exception {
9177
AtomicReference<List<InferenceAction.Response>> output = new AtomicReference<>();
9278
ActionListener<List<InferenceAction.Response>> listener = ActionListener.wrap(output::set, ESTestCase::fail);
9379

94-
inferenceRunnerFactory(new NoOpClient(threadPool)).create(randomBulkExecutionConfig()).executeBulk(requestIterator, listener);
80+
inferenceRunnerFactory(new NoOpClient(threadPool)).create(randomBulkExecutionConfig())
81+
.executeBulk(requestIterator, assertAnswerUsingSearchThreadPool(listener));
9582

9683
assertBusy(() -> assertThat(output.get(), allOf(notNullValue(), empty())));
9784
}
@@ -110,7 +97,8 @@ public void testBulkExecutionWhenInferenceRunnerAlwaysFails() throws Exception {
11097
AtomicReference<Exception> exception = new AtomicReference<>();
11198
ActionListener<List<InferenceAction.Response>> listener = ActionListener.wrap(r -> fail("Expected an exception"), exception::set);
11299

113-
inferenceRunnerFactory(client).create(randomBulkExecutionConfig()).executeBulk(requestIterator(requests), listener);
100+
inferenceRunnerFactory(client).create(randomBulkExecutionConfig())
101+
.executeBulk(requestIterator(requests), assertAnswerUsingSearchThreadPool(listener));
114102

115103
assertBusy(() -> {
116104
assertThat(exception.get(), notNullValue());
@@ -137,7 +125,8 @@ public void testBulkExecutionWhenInferenceRunnerSometimesFails() throws Exceptio
137125
AtomicReference<Exception> exception = new AtomicReference<>();
138126
ActionListener<List<InferenceAction.Response>> listener = ActionListener.wrap(r -> fail("Expected an exception"), exception::set);
139127

140-
inferenceRunnerFactory(client).create(randomBulkExecutionConfig()).executeBulk(requestIterator(requests), listener);
128+
inferenceRunnerFactory(client).create(randomBulkExecutionConfig())
129+
.executeBulk(requestIterator(requests), assertAnswerUsingSearchThreadPool(listener));
141130

142131
assertBusy(() -> {
143132
assertThat(exception.get(), notNullValue());
@@ -167,13 +156,18 @@ public void testParallelBulkExecution() throws Exception {
167156
latch.countDown();
168157
}, ESTestCase::fail);
169158

170-
inferenceRunnerFactory(client).create(randomBulkExecutionConfig()).executeBulk(requestIterator(requests), listener);
159+
inferenceRunnerFactory(client).create(randomBulkExecutionConfig())
160+
.executeBulk(requestIterator(requests), assertAnswerUsingSearchThreadPool(listener));
171161
});
172162
}
173163

174164
latch.await(10, TimeUnit.SECONDS);
175165
}
176166

167+
private <T> ActionListener<T> assertAnswerUsingSearchThreadPool(ActionListener<T> actionListener) {
168+
return ActionListener.runBefore(actionListener, () -> ThreadPool.assertCurrentThreadPool(ThreadPool.Names.SEARCH));
169+
}
170+
177171
private BulkInferenceRunner.Factory inferenceRunnerFactory(Client client) {
178172
return BulkInferenceRunner.factory(client);
179173
}

0 commit comments

Comments
 (0)