Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.lucene.BytesRefs;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
Expand All @@ -30,13 +31,16 @@ public class InferenceResolver {

private final Client client;

private final ThreadPool threadPool;

/**
* Constructs a new {@code InferenceResolver}.
*
* @param client The Elasticsearch client for executing inference deployment lookups
*/
public InferenceResolver(Client client) {
public InferenceResolver(Client client, ThreadPool threadPool) {
this.client = client;
this.threadPool = threadPool;
}

/**
Expand Down Expand Up @@ -99,10 +103,10 @@ void resolveInferenceIds(Set<String> inferenceIds, ActionListener<InferenceResol

final InferenceResolution.Builder inferenceResolutionBuilder = InferenceResolution.builder();

final CountDownActionListener countdownListener = new CountDownActionListener(
inferenceIds.size(),
ActionListener.wrap(_r -> listener.onResponse(inferenceResolutionBuilder.build()), listener::onFailure)
);
final CountDownActionListener countdownListener = new CountDownActionListener(inferenceIds.size(), ActionListener.wrap(_r -> {
threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION)
.execute(() -> listener.onResponse(inferenceResolutionBuilder.build()));
}, listener::onFailure));

for (var inferenceId : inferenceIds) {
client.execute(
Expand Down Expand Up @@ -145,18 +149,20 @@ private static String inferenceId(Expression e) {
}

public static Factory factory(Client client) {
return new Factory(client);
return new Factory(client, client.threadPool());
}

public static class Factory {
private final Client client;
private final ThreadPool threadPool;

private Factory(Client client) {
private Factory(Client client, ThreadPool threadPool) {
this.client = client;
this.threadPool = threadPool;
}

public InferenceResolver create() {
return new InferenceResolver(client);
return new InferenceResolver(client, threadPool);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.FixedExecutorBuilder;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
import org.elasticsearch.xpack.esql.parser.EsqlParser;
import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
import org.junit.After;
import org.junit.Before;

Expand Down Expand Up @@ -51,7 +51,7 @@ public void setThreadPool() {
getTestClass().getSimpleName(),
new FixedExecutorBuilder(
Settings.EMPTY,
EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME,
"inference_utility",
between(1, 10),
1024,
"esql",
Expand Down Expand Up @@ -101,9 +101,12 @@ public void testResolveInferenceIds() throws Exception {
List<String> inferenceIds = List.of("rerank-plan");
SetOnce<InferenceResolution> inferenceResolutionSetOnce = new SetOnce<>();

inferenceResolver.resolveInferenceIds(inferenceIds, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
throw new RuntimeException(e);
}));
inferenceResolver.resolveInferenceIds(
inferenceIds,
assertAnswerUsingThreadPool(ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
throw new RuntimeException(e);
}))
);

assertBusy(() -> {
InferenceResolution inferenceResolution = inferenceResolutionSetOnce.get();
Expand All @@ -118,9 +121,12 @@ public void testResolveMultipleInferenceIds() throws Exception {
List<String> inferenceIds = List.of("rerank-plan", "rerank-plan", "completion-plan");
SetOnce<InferenceResolution> inferenceResolutionSetOnce = new SetOnce<>();

inferenceResolver.resolveInferenceIds(inferenceIds, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
throw new RuntimeException(e);
}));
inferenceResolver.resolveInferenceIds(
inferenceIds,
assertAnswerUsingThreadPool(ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
throw new RuntimeException(e);
}))
);

assertBusy(() -> {
InferenceResolution inferenceResolution = inferenceResolutionSetOnce.get();
Expand All @@ -143,9 +149,12 @@ public void testResolveMissingInferenceIds() throws Exception {

SetOnce<InferenceResolution> inferenceResolutionSetOnce = new SetOnce<>();

inferenceResolver.resolveInferenceIds(inferenceIds, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
throw new RuntimeException(e);
}));
inferenceResolver.resolveInferenceIds(
inferenceIds,
assertAnswerUsingThreadPool(ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
throw new RuntimeException(e);
}))
);

assertBusy(() -> {
InferenceResolution inferenceResolution = inferenceResolutionSetOnce.get();
Expand Down Expand Up @@ -173,21 +182,17 @@ private Client mockClient() {
}
};

if (randomBoolean()) {
sendResponse.run();
} else {
threadPool.schedule(
sendResponse,
TimeValue.timeValueNanos(between(1, 1_000)),
threadPool.executor(EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME)
);
}
threadPool.schedule(sendResponse, TimeValue.timeValueNanos(between(1, 1_000)), threadPool.executor("inference_utility"));

return null;
}).when(client).execute(eq(GetInferenceModelAction.INSTANCE), any(), any());
return client;
}

private <T> ActionListener<T> assertAnswerUsingThreadPool(ActionListener<T> actionListener) {
return ActionListener.runBefore(actionListener, () -> ThreadPool.assertCurrentThreadPool(ThreadPool.Names.SEARCH_COORDINATION));
}

private static ActionResponse getInferenceModelResponse(GetInferenceModelAction.Request request) {
GetInferenceModelAction.Response response = mock(GetInferenceModelAction.Response.class);

Expand All @@ -205,7 +210,7 @@ private static ActionResponse getInferenceModelResponse(GetInferenceModelAction.
}

private InferenceResolver inferenceResolver() {
return new InferenceResolver(mockClient());
return new InferenceResolver(mockClient(), threadPool);
}

private static ModelConfigurations mockModelConfig(String inferenceId, TaskType taskType) {
Expand Down