Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.CountDownActionListener;
import org.elasticsearch.action.support.ThreadedActionListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.lucene.BytesRefs;
import org.elasticsearch.inference.TaskType;
Expand Down Expand Up @@ -103,23 +104,23 @@ void resolveInferenceIds(Set<String> inferenceIds, ActionListener<InferenceResol

final InferenceResolution.Builder inferenceResolutionBuilder = InferenceResolution.builder();

final CountDownActionListener countdownListener = new CountDownActionListener(inferenceIds.size(), ActionListener.wrap(_r -> {
threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION)
.execute(() -> listener.onResponse(inferenceResolutionBuilder.build()));
}, listener::onFailure));
final CountDownActionListener countdownListener = new CountDownActionListener(
inferenceIds.size(),
listener.delegateFailureIgnoreResponseAndWrap(l -> l.onResponse(inferenceResolutionBuilder.build()))
);

for (var inferenceId : inferenceIds) {
client.execute(
GetInferenceModelAction.INSTANCE,
new GetInferenceModelAction.Request(inferenceId, TaskType.ANY),
ActionListener.wrap(r -> {
new ThreadedActionListener<>(threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION), ActionListener.wrap(r -> {
ResolvedInference resolvedInference = new ResolvedInference(inferenceId, r.getEndpoints().getFirst().getTaskType());
inferenceResolutionBuilder.withResolvedInference(resolvedInference);
countdownListener.onResponse(null);
}, e -> {
inferenceResolutionBuilder.withError(inferenceId, e.getMessage());
countdownListener.onResponse(null);
})
}))
);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.esql.inference.bulk;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ThreadedActionListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.threadpool.ThreadPool;
Expand All @@ -25,7 +26,6 @@

import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
import static org.elasticsearch.xpack.esql.plugin.EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME;

/**
* Implementation of bulk inference execution with throttling and concurrency control.
Expand Down Expand Up @@ -88,7 +88,7 @@ public BulkInferenceRequest poll() {
public BulkInferenceRunner(Client client, int maxRunningTasks) {
this.permits = new Semaphore(maxRunningTasks);
this.client = client;
this.executor = client.threadPool().executor(ESQL_WORKER_THREAD_POOL_NAME);
this.executor = client.threadPool().executor(ThreadPool.Names.SEARCH);
}

/**
Expand All @@ -99,7 +99,7 @@ public BulkInferenceRunner(Client client, int maxRunningTasks) {
*/
public void executeBulk(BulkInferenceRequestIterator requests, ActionListener<List<InferenceAction.Response>> listener) {
List<InferenceAction.Response> responses = new ArrayList<>();
executeBulk(requests, responses::add, ActionListener.wrap(ignored -> listener.onResponse(responses), listener::onFailure));
executeBulk(requests, responses::add, listener.delegateFailureIgnoreResponseAndWrap(l -> l.onResponse(responses)));
}

/**
Expand Down Expand Up @@ -253,48 +253,51 @@ private void executePendingRequests(int recursionDepth) {
executionState.finish();
}

final ActionListener<InferenceAction.Response> inferenceResponseListener = ActionListener.runAfter(
ActionListener.wrap(
r -> executionState.onInferenceResponse(bulkRequestItem.seqNo(), r),
e -> executionState.onInferenceException(bulkRequestItem.seqNo(), e)
),
() -> {
// Release the permit we used
permits.release();

try {
synchronized (executionState) {
persistPendingResponses();
}
final ActionListener<InferenceAction.Response> inferenceResponseListener = new ThreadedActionListener<>(
executor,
ActionListener.runAfter(
ActionListener.wrap(
r -> executionState.onInferenceResponse(bulkRequestItem.seqNo(), r),
e -> executionState.onInferenceException(bulkRequestItem.seqNo(), e)
),
() -> {
// Release the permit we used
permits.release();

try {
synchronized (executionState) {
persistPendingResponses();
}

if (executionState.finished() && responseSent.compareAndSet(false, true)) {
onBulkCompletion();
}
if (executionState.finished() && responseSent.compareAndSet(false, true)) {
onBulkCompletion();
}

if (responseSent.get()) {
// Response has already been sent
// No need to continue processing this bulk.
// Check if another bulk request is pending for execution.
BulkInferenceRequest nexBulkRequest = pendingBulkRequests.poll();
if (nexBulkRequest != null) {
executor.execute(nexBulkRequest::executePendingRequests);
if (responseSent.get()) {
// Response has already been sent
// No need to continue processing this bulk.
// Check if another bulk request is pending for execution.
BulkInferenceRequest nexBulkRequest = pendingBulkRequests.poll();
if (nexBulkRequest != null) {
executor.execute(nexBulkRequest::executePendingRequests);
}
return;
}
return;
}
if (executionState.finished() == false) {
// Execute any pending requests if any
if (recursionDepth > 100) {
executor.execute(this::executePendingRequests);
} else {
this.executePendingRequests(recursionDepth + 1);
if (executionState.finished() == false) {
// Execute any pending requests if any
if (recursionDepth > 100) {
executor.execute(this::executePendingRequests);
} else {
this.executePendingRequests(recursionDepth + 1);
}
}
} catch (Exception e) {
if (responseSent.compareAndSet(false, true)) {
completionListener.onFailure(e);
}
}
} catch (Exception e) {
if (responseSent.compareAndSet(false, true)) {
completionListener.onFailure(e);
}
}
}
)
);

// Handle null requests (edge case in some iterators)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
import org.elasticsearch.action.ActionType;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.logging.LoggerMessageFormat;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BooleanBlock;
Expand All @@ -33,11 +31,8 @@
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.test.client.NoOpClient;
import org.elasticsearch.threadpool.FixedExecutorBuilder;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
import org.junit.After;
import org.junit.Before;

Expand All @@ -52,17 +47,7 @@ public abstract class InferenceOperatorTestCase<InferenceResultsType extends Inf

@Before
public void setThreadPool() {
threadPool = new TestThreadPool(
getTestClass().getSimpleName(),
new FixedExecutorBuilder(
Settings.EMPTY,
EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME,
between(1, 10),
1024,
"esql",
EsExecutors.TaskTrackingConfig.DEFAULT
)
);
threadPool = createThreadPool();
}

@Before
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,7 @@ public void testResolveInferenceIds() throws Exception {

inferenceResolver.resolveInferenceIds(
inferenceIds,
assertAnswerUsingThreadPool(ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
throw new RuntimeException(e);
}))
assertAnswerUsingSearchCoordinationThreadPool(ActionListener.wrap(inferenceResolutionSetOnce::set, ESTestCase::fail))
);

assertBusy(() -> {
Expand All @@ -123,9 +121,7 @@ public void testResolveMultipleInferenceIds() throws Exception {

inferenceResolver.resolveInferenceIds(
inferenceIds,
assertAnswerUsingThreadPool(ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
throw new RuntimeException(e);
}))
assertAnswerUsingSearchCoordinationThreadPool(ActionListener.wrap(inferenceResolutionSetOnce::set, ESTestCase::fail))
);

assertBusy(() -> {
Expand All @@ -151,9 +147,7 @@ public void testResolveMissingInferenceIds() throws Exception {

inferenceResolver.resolveInferenceIds(
inferenceIds,
assertAnswerUsingThreadPool(ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
throw new RuntimeException(e);
}))
assertAnswerUsingSearchCoordinationThreadPool(ActionListener.wrap(inferenceResolutionSetOnce::set, ESTestCase::fail))
);

assertBusy(() -> {
Expand Down Expand Up @@ -189,7 +183,7 @@ private Client mockClient() {
return client;
}

private <T> ActionListener<T> assertAnswerUsingThreadPool(ActionListener<T> actionListener) {
private <T> ActionListener<T> assertAnswerUsingSearchCoordinationThreadPool(ActionListener<T> actionListener) {
return ActionListener.runBefore(actionListener, () -> ThreadPool.assertCurrentThreadPool(ThreadPool.Names.SEARCH_COORDINATION));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,12 @@

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.client.NoOpClient;
import org.elasticsearch.threadpool.FixedExecutorBuilder;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
import org.junit.After;
import org.junit.Before;
import org.mockito.stubbing.Answer;
Expand All @@ -46,17 +41,7 @@ public class BulkInferenceRunnerTests extends ESTestCase {

@Before
public void setThreadPool() {
threadPool = new TestThreadPool(
getTestClass().getSimpleName(),
new FixedExecutorBuilder(
Settings.EMPTY,
EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME,
between(1, 20),
1024,
"esql",
EsExecutors.TaskTrackingConfig.DEFAULT
)
);
threadPool = createThreadPool();
}

@After
Expand All @@ -79,7 +64,8 @@ public void testSuccessfulBulkExecution() throws Exception {
AtomicReference<List<InferenceAction.Response>> output = new AtomicReference<>();
ActionListener<List<InferenceAction.Response>> listener = ActionListener.wrap(output::set, ESTestCase::fail);

inferenceRunnerFactory(client).create(randomBulkExecutionConfig()).executeBulk(requestIterator(requests), listener);
inferenceRunnerFactory(client).create(randomBulkExecutionConfig())
.executeBulk(requestIterator(requests), assertAnswerUsingSearchThreadPool(listener));

assertBusy(() -> assertThat(output.get(), allOf(notNullValue(), equalTo(responses))));
}
Expand All @@ -91,7 +77,8 @@ public void testSuccessfulBulkExecutionOnEmptyRequest() throws Exception {
AtomicReference<List<InferenceAction.Response>> output = new AtomicReference<>();
ActionListener<List<InferenceAction.Response>> listener = ActionListener.wrap(output::set, ESTestCase::fail);

inferenceRunnerFactory(new NoOpClient(threadPool)).create(randomBulkExecutionConfig()).executeBulk(requestIterator, listener);
inferenceRunnerFactory(new NoOpClient(threadPool)).create(randomBulkExecutionConfig())
.executeBulk(requestIterator, assertAnswerUsingSearchThreadPool(listener));

assertBusy(() -> assertThat(output.get(), allOf(notNullValue(), empty())));
}
Expand All @@ -110,7 +97,8 @@ public void testBulkExecutionWhenInferenceRunnerAlwaysFails() throws Exception {
AtomicReference<Exception> exception = new AtomicReference<>();
ActionListener<List<InferenceAction.Response>> listener = ActionListener.wrap(r -> fail("Expected an exception"), exception::set);

inferenceRunnerFactory(client).create(randomBulkExecutionConfig()).executeBulk(requestIterator(requests), listener);
inferenceRunnerFactory(client).create(randomBulkExecutionConfig())
.executeBulk(requestIterator(requests), assertAnswerUsingSearchThreadPool(listener));

assertBusy(() -> {
assertThat(exception.get(), notNullValue());
Expand All @@ -137,7 +125,8 @@ public void testBulkExecutionWhenInferenceRunnerSometimesFails() throws Exceptio
AtomicReference<Exception> exception = new AtomicReference<>();
ActionListener<List<InferenceAction.Response>> listener = ActionListener.wrap(r -> fail("Expected an exception"), exception::set);

inferenceRunnerFactory(client).create(randomBulkExecutionConfig()).executeBulk(requestIterator(requests), listener);
inferenceRunnerFactory(client).create(randomBulkExecutionConfig())
.executeBulk(requestIterator(requests), assertAnswerUsingSearchThreadPool(listener));

assertBusy(() -> {
assertThat(exception.get(), notNullValue());
Expand Down Expand Up @@ -167,13 +156,18 @@ public void testParallelBulkExecution() throws Exception {
latch.countDown();
}, ESTestCase::fail);

inferenceRunnerFactory(client).create(randomBulkExecutionConfig()).executeBulk(requestIterator(requests), listener);
inferenceRunnerFactory(client).create(randomBulkExecutionConfig())
.executeBulk(requestIterator(requests), assertAnswerUsingSearchThreadPool(listener));
});
}

latch.await(10, TimeUnit.SECONDS);
}

private <T> ActionListener<T> assertAnswerUsingSearchThreadPool(ActionListener<T> actionListener) {
return ActionListener.runBefore(actionListener, () -> ThreadPool.assertCurrentThreadPool(ThreadPool.Names.SEARCH));
}

private BulkInferenceRunner.Factory inferenceRunnerFactory(Client client) {
return BulkInferenceRunner.factory(client);
}
Expand Down