Skip to content

Commit a9b8048

Browse files
authored
Fix assertion error after inference resolution. (#134777)
The error has been introduced in #132324 Now making sure that InferenceResolver answer though the SEARCH_COORDINATION threadpool so subsequent assertions will not fail.
1 parent 7847082 commit a9b8048

File tree

2 files changed

+40
-29
lines changed

2 files changed

+40
-29
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.elasticsearch.client.internal.Client;
1313
import org.elasticsearch.common.lucene.BytesRefs;
1414
import org.elasticsearch.inference.TaskType;
15+
import org.elasticsearch.threadpool.ThreadPool;
1516
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
1617
import org.elasticsearch.xpack.esql.core.expression.Expression;
1718
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
@@ -30,13 +31,16 @@ public class InferenceResolver {
3031

3132
private final Client client;
3233

34+
private final ThreadPool threadPool;
35+
3336
/**
3437
* Constructs a new {@code InferenceResolver}.
3538
*
3639
* @param client The Elasticsearch client for executing inference deployment lookups
3740
*/
38-
public InferenceResolver(Client client) {
41+
public InferenceResolver(Client client, ThreadPool threadPool) {
3942
this.client = client;
43+
this.threadPool = threadPool;
4044
}
4145

4246
/**
@@ -99,10 +103,10 @@ void resolveInferenceIds(Set<String> inferenceIds, ActionListener<InferenceResol
99103

100104
final InferenceResolution.Builder inferenceResolutionBuilder = InferenceResolution.builder();
101105

102-
final CountDownActionListener countdownListener = new CountDownActionListener(
103-
inferenceIds.size(),
104-
ActionListener.wrap(_r -> listener.onResponse(inferenceResolutionBuilder.build()), listener::onFailure)
105-
);
106+
final CountDownActionListener countdownListener = new CountDownActionListener(inferenceIds.size(), ActionListener.wrap(_r -> {
107+
threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION)
108+
.execute(() -> listener.onResponse(inferenceResolutionBuilder.build()));
109+
}, listener::onFailure));
106110

107111
for (var inferenceId : inferenceIds) {
108112
client.execute(
@@ -145,18 +149,20 @@ private static String inferenceId(Expression e) {
145149
}
146150

147151
public static Factory factory(Client client) {
148-
return new Factory(client);
152+
return new Factory(client, client.threadPool());
149153
}
150154

151155
public static class Factory {
152156
private final Client client;
157+
private final ThreadPool threadPool;
153158

154-
private Factory(Client client) {
159+
private Factory(Client client, ThreadPool threadPool) {
155160
this.client = client;
161+
this.threadPool = threadPool;
156162
}
157163

158164
public InferenceResolver create() {
159-
return new InferenceResolver(client);
165+
return new InferenceResolver(client, threadPool);
160166
}
161167
}
162168
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceResolverTests.java

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
import org.elasticsearch.test.ESTestCase;
2222
import org.elasticsearch.threadpool.FixedExecutorBuilder;
2323
import org.elasticsearch.threadpool.TestThreadPool;
24+
import org.elasticsearch.threadpool.ThreadPool;
2425
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
2526
import org.elasticsearch.xpack.esql.parser.EsqlParser;
26-
import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
2727
import org.junit.After;
2828
import org.junit.Before;
2929

@@ -51,7 +51,7 @@ public void setThreadPool() {
5151
getTestClass().getSimpleName(),
5252
new FixedExecutorBuilder(
5353
Settings.EMPTY,
54-
EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME,
54+
"inference_utility",
5555
between(1, 10),
5656
1024,
5757
"esql",
@@ -101,9 +101,12 @@ public void testResolveInferenceIds() throws Exception {
101101
List<String> inferenceIds = List.of("rerank-plan");
102102
SetOnce<InferenceResolution> inferenceResolutionSetOnce = new SetOnce<>();
103103

104-
inferenceResolver.resolveInferenceIds(inferenceIds, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
105-
throw new RuntimeException(e);
106-
}));
104+
inferenceResolver.resolveInferenceIds(
105+
inferenceIds,
106+
assertAnswerUsingThreadPool(ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
107+
throw new RuntimeException(e);
108+
}))
109+
);
107110

108111
assertBusy(() -> {
109112
InferenceResolution inferenceResolution = inferenceResolutionSetOnce.get();
@@ -118,9 +121,12 @@ public void testResolveMultipleInferenceIds() throws Exception {
118121
List<String> inferenceIds = List.of("rerank-plan", "rerank-plan", "completion-plan");
119122
SetOnce<InferenceResolution> inferenceResolutionSetOnce = new SetOnce<>();
120123

121-
inferenceResolver.resolveInferenceIds(inferenceIds, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
122-
throw new RuntimeException(e);
123-
}));
124+
inferenceResolver.resolveInferenceIds(
125+
inferenceIds,
126+
assertAnswerUsingThreadPool(ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
127+
throw new RuntimeException(e);
128+
}))
129+
);
124130

125131
assertBusy(() -> {
126132
InferenceResolution inferenceResolution = inferenceResolutionSetOnce.get();
@@ -143,9 +149,12 @@ public void testResolveMissingInferenceIds() throws Exception {
143149

144150
SetOnce<InferenceResolution> inferenceResolutionSetOnce = new SetOnce<>();
145151

146-
inferenceResolver.resolveInferenceIds(inferenceIds, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
147-
throw new RuntimeException(e);
148-
}));
152+
inferenceResolver.resolveInferenceIds(
153+
inferenceIds,
154+
assertAnswerUsingThreadPool(ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
155+
throw new RuntimeException(e);
156+
}))
157+
);
149158

150159
assertBusy(() -> {
151160
InferenceResolution inferenceResolution = inferenceResolutionSetOnce.get();
@@ -173,21 +182,17 @@ private Client mockClient() {
173182
}
174183
};
175184

176-
if (randomBoolean()) {
177-
sendResponse.run();
178-
} else {
179-
threadPool.schedule(
180-
sendResponse,
181-
TimeValue.timeValueNanos(between(1, 1_000)),
182-
threadPool.executor(EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME)
183-
);
184-
}
185+
threadPool.schedule(sendResponse, TimeValue.timeValueNanos(between(1, 1_000)), threadPool.executor("inference_utility"));
185186

186187
return null;
187188
}).when(client).execute(eq(GetInferenceModelAction.INSTANCE), any(), any());
188189
return client;
189190
}
190191

192+
private <T> ActionListener<T> assertAnswerUsingThreadPool(ActionListener<T> actionListener) {
193+
return ActionListener.runBefore(actionListener, () -> ThreadPool.assertCurrentThreadPool(ThreadPool.Names.SEARCH_COORDINATION));
194+
}
195+
191196
private static ActionResponse getInferenceModelResponse(GetInferenceModelAction.Request request) {
192197
GetInferenceModelAction.Response response = mock(GetInferenceModelAction.Response.class);
193198

@@ -205,7 +210,7 @@ private static ActionResponse getInferenceModelResponse(GetInferenceModelAction.
205210
}
206211

207212
private InferenceResolver inferenceResolver() {
208-
return new InferenceResolver(mockClient());
213+
return new InferenceResolver(mockClient(), threadPool);
209214
}
210215

211216
private static ModelConfigurations mockModelConfig(String inferenceId, TaskType taskType) {

0 commit comments

Comments
 (0)