diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java index abb4eef251374..637c4d3b1ad76 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java @@ -9,6 +9,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.CountDownActionListener; +import org.elasticsearch.action.support.ThreadedActionListener; import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.lucene.BytesRefs; import org.elasticsearch.inference.TaskType; @@ -103,23 +104,23 @@ void resolveInferenceIds(Set inferenceIds, ActionListener { - threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION) - .execute(() -> listener.onResponse(inferenceResolutionBuilder.build())); - }, listener::onFailure)); + final CountDownActionListener countdownListener = new CountDownActionListener( + inferenceIds.size(), + listener.delegateFailureIgnoreResponseAndWrap(l -> l.onResponse(inferenceResolutionBuilder.build())) + ); for (var inferenceId : inferenceIds) { client.execute( GetInferenceModelAction.INSTANCE, new GetInferenceModelAction.Request(inferenceId, TaskType.ANY), - ActionListener.wrap(r -> { + new ThreadedActionListener<>(threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION), ActionListener.wrap(r -> { ResolvedInference resolvedInference = new ResolvedInference(inferenceId, r.getEndpoints().getFirst().getTaskType()); inferenceResolutionBuilder.withResolvedInference(resolvedInference); countdownListener.onResponse(null); }, e -> { inferenceResolutionBuilder.withError(inferenceId, e.getMessage()); countdownListener.onResponse(null); - }) + })) ); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java index 203a3031bcad4..9e5011d77b307 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.inference.bulk; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ThreadedActionListener; import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.threadpool.ThreadPool; @@ -25,7 +26,6 @@ import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; -import static org.elasticsearch.xpack.esql.plugin.EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME; /** * Implementation of bulk inference execution with throttling and concurrency control. @@ -88,7 +88,7 @@ public BulkInferenceRequest poll() { public BulkInferenceRunner(Client client, int maxRunningTasks) { this.permits = new Semaphore(maxRunningTasks); this.client = client; - this.executor = client.threadPool().executor(ESQL_WORKER_THREAD_POOL_NAME); + this.executor = client.threadPool().executor(ThreadPool.Names.SEARCH); } /** @@ -99,7 +99,7 @@ public BulkInferenceRunner(Client client, int maxRunningTasks) { */ public void executeBulk(BulkInferenceRequestIterator requests, ActionListener> listener) { List responses = new ArrayList<>(); - executeBulk(requests, responses::add, ActionListener.wrap(ignored -> listener.onResponse(responses), listener::onFailure)); + executeBulk(requests, responses::add, listener.delegateFailureIgnoreResponseAndWrap(l -> l.onResponse(responses))); } /** @@ -253,48 +253,51 @@ private void executePendingRequests(int recursionDepth) { executionState.finish(); } - final ActionListener inferenceResponseListener = ActionListener.runAfter( - ActionListener.wrap( - r -> executionState.onInferenceResponse(bulkRequestItem.seqNo(), r), - e -> executionState.onInferenceException(bulkRequestItem.seqNo(), e) - ), - () -> { - // Release the permit we used - permits.release(); - - try { - synchronized (executionState) { - persistPendingResponses(); - } + final ActionListener inferenceResponseListener = new ThreadedActionListener<>( + executor, + ActionListener.runAfter( + ActionListener.wrap( + r -> executionState.onInferenceResponse(bulkRequestItem.seqNo(), r), + e -> executionState.onInferenceException(bulkRequestItem.seqNo(), e) + ), + () -> { + // Release the permit we used + permits.release(); + + try { + synchronized (executionState) { + persistPendingResponses(); + } - if (executionState.finished() && responseSent.compareAndSet(false, true)) { - onBulkCompletion(); - } + if (executionState.finished() && responseSent.compareAndSet(false, true)) { + onBulkCompletion(); + } - if (responseSent.get()) { - // Response has already been sent - // No need to continue processing this bulk. - // Check if another bulk request is pending for execution. - BulkInferenceRequest nexBulkRequest = pendingBulkRequests.poll(); - if (nexBulkRequest != null) { - executor.execute(nexBulkRequest::executePendingRequests); + if (responseSent.get()) { + // Response has already been sent + // No need to continue processing this bulk. + // Check if another bulk request is pending for execution. + BulkInferenceRequest nexBulkRequest = pendingBulkRequests.poll(); + if (nexBulkRequest != null) { + executor.execute(nexBulkRequest::executePendingRequests); + } + return; } - return; - } - if (executionState.finished() == false) { - // Execute any pending requests if any - if (recursionDepth > 100) { - executor.execute(this::executePendingRequests); - } else { - this.executePendingRequests(recursionDepth + 1); + if (executionState.finished() == false) { + // Execute any pending requests if any + if (recursionDepth > 100) { + executor.execute(this::executePendingRequests); + } else { + this.executePendingRequests(recursionDepth + 1); + } + } + } catch (Exception e) { + if (responseSent.compareAndSet(false, true)) { + completionListener.onFailure(e); } - } - } catch (Exception e) { - if (responseSent.compareAndSet(false, true)) { - completionListener.onFailure(e); } } - } + ) ); // Handle null requests (edge case in some iterators) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java index 89a9e6a1baf39..ac964efd892ff 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java @@ -14,8 +14,6 @@ import org.elasticsearch.action.ActionType; import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.logging.LoggerMessageFormat; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BooleanBlock; @@ -33,11 +31,8 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.client.NoOpClient; -import org.elasticsearch.threadpool.FixedExecutorBuilder; -import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; import org.junit.After; import org.junit.Before; @@ -52,17 +47,7 @@ public abstract class InferenceOperatorTestCase { - throw new RuntimeException(e); - })) + assertAnswerUsingSearchCoordinationThreadPool(ActionListener.wrap(inferenceResolutionSetOnce::set, ESTestCase::fail)) ); assertBusy(() -> { @@ -123,9 +121,7 @@ public void testResolveMultipleInferenceIds() throws Exception { inferenceResolver.resolveInferenceIds( inferenceIds, - assertAnswerUsingThreadPool(ActionListener.wrap(inferenceResolutionSetOnce::set, e -> { - throw new RuntimeException(e); - })) + assertAnswerUsingSearchCoordinationThreadPool(ActionListener.wrap(inferenceResolutionSetOnce::set, ESTestCase::fail)) ); assertBusy(() -> { @@ -151,9 +147,7 @@ public void testResolveMissingInferenceIds() throws Exception { inferenceResolver.resolveInferenceIds( inferenceIds, - assertAnswerUsingThreadPool(ActionListener.wrap(inferenceResolutionSetOnce::set, e -> { - throw new RuntimeException(e); - })) + assertAnswerUsingSearchCoordinationThreadPool(ActionListener.wrap(inferenceResolutionSetOnce::set, ESTestCase::fail)) ); assertBusy(() -> { @@ -189,7 +183,7 @@ private Client mockClient() { return client; } - private ActionListener assertAnswerUsingThreadPool(ActionListener actionListener) { + private ActionListener assertAnswerUsingSearchCoordinationThreadPool(ActionListener actionListener) { return ActionListener.runBefore(actionListener, () -> ThreadPool.assertCurrentThreadPool(ThreadPool.Names.SEARCH_COORDINATION)); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunnerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunnerTests.java index dedbf895860b9..896fde3b6759f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunnerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunnerTests.java @@ -9,17 +9,12 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.internal.Client; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.core.TimeValue; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.client.NoOpClient; -import org.elasticsearch.threadpool.FixedExecutorBuilder; -import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; -import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; import org.junit.After; import org.junit.Before; import org.mockito.stubbing.Answer; @@ -46,17 +41,7 @@ public class BulkInferenceRunnerTests extends ESTestCase { @Before public void setThreadPool() { - threadPool = new TestThreadPool( - getTestClass().getSimpleName(), - new FixedExecutorBuilder( - Settings.EMPTY, - EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME, - between(1, 20), - 1024, - "esql", - EsExecutors.TaskTrackingConfig.DEFAULT - ) - ); + threadPool = createThreadPool(); } @After @@ -79,7 +64,8 @@ public void testSuccessfulBulkExecution() throws Exception { AtomicReference> output = new AtomicReference<>(); ActionListener> listener = ActionListener.wrap(output::set, ESTestCase::fail); - inferenceRunnerFactory(client).create(randomBulkExecutionConfig()).executeBulk(requestIterator(requests), listener); + inferenceRunnerFactory(client).create(randomBulkExecutionConfig()) + .executeBulk(requestIterator(requests), assertAnswerUsingSearchThreadPool(listener)); assertBusy(() -> assertThat(output.get(), allOf(notNullValue(), equalTo(responses)))); } @@ -91,7 +77,8 @@ public void testSuccessfulBulkExecutionOnEmptyRequest() throws Exception { AtomicReference> output = new AtomicReference<>(); ActionListener> listener = ActionListener.wrap(output::set, ESTestCase::fail); - inferenceRunnerFactory(new NoOpClient(threadPool)).create(randomBulkExecutionConfig()).executeBulk(requestIterator, listener); + inferenceRunnerFactory(new NoOpClient(threadPool)).create(randomBulkExecutionConfig()) + .executeBulk(requestIterator, assertAnswerUsingSearchThreadPool(listener)); assertBusy(() -> assertThat(output.get(), allOf(notNullValue(), empty()))); } @@ -110,7 +97,8 @@ public void testBulkExecutionWhenInferenceRunnerAlwaysFails() throws Exception { AtomicReference exception = new AtomicReference<>(); ActionListener> listener = ActionListener.wrap(r -> fail("Expected an exception"), exception::set); - inferenceRunnerFactory(client).create(randomBulkExecutionConfig()).executeBulk(requestIterator(requests), listener); + inferenceRunnerFactory(client).create(randomBulkExecutionConfig()) + .executeBulk(requestIterator(requests), assertAnswerUsingSearchThreadPool(listener)); assertBusy(() -> { assertThat(exception.get(), notNullValue()); @@ -137,7 +125,8 @@ public void testBulkExecutionWhenInferenceRunnerSometimesFails() throws Exceptio AtomicReference exception = new AtomicReference<>(); ActionListener> listener = ActionListener.wrap(r -> fail("Expected an exception"), exception::set); - inferenceRunnerFactory(client).create(randomBulkExecutionConfig()).executeBulk(requestIterator(requests), listener); + inferenceRunnerFactory(client).create(randomBulkExecutionConfig()) + .executeBulk(requestIterator(requests), assertAnswerUsingSearchThreadPool(listener)); assertBusy(() -> { assertThat(exception.get(), notNullValue()); @@ -167,13 +156,18 @@ public void testParallelBulkExecution() throws Exception { latch.countDown(); }, ESTestCase::fail); - inferenceRunnerFactory(client).create(randomBulkExecutionConfig()).executeBulk(requestIterator(requests), listener); + inferenceRunnerFactory(client).create(randomBulkExecutionConfig()) + .executeBulk(requestIterator(requests), assertAnswerUsingSearchThreadPool(listener)); }); } latch.await(10, TimeUnit.SECONDS); } + private ActionListener assertAnswerUsingSearchThreadPool(ActionListener actionListener) { + return ActionListener.runBefore(actionListener, () -> ThreadPool.assertCurrentThreadPool(ThreadPool.Names.SEARCH)); + } + private BulkInferenceRunner.Factory inferenceRunnerFactory(Client client) { return BulkInferenceRunner.factory(client); }