diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java index f7d349281e004..abb4eef251374 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java @@ -12,6 +12,7 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.lucene.BytesRefs; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FoldContext; @@ -30,13 +31,16 @@ public class InferenceResolver { private final Client client; + private final ThreadPool threadPool; + /** * Constructs a new {@code InferenceResolver}. * * @param client The Elasticsearch client for executing inference deployment lookups */ - public InferenceResolver(Client client) { + public InferenceResolver(Client client, ThreadPool threadPool) { this.client = client; + this.threadPool = threadPool; } /** @@ -99,10 +103,10 @@ void resolveInferenceIds(Set inferenceIds, ActionListener listener.onResponse(inferenceResolutionBuilder.build()), listener::onFailure) - ); + final CountDownActionListener countdownListener = new CountDownActionListener(inferenceIds.size(), ActionListener.wrap(_r -> { + threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION) + .execute(() -> listener.onResponse(inferenceResolutionBuilder.build())); + }, listener::onFailure)); for (var inferenceId : inferenceIds) { client.execute( @@ -145,18 +149,20 @@ private static String inferenceId(Expression e) { } public static Factory factory(Client client) { - return new Factory(client); + return new Factory(client, client.threadPool()); } public static class Factory { private final Client client; + private final ThreadPool threadPool; - private Factory(Client client) { + private Factory(Client client, ThreadPool threadPool) { this.client = client; + this.threadPool = threadPool; } public InferenceResolver create() { - return new InferenceResolver(client); + return new InferenceResolver(client, threadPool); } } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceResolverTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceResolverTests.java index 8666eedbaeaaa..39917f849e6f2 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceResolverTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceResolverTests.java @@ -21,9 +21,9 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.FixedExecutorBuilder; import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.esql.parser.EsqlParser; -import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; import org.junit.After; import org.junit.Before; @@ -51,7 +51,7 @@ public void setThreadPool() { getTestClass().getSimpleName(), new FixedExecutorBuilder( Settings.EMPTY, - EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME, + "inference_utility", between(1, 10), 1024, "esql", @@ -101,9 +101,12 @@ public void testResolveInferenceIds() throws Exception { List inferenceIds = List.of("rerank-plan"); SetOnce inferenceResolutionSetOnce = new SetOnce<>(); - inferenceResolver.resolveInferenceIds(inferenceIds, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> { - throw new RuntimeException(e); - })); + inferenceResolver.resolveInferenceIds( + inferenceIds, + assertAnswerUsingThreadPool(ActionListener.wrap(inferenceResolutionSetOnce::set, e -> { + throw new RuntimeException(e); + })) + ); assertBusy(() -> { InferenceResolution inferenceResolution = inferenceResolutionSetOnce.get(); @@ -118,9 +121,12 @@ public void testResolveMultipleInferenceIds() throws Exception { List inferenceIds = List.of("rerank-plan", "rerank-plan", "completion-plan"); SetOnce inferenceResolutionSetOnce = new SetOnce<>(); - inferenceResolver.resolveInferenceIds(inferenceIds, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> { - throw new RuntimeException(e); - })); + inferenceResolver.resolveInferenceIds( + inferenceIds, + assertAnswerUsingThreadPool(ActionListener.wrap(inferenceResolutionSetOnce::set, e -> { + throw new RuntimeException(e); + })) + ); assertBusy(() -> { InferenceResolution inferenceResolution = inferenceResolutionSetOnce.get(); @@ -143,9 +149,12 @@ public void testResolveMissingInferenceIds() throws Exception { SetOnce inferenceResolutionSetOnce = new SetOnce<>(); - inferenceResolver.resolveInferenceIds(inferenceIds, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> { - throw new RuntimeException(e); - })); + inferenceResolver.resolveInferenceIds( + inferenceIds, + assertAnswerUsingThreadPool(ActionListener.wrap(inferenceResolutionSetOnce::set, e -> { + throw new RuntimeException(e); + })) + ); assertBusy(() -> { InferenceResolution inferenceResolution = inferenceResolutionSetOnce.get(); @@ -173,21 +182,17 @@ private Client mockClient() { } }; - if (randomBoolean()) { - sendResponse.run(); - } else { - threadPool.schedule( - sendResponse, - TimeValue.timeValueNanos(between(1, 1_000)), - threadPool.executor(EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME) - ); - } + threadPool.schedule(sendResponse, TimeValue.timeValueNanos(between(1, 1_000)), threadPool.executor("inference_utility")); return null; }).when(client).execute(eq(GetInferenceModelAction.INSTANCE), any(), any()); return client; } + private ActionListener assertAnswerUsingThreadPool(ActionListener actionListener) { + return ActionListener.runBefore(actionListener, () -> ThreadPool.assertCurrentThreadPool(ThreadPool.Names.SEARCH_COORDINATION)); + } + private static ActionResponse getInferenceModelResponse(GetInferenceModelAction.Request request) { GetInferenceModelAction.Response response = mock(GetInferenceModelAction.Response.class); @@ -205,7 +210,7 @@ private static ActionResponse getInferenceModelResponse(GetInferenceModelAction. } private InferenceResolver inferenceResolver() { - return new InferenceResolver(mockClient()); + return new InferenceResolver(mockClient(), threadPool); } private static ModelConfigurations mockModelConfig(String inferenceId, TaskType taskType) {