diff --git a/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/CannedSourceOperator.java b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/CannedSourceOperator.java index 34ce21dad1030..9d168794c8cbb 100644 --- a/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/CannedSourceOperator.java +++ b/x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/CannedSourceOperator.java @@ -84,14 +84,20 @@ public static List deepCopyOf(BlockFactory blockFactory, List pages) try { for (Page p : pages) { Block[] blocks = new Block[p.getBlockCount()]; - for (int b = 0; b < blocks.length; b++) { - Block orig = p.getBlock(b); - try (Block.Builder builder = orig.elementType().newBlockBuilder(p.getPositionCount(), blockFactory)) { - builder.copyFrom(orig, 0, p.getPositionCount()); - blocks[b] = builder.build(); + try { + for (int b = 0; b < blocks.length; b++) { + Block orig = p.getBlock(b); + try (Block.Builder builder = orig.elementType().newBlockBuilder(p.getPositionCount(), blockFactory)) { + builder.copyFrom(orig, 0, p.getPositionCount()); + blocks[b] = builder.build(); + } } + out.add(new Page(blocks)); + } catch (Exception e) { + // Something went wrong, release the blocks. + Releasables.closeExpectNoException(blocks); + throw e; } - out.add(new Page(blocks)); } } finally { if (pages.size() != out.size()) { diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java index ed15caa17ad3d..7f9f77509310e 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java @@ -12,7 +12,7 @@ import org.apache.lucene.sandbox.document.HalfFloatPoint; import org.apache.lucene.util.BytesRef; import org.elasticsearch.ExceptionsHelper; -import org.elasticsearch.action.ActionListener; +import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.RemoteException; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.project.ProjectResolver; @@ -76,7 +76,7 @@ import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals; import org.elasticsearch.xpack.esql.index.EsIndex; import org.elasticsearch.xpack.esql.inference.InferenceResolution; -import org.elasticsearch.xpack.esql.inference.InferenceRunner; +import org.elasticsearch.xpack.esql.inference.InferenceService; import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; import org.elasticsearch.xpack.esql.parser.QueryParam; import org.elasticsearch.xpack.esql.plan.logical.Enrich; @@ -161,8 +161,6 @@ import static org.hamcrest.Matchers.instanceOf; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; public final class EsqlTestUtils { @@ -422,20 +420,9 @@ public static LogicalOptimizerContext unboundLogicalOptimizerContext() { mock(ProjectResolver.class), mock(IndexNameExpressionResolver.class), null, - mockInferenceRunner() + new InferenceService(mock(Client.class)) ); - @SuppressWarnings("unchecked") - private static InferenceRunner mockInferenceRunner() { - InferenceRunner inferenceRunner = mock(InferenceRunner.class); - doAnswer(i -> { - i.getArgument(1, ActionListener.class).onResponse(emptyInferenceResolution()); - return null; - }).when(inferenceRunner).resolveInferenceIds(any(), any()); - - return inferenceRunner; - } - private EsqlTestUtils() {} public static Configuration configuration(QueryPragmas pragmas, String query) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java index 72a68663e41b5..a1a13b72c6ab1 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java @@ -188,9 +188,9 @@ public class Analyzer extends ParameterizedRuleExecutor( @@ -414,34 +414,6 @@ private static NamedExpression createEnrichFieldExpression( } } - private static class ResolveInference extends ParameterizedAnalyzerRule, AnalyzerContext> { - @Override - protected LogicalPlan rule(InferencePlan plan, AnalyzerContext context) { - assert plan.inferenceId().resolved() && plan.inferenceId().foldable(); - - String inferenceId = BytesRefs.toString(plan.inferenceId().fold(FoldContext.small())); - ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId); - - if (resolvedInference != null && resolvedInference.taskType() == plan.taskType()) { - return plan; - } else if (resolvedInference != null) { - String error = "cannot use inference endpoint [" - + inferenceId - + "] with task type [" - + resolvedInference.taskType() - + "] within a " - + plan.nodeName() - + " command. Only inference endpoints with the task type [" - + plan.taskType() - + "] are supported."; - return plan.withInferenceResolutionError(inferenceId, error); - } else { - String error = context.inferenceResolution().getError(inferenceId); - return plan.withInferenceResolutionError(inferenceId, error); - } - } - } - private static class ResolveLookupTables extends ParameterizedAnalyzerRule { @Override @@ -1335,6 +1307,41 @@ public static org.elasticsearch.xpack.esql.core.expression.function.Function res } } + private static class ResolveInference extends ParameterizedRule { + + @Override + public LogicalPlan apply(LogicalPlan plan, AnalyzerContext context) { + return plan.transformDown(InferencePlan.class, p -> resolveInferencePlan(p, context)); + } + + private LogicalPlan resolveInferencePlan(InferencePlan plan, AnalyzerContext context) { + assert plan.inferenceId().resolved() && plan.inferenceId().foldable(); + + String inferenceId = BytesRefs.toString(plan.inferenceId().fold(FoldContext.small())); + ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId); + + if (resolvedInference == null) { + String error = context.inferenceResolution().getError(inferenceId); + return plan.withInferenceResolutionError(inferenceId, error); + } + + if (resolvedInference.taskType() != plan.taskType()) { + String error = "cannot use inference endpoint [" + + inferenceId + + "] with task type [" + + resolvedInference.taskType() + + "] within a " + + plan.nodeName() + + " command. Only inference endpoints with the task type [" + + plan.taskType() + + "] are supported."; + return plan.withInferenceResolutionError(inferenceId, error); + } + + return plan; + } + } + private static class AddImplicitLimit extends ParameterizedRule { @Override public LogicalPlan apply(LogicalPlan logicalPlan, AnalyzerContext context) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzer.java index 5b9f41876d6e1..6b4570dbb2f6a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzer.java @@ -13,7 +13,6 @@ import org.elasticsearch.xpack.esql.plan.logical.Enrich; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation; -import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan; import java.util.ArrayList; import java.util.HashSet; @@ -28,25 +27,17 @@ public class PreAnalyzer { public static class PreAnalysis { - public static final PreAnalysis EMPTY = new PreAnalysis(null, emptyList(), emptyList(), emptyList(), emptyList()); + public static final PreAnalysis EMPTY = new PreAnalysis(null, emptyList(), emptyList(), emptyList()); public final IndexMode indexMode; public final List indices; public final List enriches; - public final List> inferencePlans; public final List lookupIndices; - public PreAnalysis( - IndexMode indexMode, - List indices, - List enriches, - List> inferencePlans, - List lookupIndices - ) { + public PreAnalysis(IndexMode indexMode, List indices, List enriches, List lookupIndices) { this.indexMode = indexMode; this.indices = indices; this.enriches = enriches; - this.inferencePlans = inferencePlans; this.lookupIndices = lookupIndices; } } @@ -64,7 +55,7 @@ protected PreAnalysis doPreAnalyze(LogicalPlan plan) { List unresolvedEnriches = new ArrayList<>(); List lookupIndices = new ArrayList<>(); - List> unresolvedInferencePlans = new ArrayList<>(); + Holder indexMode = new Holder<>(); plan.forEachUp(UnresolvedRelation.class, p -> { if (p.indexMode() == IndexMode.LOOKUP) { @@ -78,11 +69,11 @@ protected PreAnalysis doPreAnalyze(LogicalPlan plan) { }); plan.forEachUp(Enrich.class, unresolvedEnriches::add); - plan.forEachUp(InferencePlan.class, unresolvedInferencePlans::add); // mark plan as preAnalyzed (if it were marked, there would be no analysis) plan.forEachUp(LogicalPlan::setPreAnalyzed); - return new PreAnalysis(indexMode.get(), indices.stream().toList(), unresolvedEnriches, unresolvedInferencePlans, lookupIndices); + return new PreAnalysis(indexMode.get(), indices.stream().toList(), unresolvedEnriches, lookupIndices); } + } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java index fe6ab6e9a998c..93085969415a6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java @@ -15,18 +15,16 @@ import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutionConfig; -import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutor; import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner; import java.util.List; import static org.elasticsearch.common.logging.LoggerMessageFormat.format; /** - * An abstract asynchronous operator that performs throttled bulk inference execution using an {@link InferenceRunner}. + * An abstract asynchronous operator that performs throttled bulk inference execution using an {@link InferenceResolver}. *

* The {@code InferenceOperator} integrates with the compute framework supports throttled bulk execution of inference requests. It * transforms input {@link Page} into inference requests, asynchronously executes them, and converts the responses into a new {@link Page}. @@ -35,27 +33,25 @@ public abstract class InferenceOperator extends AsyncOperator { private final String inferenceId; private final BlockFactory blockFactory; - private final BulkInferenceExecutor bulkInferenceExecutor; + private final BulkInferenceRunner bulkInferenceRunner; /** * Constructs a new {@code InferenceOperator}. * * @param driverContext The driver context. - * @param inferenceRunner The runner used to execute inference requests. - * @param bulkExecutionConfig Configuration for inference execution. - * @param threadPool The thread pool used for executing async inference. + * @param bulkInferenceRunner Inference runner used to execute inference requests. * @param inferenceId The ID of the inference model to use. + * @param maxOutstandingPages The number of concurrent pages to process in parallel. */ public InferenceOperator( DriverContext driverContext, - InferenceRunner inferenceRunner, - BulkInferenceExecutionConfig bulkExecutionConfig, - ThreadPool threadPool, - String inferenceId + BulkInferenceRunner bulkInferenceRunner, + String inferenceId, + int maxOutstandingPages ) { - super(driverContext, inferenceRunner.threadPool().getThreadContext(), bulkExecutionConfig.workers()); + super(driverContext, bulkInferenceRunner.threadPool().getThreadContext(), maxOutstandingPages); this.blockFactory = driverContext.blockFactory(); - this.bulkInferenceExecutor = new BulkInferenceExecutor(inferenceRunner, threadPool, bulkExecutionConfig); + this.bulkInferenceRunner = bulkInferenceRunner; this.inferenceId = inferenceId; } @@ -81,7 +77,8 @@ protected void performAsync(Page input, ActionListener l try { BulkInferenceRequestIterator requests = requests(input); listener = ActionListener.releaseBefore(requests, listener); - bulkInferenceExecutor.execute(requests, listener.map(responses -> new OngoingInferenceResult(input, responses))); + + bulkInferenceRunner.executeBulk(requests, listener.map(responses -> new OngoingInferenceResult(input, responses))); } catch (Exception e) { listener.onFailure(e); } @@ -110,9 +107,9 @@ public Page getOutput() { outputBuilder.addInferenceResponse(response); } return outputBuilder.buildOutput(); - - } finally { + } catch (Exception e) { releaseFetchedOnAnyThread(ongoingInferenceResult); + throw e; } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java new file mode 100644 index 0000000000000..f7d349281e004 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java @@ -0,0 +1,162 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.CountDownActionListener; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.lucene.BytesRefs; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan; + +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.function.Consumer; + +/** + * Collects and resolves inference deployments inference IDs from ES|QL logical plans. + */ +public class InferenceResolver { + + private final Client client; + + /** + * Constructs a new {@code InferenceResolver}. + * + * @param client The Elasticsearch client for executing inference deployment lookups + */ + public InferenceResolver(Client client) { + this.client = client; + } + + /** + * Resolves inference IDs from the given logical plan. + *

+ * This method traverses the logical plan tree and identifies all inference operations, + * extracting their deployment IDs for subsequent validation. Currently, supports: + *

    + *
  • {@link InferencePlan} objects (Completion, etc.)
  • + *
+ * + * @param plan The logical plan to scan for inference operations + * @param listener Callback to receive the resolution results + */ + public void resolveInferenceIds(LogicalPlan plan, ActionListener listener) { + List inferenceIds = new ArrayList<>(); + collectInferenceIds(plan, inferenceIds::add); + resolveInferenceIds(inferenceIds, listener); + } + + /** + * Collects all inference IDs from the given logical plan. + *

+ * This method traverses the logical plan tree and identifies all inference operations, + * extracting their deployment IDs for subsequent validation. Currently, supports: + *

    + *
  • {@link InferencePlan} objects (Completion, etc.)
  • + *
+ * + * @param plan The logical plan to scan for inference operations + * @param c Consumer function to receive each discovered inference ID + */ + void collectInferenceIds(LogicalPlan plan, Consumer c) { + collectInferenceIdsFromInferencePlans(plan, c); + } + + /** + * Resolves a list of inference deployment IDs to their metadata. + *

+ * For each inference ID, this method: + *

    + *
  1. Queries the inference service to verify the deployment exists
  2. + *
  3. Retrieves the deployment's task type and configuration
  4. + *
  5. Builds an {@link InferenceResolution} containing resolved metadata or errors
  6. + *
+ * + * @param inferenceIds List of inference deployment IDs to resolve + * @param listener Callback to receive the resolution results + */ + void resolveInferenceIds(List inferenceIds, ActionListener listener) { + resolveInferenceIds(Set.copyOf(inferenceIds), listener); + } + + void resolveInferenceIds(Set inferenceIds, ActionListener listener) { + + if (inferenceIds.isEmpty()) { + listener.onResponse(InferenceResolution.EMPTY); + return; + } + + final InferenceResolution.Builder inferenceResolutionBuilder = InferenceResolution.builder(); + + final CountDownActionListener countdownListener = new CountDownActionListener( + inferenceIds.size(), + ActionListener.wrap(_r -> listener.onResponse(inferenceResolutionBuilder.build()), listener::onFailure) + ); + + for (var inferenceId : inferenceIds) { + client.execute( + GetInferenceModelAction.INSTANCE, + new GetInferenceModelAction.Request(inferenceId, TaskType.ANY), + ActionListener.wrap(r -> { + ResolvedInference resolvedInference = new ResolvedInference(inferenceId, r.getEndpoints().getFirst().getTaskType()); + inferenceResolutionBuilder.withResolvedInference(resolvedInference); + countdownListener.onResponse(null); + }, e -> { + inferenceResolutionBuilder.withError(inferenceId, e.getMessage()); + countdownListener.onResponse(null); + }) + ); + } + } + + /** + * Collects inference IDs from InferencePlan objects within the logical plan. + * + * @param plan The logical plan to scan for InferencePlan objects + * @param c Consumer function to receive each discovered inference ID + */ + private void collectInferenceIdsFromInferencePlans(LogicalPlan plan, Consumer c) { + plan.forEachUp(InferencePlan.class, inferencePlan -> c.accept(inferenceId(inferencePlan))); + } + + /** + * Extracts the inference ID from an InferencePlan object. + * + * @param plan The InferencePlan object to extract the ID from + * @return The inference ID as a string + */ + private static String inferenceId(InferencePlan plan) { + return inferenceId(plan.inferenceId()); + } + + private static String inferenceId(Expression e) { + return BytesRefs.toString(e.fold(FoldContext.small())); + } + + public static Factory factory(Client client) { + return new Factory(client); + } + + public static class Factory { + private final Client client; + + private Factory(Client client) { + this.client = client; + } + + public InferenceResolver create() { + return new InferenceResolver(client); + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceRunner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceRunner.java deleted file mode 100644 index d67d6817742c0..0000000000000 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceRunner.java +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.esql.inference; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.support.CountDownActionListener; -import org.elasticsearch.client.internal.Client; -import org.elasticsearch.common.lucene.BytesRefs; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.esql.core.expression.FoldContext; -import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan; - -import java.util.List; -import java.util.Set; -import java.util.stream.Collectors; - -import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN; -import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; - -public class InferenceRunner { - - private final Client client; - private final ThreadPool threadPool; - - public InferenceRunner(Client client, ThreadPool threadPool) { - this.client = client; - this.threadPool = threadPool; - } - - public ThreadPool threadPool() { - return threadPool; - } - - public void resolveInferenceIds(List> plans, ActionListener listener) { - resolveInferenceIds(plans.stream().map(InferenceRunner::planInferenceId).collect(Collectors.toSet()), listener); - - } - - private void resolveInferenceIds(Set inferenceIds, ActionListener listener) { - - if (inferenceIds.isEmpty()) { - listener.onResponse(InferenceResolution.EMPTY); - return; - } - - final InferenceResolution.Builder inferenceResolutionBuilder = InferenceResolution.builder(); - - final CountDownActionListener countdownListener = new CountDownActionListener( - inferenceIds.size(), - ActionListener.wrap(_r -> listener.onResponse(inferenceResolutionBuilder.build()), listener::onFailure) - ); - - for (var inferenceId : inferenceIds) { - client.execute( - GetInferenceModelAction.INSTANCE, - new GetInferenceModelAction.Request(inferenceId, TaskType.ANY), - ActionListener.wrap(r -> { - ResolvedInference resolvedInference = new ResolvedInference(inferenceId, r.getEndpoints().getFirst().getTaskType()); - inferenceResolutionBuilder.withResolvedInference(resolvedInference); - countdownListener.onResponse(null); - }, e -> { - inferenceResolutionBuilder.withError(inferenceId, e.getMessage()); - countdownListener.onResponse(null); - }) - ); - } - } - - private static String planInferenceId(InferencePlan plan) { - return BytesRefs.toString(plan.inferenceId().fold(FoldContext.small())); - } - - public void doInference(InferenceAction.Request request, ActionListener listener) { - executeAsyncWithOrigin(client, INFERENCE_ORIGIN, InferenceAction.INSTANCE, request, listener); - } -} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceService.java new file mode 100644 index 0000000000000..37c163beaecda --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceService.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference; + +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunnerConfig; + +public class InferenceService { + private final InferenceResolver.Factory inferenceResolverFactory; + + private final BulkInferenceRunner.Factory bulkInferenceRunnerFactory; + + /** + * Creates a new inference service with the given client. + * + * @param client the Elasticsearch client for inference operations + */ + public InferenceService(Client client) { + this(InferenceResolver.factory(client), BulkInferenceRunner.factory(client)); + } + + private InferenceService(InferenceResolver.Factory inferenceResolverFactory, BulkInferenceRunner.Factory bulkInferenceRunnerFactory) { + this.inferenceResolverFactory = inferenceResolverFactory; + this.bulkInferenceRunnerFactory = bulkInferenceRunnerFactory; + } + + /** + * Creates an inference resolver for resolving inference IDs in logical plans. + * + * @return a new inference resolver instance + */ + public InferenceResolver inferenceResolver() { + return inferenceResolverFactory.create(); + } + + public BulkInferenceRunner bulkInferenceRunner() { + return bulkInferenceRunner(BulkInferenceRunnerConfig.DEFAULT); + } + + public BulkInferenceRunner bulkInferenceRunner(BulkInferenceRunnerConfig bulkInferenceRunnerConfig) { + return bulkInferenceRunnerFactory.create(bulkInferenceRunnerConfig); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutionConfig.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutionConfig.java deleted file mode 100644 index 8bc48a908fe22..0000000000000 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutionConfig.java +++ /dev/null @@ -1,18 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.esql.inference.bulk; - -public record BulkInferenceExecutionConfig(int workers, int maxOutstandingRequests) { - public static final int DEFAULT_WORKERS = 10; - public static final int DEFAULT_MAX_OUTSTANDING_REQUESTS = 50; - - public static final BulkInferenceExecutionConfig DEFAULT = new BulkInferenceExecutionConfig( - DEFAULT_WORKERS, - DEFAULT_MAX_OUTSTANDING_REQUESTS - ); -} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutionState.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutionState.java index 55f1f49f68c21..48303be617286 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutionState.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutionState.java @@ -27,8 +27,8 @@ public class BulkInferenceExecutionState { private final Map bufferedResponses; private final AtomicBoolean finished = new AtomicBoolean(false); - public BulkInferenceExecutionState(int bufferSize) { - this.bufferedResponses = new ConcurrentHashMap<>(bufferSize); + public BulkInferenceExecutionState() { + this.bufferedResponses = new ConcurrentHashMap<>(); } /** @@ -125,7 +125,7 @@ public void addFailure(Exception e) { * Indicates whether the entire bulk execution is marked as finished and all responses have been successfully persisted. */ public boolean finished() { - return finished.get() && getMaxSeqNo() == getPersistedCheckpoint(); + return hasFailure() || (finished.get() && getMaxSeqNo() == getPersistedCheckpoint()); } /** diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutor.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutor.java deleted file mode 100644 index 257799962dda7..0000000000000 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutor.java +++ /dev/null @@ -1,260 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.esql.inference.bulk; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.common.util.concurrent.AbstractRunnable; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.esql.inference.InferenceRunner; -import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; - -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.ArrayBlockingQueue; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Semaphore; -import java.util.concurrent.atomic.AtomicBoolean; - -/** - * Executes a sequence of inference requests in bulk with throttling and concurrency control. - */ -public class BulkInferenceExecutor { - private final ThrottledInferenceRunner throttledInferenceRunner; - private final BulkInferenceExecutionConfig bulkExecutionConfig; - - /** - * Constructs a new {@code BulkInferenceExecutor}. - * - * @param inferenceRunner The inference runner used to execute individual inference requests. - * @param threadPool The thread pool for executing inference tasks. - * @param bulkExecutionConfig Configuration options (throttling and concurrency limits). - */ - public BulkInferenceExecutor(InferenceRunner inferenceRunner, ThreadPool threadPool, BulkInferenceExecutionConfig bulkExecutionConfig) { - this.throttledInferenceRunner = ThrottledInferenceRunner.create(inferenceRunner, executorService(threadPool), bulkExecutionConfig); - this.bulkExecutionConfig = bulkExecutionConfig; - } - - /** - * Executes the provided bulk inference requests. - *

- * Each request is sent to the {@link ThrottledInferenceRunner} to be executed. - * The final listener is notified with all successful responses once all requests are completed. - * - * @param requests An iterator over the inference requests to be executed. - * @param listener A listener notified with the complete list of responses or a failure. - */ - public void execute(BulkInferenceRequestIterator requests, ActionListener> listener) { - if (requests.hasNext() == false) { - listener.onResponse(List.of()); - return; - } - - final BulkInferenceExecutionState bulkExecutionState = new BulkInferenceExecutionState( - bulkExecutionConfig.maxOutstandingRequests() - ); - final ResponseHandler responseHandler = new ResponseHandler(bulkExecutionState, listener, requests.estimatedSize()); - - while (bulkExecutionState.finished() == false && requests.hasNext()) { - InferenceAction.Request request = requests.next(); - long seqNo = bulkExecutionState.generateSeqNo(); - - if (requests.hasNext() == false) { - bulkExecutionState.finish(); - } - - ActionListener inferenceResponseListener = ActionListener.runAfter( - ActionListener.wrap( - r -> bulkExecutionState.onInferenceResponse(seqNo, r), - e -> bulkExecutionState.onInferenceException(seqNo, e) - ), - responseHandler::persistPendingResponses - ); - - if (request == null) { - inferenceResponseListener.onResponse(null); - } else { - throttledInferenceRunner.doInference(request, inferenceResponseListener); - } - } - } - - /** - * Handles collection and delivery of inference responses once they are complete. - */ - private static class ResponseHandler { - private final List responses; - private final ActionListener> listener; - private final BulkInferenceExecutionState bulkExecutionState; - private final AtomicBoolean responseSent = new AtomicBoolean(false); - - private ResponseHandler( - BulkInferenceExecutionState bulkExecutionState, - ActionListener> listener, - int estimatedSize - ) { - this.listener = listener; - this.bulkExecutionState = bulkExecutionState; - this.responses = new ArrayList<>(estimatedSize); - } - - /** - * Persists all buffered responses that can be delivered in order, and sends the final response if all requests are finished. - */ - public synchronized void persistPendingResponses() { - long persistedSeqNo = bulkExecutionState.getPersistedCheckpoint(); - - while (persistedSeqNo < bulkExecutionState.getProcessedCheckpoint()) { - persistedSeqNo++; - if (bulkExecutionState.hasFailure() == false) { - try { - InferenceAction.Response response = bulkExecutionState.fetchBufferedResponse(persistedSeqNo); - responses.add(response); - } catch (Exception e) { - bulkExecutionState.addFailure(e); - } - } - bulkExecutionState.markSeqNoAsPersisted(persistedSeqNo); - } - - sendResponseOnCompletion(); - } - - /** - * Sends the final response or failure once all inference tasks have completed. - */ - private void sendResponseOnCompletion() { - if (bulkExecutionState.finished() && responseSent.compareAndSet(false, true)) { - if (bulkExecutionState.hasFailure() == false) { - try { - listener.onResponse(responses); - return; - } catch (Exception e) { - bulkExecutionState.addFailure(e); - } - } - - listener.onFailure(bulkExecutionState.getFailure()); - } - } - } - - /** - * Manages throttled inference tasks execution. - */ - private static class ThrottledInferenceRunner { - private final InferenceRunner inferenceRunner; - private final ExecutorService executorService; - private final BlockingQueue pendingRequestsQueue; - private final Semaphore permits; - - private ThrottledInferenceRunner(InferenceRunner inferenceRunner, ExecutorService executorService, int maxRunningTasks) { - this.executorService = executorService; - this.permits = new Semaphore(maxRunningTasks); - this.inferenceRunner = inferenceRunner; - this.pendingRequestsQueue = new ArrayBlockingQueue<>(maxRunningTasks); - } - - /** - * Creates a new {@code ThrottledInferenceRunner} with the specified configuration. - * - * @param inferenceRunner TThe inference runner used to execute individual inference requests. - * @param executorService The executor used for asynchronous execution. - * @param bulkExecutionConfig Configuration options (throttling and concurrency limits). - */ - public static ThrottledInferenceRunner create( - InferenceRunner inferenceRunner, - ExecutorService executorService, - BulkInferenceExecutionConfig bulkExecutionConfig - ) { - return new ThrottledInferenceRunner(inferenceRunner, executorService, bulkExecutionConfig.maxOutstandingRequests()); - } - - /** - * Schedules the inference task for execution. If a permit is available, the task runs immediately; otherwise, it is queued. - * - * @param request The inference request. - * @param listener The listener to notify on response or failure. - */ - public void doInference(InferenceAction.Request request, ActionListener listener) { - enqueueTask(request, listener); - executePendingRequests(); - } - - /** - * Attempts to execute as many pending inference tasks as possible, limited by available permits. - */ - private void executePendingRequests() { - while (permits.tryAcquire()) { - AbstractRunnable task = pendingRequestsQueue.poll(); - - if (task == null) { - permits.release(); - return; - } - - try { - executorService.execute(task); - } catch (Exception e) { - task.onFailure(e); - permits.release(); - } - } - } - - /** - * Add an inference task to the queue. - * - * @param request The inference request. - * * @param listener The listener to notify on response or failure. - */ - private void enqueueTask(InferenceAction.Request request, ActionListener listener) { - try { - pendingRequestsQueue.put(createTask(request, listener)); - } catch (Exception e) { - listener.onFailure(new IllegalStateException("An error occurred while adding the inference request to the queue", e)); - } - } - - /** - * Wraps an inference request into an {@link AbstractRunnable} that releases its permit on completion and triggers any remaining - * queued tasks. - * - * @param request The inference request. - * @param listener The listener to notify on completion. - * @return A runnable task encapsulating the request. - */ - private AbstractRunnable createTask(InferenceAction.Request request, ActionListener listener) { - final ActionListener completionListener = ActionListener.runAfter(listener, () -> { - permits.release(); - executePendingRequests(); - }); - - return new AbstractRunnable() { - @Override - protected void doRun() { - try { - inferenceRunner.doInference(request, completionListener); - } catch (Throwable e) { - listener.onFailure(new RuntimeException("Unexpected failure while running inference", e)); - } - } - - @Override - public void onFailure(Exception e) { - completionListener.onFailure(e); - } - }; - } - } - - private static ExecutorService executorService(ThreadPool threadPool) { - return threadPool.executor(EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME); - } -} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java new file mode 100644 index 0000000000000..203a3031bcad4 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java @@ -0,0 +1,394 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.bulk; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.util.concurrent.ConcurrentCollections; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; + +import java.util.ArrayList; +import java.util.List; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Semaphore; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Consumer; + +import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN; +import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; +import static org.elasticsearch.xpack.esql.plugin.EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME; + +/** + * Implementation of bulk inference execution with throttling and concurrency control. + *

+ * This runner limits the number of concurrent inference requests using a semaphore-based + * permit system. When all permits are exhausted, additional requests are queued and + * executed as permits become available. + *

+ *

+ * Response processing is always executed in the ESQL worker thread pool to ensure + * consistent thread context and avoid thread safety issues with circuit breakers + * and other non-thread-safe components. + *

+ */ +public class BulkInferenceRunner { + + private final Client client; + private final Semaphore permits; + private final ExecutorService executor; + + /** + * Custom concurrent queue that prevents duplicate bulk requests from being queued. + *

+ * This queue implementation ensures fairness among multiple concurrent bulk operations + * by preventing the same bulk request from being queued multiple times. It uses a + * backing concurrent set to track which requests are already queued. + *

+ */ + private final Queue pendingBulkRequests = new ConcurrentLinkedQueue<>() { + private final Set requests = ConcurrentCollections.newConcurrentSet(); + + @Override + public boolean offer(BulkInferenceRequest bulkInferenceRequest) { + synchronized (requests) { + if (requests.add(bulkInferenceRequest)) { + return super.offer(bulkInferenceRequest); + } + return false; // Already exists, don't add duplicate + } + } + + @Override + public BulkInferenceRequest poll() { + synchronized (requests) { + BulkInferenceRequest request = super.poll(); + if (request != null) { + requests.remove(request); + } + return request; + } + } + }; + + /** + * Constructs a new throttled inference runner with the specified configuration. + * + * @param client Client for executing inference requests + * @param maxRunningTasks The maximum number of concurrent inference requests allowed + */ + public BulkInferenceRunner(Client client, int maxRunningTasks) { + this.permits = new Semaphore(maxRunningTasks); + this.client = client; + this.executor = client.threadPool().executor(ESQL_WORKER_THREAD_POOL_NAME); + } + + /** + * Executes multiple inference requests in bulk and collects all responses. + * + * @param requests An iterator over the inference requests to execute + * @param listener Called with the list of all responses in request order + */ + public void executeBulk(BulkInferenceRequestIterator requests, ActionListener> listener) { + List responses = new ArrayList<>(); + executeBulk(requests, responses::add, ActionListener.wrap(ignored -> listener.onResponse(responses), listener::onFailure)); + } + + /** + * Executes multiple inference requests in bulk with streaming response handling. + *

+ * This method orchestrates the entire bulk inference process: + * 1. Creates execution state to track progress and responses + * 2. Sets up response handling pipeline + * 3. Initiates asynchronous request processing + *

+ * + * @param requests An iterator over the inference requests to execute + * @param responseConsumer Called for each successful inference response as they complete + * @param completionListener Called when all requests are complete or if any error occurs + */ + public void executeBulk( + BulkInferenceRequestIterator requests, + Consumer responseConsumer, + ActionListener completionListener + ) { + if (requests.hasNext() == false) { + completionListener.onResponse(null); + return; + } + + new BulkInferenceRequest(requests, responseConsumer, completionListener).executePendingRequests(); + } + + /** + * Returns the thread pool used for executing inference requests. + */ + public ThreadPool threadPool() { + return client.threadPool(); + } + + /** + * Encapsulates the execution state and logic for a single bulk inference operation. + *

+ * This inner class manages the complete lifecycle of a bulk inference request, including: + * - Request iteration and permit-based concurrency control + * - Asynchronous execution with hybrid recursion strategy + * - Response collection and ordering via execution state + * - Error handling and completion notification + *

+ *

+ * Each BulkInferenceRequest instance represents one bulk operation that may contain + * multiple individual inference requests. Multiple BulkInferenceRequest instances + * can execute concurrently, with fairness ensured through the pending queue mechanism. + *

+ */ + private class BulkInferenceRequest { + private final BulkInferenceRequestIterator requests; + private final Consumer responseConsumer; + private final ActionListener completionListener; + + private final BulkInferenceExecutionState executionState = new BulkInferenceExecutionState(); + private final AtomicBoolean responseSent = new AtomicBoolean(false); + + BulkInferenceRequest( + BulkInferenceRequestIterator requests, + Consumer responseConsumer, + ActionListener completionListener + ) { + this.requests = requests; + this.responseConsumer = responseConsumer; + this.completionListener = completionListener; + } + + /** + * Attempts to poll the next request from the iterator and acquire a permit for execution. + *

+ * Because multiple threads may call this concurrently via async callbacks, this method is synchronized to ensure thread-safe access + * to the request iterator. + *

+ * + * @return A BulkRequestItem if a request and permit are available, null otherwise + */ + private BulkRequestItem pollPendingRequest() { + synchronized (requests) { + if (requests.hasNext()) { + return new BulkRequestItem(executionState.generateSeqNo(), requests.next()); + } + } + + return null; + } + + /** + * Main execution loop that processes inference requests asynchronously with hybrid recursion strategy. + *

+ * This method implements a continuation-based asynchronous pattern with the following features: + * - Queue-based fairness: Multiple bulk requests can be queued and processed fairly + * - Permit-based concurrency control: Limits concurrent inference requests using semaphores + * - Hybrid recursion strategy: Uses direct recursion for performance up to 100 levels, + * then switches to executor-based continuation to prevent stack overflow + * - Duplicate prevention: Custom queue prevents the same bulk request from being queued multiple times + *

+ *

+ * Execution flow: + * 1. Attempts to acquire a permit for concurrent execution + * 2. If no permit available, queues this bulk request for later execution + * 3. Polls for the next available request from the iterator + * 4. If no requests available, schedules the next queued bulk request + * 5. Executes the request asynchronously with proper continuation handling + * 6. Uses hybrid recursion: direct calls up to 100 levels, executor-based beyond that + *

+ *

+ * The loop terminates when: + * - No more requests are available and no permits can be acquired + * - The bulk execution is marked as finished (due to completion or failure) + * - An unrecoverable error occurs during processing + *

+ */ + private void executePendingRequests() { + executePendingRequests(0); + } + + private void executePendingRequests(int recursionDepth) { + try { + while (executionState.finished() == false) { + if (permits.tryAcquire() == false) { + if (requests.hasNext()) { + pendingBulkRequests.add(this); + } + return; + } else { + BulkRequestItem bulkRequestItem = pollPendingRequest(); + + if (bulkRequestItem == null) { + // No more requests available + // Release the permit we didn't used and stop processing + permits.release(); + + // Check if another bulk request is pending for execution. + BulkInferenceRequest nexBulkRequest = pendingBulkRequests.poll(); + + while (nexBulkRequest == this) { + nexBulkRequest = pendingBulkRequests.poll(); + } + + if (nexBulkRequest != null) { + executor.execute(nexBulkRequest::executePendingRequests); + } + + return; + } + + if (requests.hasNext() == false) { + // This is the last request - mark bulk execution as finished + // to prevent further processing attempts + executionState.finish(); + } + + final ActionListener inferenceResponseListener = ActionListener.runAfter( + ActionListener.wrap( + r -> executionState.onInferenceResponse(bulkRequestItem.seqNo(), r), + e -> executionState.onInferenceException(bulkRequestItem.seqNo(), e) + ), + () -> { + // Release the permit we used + permits.release(); + + try { + synchronized (executionState) { + persistPendingResponses(); + } + + if (executionState.finished() && responseSent.compareAndSet(false, true)) { + onBulkCompletion(); + } + + if (responseSent.get()) { + // Response has already been sent + // No need to continue processing this bulk. + // Check if another bulk request is pending for execution. + BulkInferenceRequest nexBulkRequest = pendingBulkRequests.poll(); + if (nexBulkRequest != null) { + executor.execute(nexBulkRequest::executePendingRequests); + } + return; + } + if (executionState.finished() == false) { + // Execute any pending requests if any + if (recursionDepth > 100) { + executor.execute(this::executePendingRequests); + } else { + this.executePendingRequests(recursionDepth + 1); + } + } + } catch (Exception e) { + if (responseSent.compareAndSet(false, true)) { + completionListener.onFailure(e); + } + } + } + ); + + // Handle null requests (edge case in some iterators) + if (bulkRequestItem.request() == null) { + inferenceResponseListener.onResponse(null); + return; + } + + // Execute the inference request with proper origin context + executeAsyncWithOrigin( + client, + INFERENCE_ORIGIN, + InferenceAction.INSTANCE, + bulkRequestItem.request(), + inferenceResponseListener + ); + } + } + } catch (Exception e) { + executionState.addFailure(e); + } + } + + /** + * Processes and delivers buffered responses in order, ensuring proper sequencing. + *

+ * This method is synchronized to ensure thread-safe access to the execution state + * and prevent concurrent response processing which could cause ordering issues. + * Processing stops immediately if a failure is detected to implement fail-fast behavior. + *

+ */ + private void persistPendingResponses() { + long persistedSeqNo = executionState.getPersistedCheckpoint(); + + while (persistedSeqNo < executionState.getProcessedCheckpoint()) { + persistedSeqNo++; + if (executionState.hasFailure() == false) { + try { + InferenceAction.Response response = executionState.fetchBufferedResponse(persistedSeqNo); + responseConsumer.accept(response); + } catch (Exception e) { + executionState.addFailure(e); + } + } + executionState.markSeqNoAsPersisted(persistedSeqNo); + } + } + + /** + * Call the completion listener when all requests have completed. + */ + private void onBulkCompletion() { + if (executionState.hasFailure() == false) { + try { + completionListener.onResponse(null); + return; + } catch (Exception e) { + executionState.addFailure(e); + } + } + + completionListener.onFailure(executionState.getFailure()); + } + } + + /** + * Encapsulates an inference request with its associated sequence number. + *

+ * The sequence number is used for ordering responses and tracking completion + * in the bulk execution state. + *

+ * + * @param seqNo Unique sequence number for this request in the bulk operation + * @param request The actual inference request to execute + */ + private record BulkRequestItem(long seqNo, InferenceAction.Request request) { + + } + + public static Factory factory(Client client) { + return inferenceRunnerConfig -> new BulkInferenceRunner(client, inferenceRunnerConfig.maxOutstandingBulkRequests()); + } + + /** + * Factory interface for creating {@link BulkInferenceRunner} instances. + */ + @FunctionalInterface + public interface Factory { + /** + * Creates a new inference runner with the specified execution configuration. + * + * @param bulkInferenceRunnerConfig Configuration defining concurrency limits and execution parameters + * @return A configured inference runner implementation + */ + BulkInferenceRunner create(BulkInferenceRunnerConfig bulkInferenceRunnerConfig); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunnerConfig.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunnerConfig.java new file mode 100644 index 0000000000000..4d4b5ed0575af --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunnerConfig.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.bulk; + +/** + * Configuration record for inference execution parameters. + *

+ * This record defines the concurrency and resource limits for inference operations, + * including the number of worker threads and the maximum number of outstanding requests + * that can be queued or processed simultaneously. + *

+ * + * @param maxOutstandingBulkRequests The maximum number of concurrent bulk inference requests allowed + * @param maxOutstandingRequests The maximum number of concurrent inference requests allowed + */ +public record BulkInferenceRunnerConfig(int maxOutstandingRequests, int maxOutstandingBulkRequests) { + + /** + * Default number of worker threads for inference execution. + */ + public static final int DEFAULT_MAX_OUTSTANDING_BULK_REQUESTS = 10; + + /** Default maximum number of outstanding inference requests. */ + public static final int DEFAULT_MAX_OUTSTANDING_REQUESTS = 50; + + /** + * Default configuration instance using standard values for most use cases. + */ + public static final BulkInferenceRunnerConfig DEFAULT = new BulkInferenceRunnerConfig( + DEFAULT_MAX_OUTSTANDING_REQUESTS, + DEFAULT_MAX_OUTSTANDING_BULK_REQUESTS + ); + + public BulkInferenceRunnerConfig { + if (maxOutstandingRequests <= 0) throw new IllegalArgumentException("maxOutstandingRequests must be positive"); + if (maxOutstandingBulkRequests <= 0) throw new IllegalArgumentException("maxOutstandingBulkRequests must be positive"); + + if (maxOutstandingBulkRequests > maxOutstandingRequests) { + maxOutstandingBulkRequests = maxOutstandingRequests; + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java index e53fda90c88b3..65b560f3cf9ce 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java @@ -13,13 +13,11 @@ import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.core.Releasables; -import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.esql.inference.InferenceOperator; -import org.elasticsearch.xpack.esql.inference.InferenceRunner; -import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutionConfig; +import org.elasticsearch.xpack.esql.inference.InferenceService; import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator; - -import java.util.stream.IntStream; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunnerConfig; /** * {@link CompletionOperator} is an {@link InferenceOperator} that performs inference using prompt-based model (e.g., text completion). @@ -31,12 +29,12 @@ public class CompletionOperator extends InferenceOperator { public CompletionOperator( DriverContext driverContext, - InferenceRunner inferenceRunner, - ThreadPool threadPool, + BulkInferenceRunner bulkInferenceRunner, String inferenceId, - ExpressionEvaluator promptEvaluator + ExpressionEvaluator promptEvaluator, + int maxOutstandingPages ) { - super(driverContext, inferenceRunner, BulkInferenceExecutionConfig.DEFAULT, threadPool, inferenceId); + super(driverContext, bulkInferenceRunner, inferenceId, maxOutstandingPages); this.promptEvaluator = promptEvaluator; } @@ -50,16 +48,6 @@ public String toString() { return "CompletionOperator[inference_id=[" + inferenceId() + "]]"; } - @Override - public void addInput(Page input) { - try { - super.addInput(input.appendBlock(promptEvaluator.eval(input))); - } catch (Exception e) { - releasePageOnAnyThread(input); - throw e; - } - } - /** * Constructs the completion inference requests iterator for the given input page by evaluating the prompt expression. * @@ -67,8 +55,7 @@ public void addInput(Page input) { */ @Override protected BulkInferenceRequestIterator requests(Page inputPage) { - int inputBlockChannel = inputPage.getBlockCount() - 1; - return new CompletionOperatorRequestIterator(inputPage.getBlock(inputBlockChannel), inferenceId()); + return new CompletionOperatorRequestIterator((BytesRefBlock) promptEvaluator.eval(inputPage), inferenceId()); } /** @@ -79,16 +66,13 @@ protected BulkInferenceRequestIterator requests(Page inputPage) { @Override protected CompletionOperatorOutputBuilder outputBuilder(Page input) { BytesRefBlock.Builder outputBlockBuilder = blockFactory().newBytesRefBlockBuilder(input.getPositionCount()); - return new CompletionOperatorOutputBuilder( - outputBlockBuilder, - input.projectBlocks(IntStream.range(0, input.getBlockCount() - 1).toArray()) - ); + return new CompletionOperatorOutputBuilder(outputBlockBuilder, input); } /** * Factory for creating {@link CompletionOperator} instances. */ - public record Factory(InferenceRunner inferenceRunner, String inferenceId, ExpressionEvaluator.Factory promptEvaluatorFactory) + public record Factory(InferenceService inferenceService, String inferenceId, ExpressionEvaluator.Factory promptEvaluatorFactory) implements OperatorFactory { @Override @@ -100,10 +84,10 @@ public String describe() { public Operator get(DriverContext driverContext) { return new CompletionOperator( driverContext, - inferenceRunner, - inferenceRunner.threadPool(), + inferenceService.bulkInferenceRunner(), inferenceId, - promptEvaluatorFactory.get(driverContext) + promptEvaluatorFactory.get(driverContext), + BulkInferenceRunnerConfig.DEFAULT.maxOutstandingBulkRequests() ); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilder.java index cfb587c6451d8..3e9106f9a1cf6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilder.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilder.java @@ -33,7 +33,6 @@ public CompletionOperatorOutputBuilder(BytesRefBlock.Builder outputBlockBuilder, @Override public void close() { Releasables.close(outputBlockBuilder); - releasePageOnAnyThread(inputPage); } /** @@ -72,13 +71,13 @@ public void addInferenceResponse(InferenceAction.Response inferenceResponse) { } /** - * Builds the final output page by appending the completion output block to a shallow copy of the input page. + * Builds the final output page by appending the completion output block to the input page. */ @Override public Page buildOutput() { Block outputBlock = outputBlockBuilder.build(); assert outputBlock.getPositionCount() == inputPage.getPositionCount(); - return inputPage.shallowCopy().appendBlock(outputBlock); + return inputPage.appendBlock(outputBlock); } private ChatCompletionResults inferenceResults(InferenceAction.Response inferenceResponse) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java index 6893130425edf..f8ae6da7a35a7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java @@ -119,7 +119,8 @@ public int estimatedSize() { @Override public void close() { - + promptBlock.allowPassingToDifferentDriver(); + Releasables.close(promptBlock); } } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperator.java index ca628fdba8a8f..404f60b6c7142 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperator.java @@ -7,19 +7,17 @@ package org.elasticsearch.xpack.esql.inference.rerank; -import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.core.Releasables; -import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.esql.inference.InferenceOperator; -import org.elasticsearch.xpack.esql.inference.InferenceRunner; -import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutionConfig; - -import java.util.stream.IntStream; +import org.elasticsearch.xpack.esql.inference.InferenceService; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunnerConfig; /** * {@link RerankOperator} is an inference operator that compute scores for rows using a reranking model. @@ -40,30 +38,19 @@ public class RerankOperator extends InferenceOperator { public RerankOperator( DriverContext driverContext, - InferenceRunner inferenceRunner, - ThreadPool threadPool, + BulkInferenceRunner bulkInferenceRunner, String inferenceId, String queryText, ExpressionEvaluator rowEncoder, - int scoreChannel + int scoreChannel, + int maxOutstandingPages ) { - super(driverContext, inferenceRunner, BulkInferenceExecutionConfig.DEFAULT, threadPool, inferenceId); + super(driverContext, bulkInferenceRunner, inferenceId, maxOutstandingPages); this.queryText = queryText; this.rowEncoder = rowEncoder; this.scoreChannel = scoreChannel; } - @Override - public void addInput(Page input) { - try { - Block inputBlock = rowEncoder.eval(input); - super.addInput(input.appendBlock(inputBlock)); - } catch (Exception e) { - releasePageOnAnyThread(input); - throw e; - } - } - @Override protected void doClose() { Releasables.close(rowEncoder); @@ -79,8 +66,7 @@ public String toString() { */ @Override protected RerankOperatorRequestIterator requests(Page inputPage) { - int inputBlockChannel = inputPage.getBlockCount() - 1; - return new RerankOperatorRequestIterator(inputPage.getBlock(inputBlockChannel), inferenceId(), queryText, batchSize); + return new RerankOperatorRequestIterator((BytesRefBlock) rowEncoder.eval(inputPage), inferenceId(), queryText, batchSize); } /** @@ -89,18 +75,14 @@ protected RerankOperatorRequestIterator requests(Page inputPage) { @Override protected RerankOperatorOutputBuilder outputBuilder(Page input) { DoubleBlock.Builder outputBlockBuilder = blockFactory().newDoubleBlockBuilder(input.getPositionCount()); - return new RerankOperatorOutputBuilder( - outputBlockBuilder, - input.projectBlocks(IntStream.range(0, input.getBlockCount() - 1).toArray()), - scoreChannel - ); + return new RerankOperatorOutputBuilder(outputBlockBuilder, input, scoreChannel); } /** * Factory for creating {@link RerankOperator} instances */ public record Factory( - InferenceRunner inferenceRunner, + InferenceService inferenceService, String inferenceId, String queryText, ExpressionEvaluator.Factory rowEncoderFactory, @@ -116,12 +98,12 @@ public String describe() { public Operator get(DriverContext driverContext) { return new RerankOperator( driverContext, - inferenceRunner, - inferenceRunner.threadPool(), + inferenceService.bulkInferenceRunner(), inferenceId, queryText, - rowEncoderFactory().get(driverContext), - scoreChannel + rowEncoderFactory.get(driverContext), + scoreChannel, + BulkInferenceRunnerConfig.DEFAULT.maxOutstandingRequests() ); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilder.java index 1813aa3e9fb59..188986a29bc0a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilder.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilder.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.esql.inference.rerank; -import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.core.Releasables; @@ -18,6 +17,7 @@ import java.util.Comparator; import java.util.Iterator; +import java.util.stream.IntStream; /** * Builds the output page for the {@link RerankOperator} by adding @@ -39,7 +39,6 @@ public RerankOperatorOutputBuilder(DoubleBlock.Builder scoreBlockBuilder, Page i @Override public void close() { Releasables.close(scoreBlockBuilder); - releasePageOnAnyThread(inputPage); } /** @@ -48,22 +47,24 @@ public void close() { */ @Override public Page buildOutput() { - int blockCount = Integer.max(inputPage.getBlockCount(), scoreChannel + 1); - Block[] blocks = new Block[blockCount]; + Page outputPage = inputPage.appendBlock(scoreBlockBuilder.build()); + + if (scoreChannel == inputPage.getBlockCount()) { + // Just need to append the block at the end + // We can just return the output page we have just created + return outputPage; + } try { - for (int b = 0; b < blockCount; b++) { - if (b == scoreChannel) { - blocks[b] = scoreBlockBuilder.build(); - } else { - blocks[b] = inputPage.getBlock(b); - blocks[b].incRef(); - } - } - return new Page(blocks); - } catch (Exception e) { - Releasables.close(blocks); - throw (e); + // We need to project the last column to the score channel. + int[] blockNapping = IntStream.range(0, inputPage.getBlockCount()) + .map(channel -> channel == scoreChannel ? inputPage.getBlockCount() : channel) + .toArray(); + + return outputPage.projectBlocks(blockNapping); + } finally { + // Releasing the output page since projection is incrementing block references. + releasePageOnAnyThread(outputPage); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIterator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIterator.java index 3e73bcc8bea1f..81911d5d089df 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIterator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIterator.java @@ -10,6 +10,7 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.lucene.BytesRefs; import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.core.Releasables; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator; @@ -80,6 +81,7 @@ private InferenceAction.Request inferenceRequest(List inputs) { @Override public void close() { - + inputBlock.allowPassingToDifferentDriver(); + Releasables.close(inputBlock); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java index 28204e2572842..fb7b6ccab5d5e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java @@ -86,7 +86,7 @@ import org.elasticsearch.xpack.esql.evaluator.EvalMapper; import org.elasticsearch.xpack.esql.evaluator.command.GrokEvaluatorExtracter; import org.elasticsearch.xpack.esql.expression.Order; -import org.elasticsearch.xpack.esql.inference.InferenceRunner; +import org.elasticsearch.xpack.esql.inference.InferenceService; import org.elasticsearch.xpack.esql.inference.XContentRowEncoder; import org.elasticsearch.xpack.esql.inference.completion.CompletionOperator; import org.elasticsearch.xpack.esql.inference.rerank.RerankOperator; @@ -159,7 +159,7 @@ public class LocalExecutionPlanner { private final Supplier exchangeSinkSupplier; private final EnrichLookupService enrichLookupService; private final LookupFromIndexService lookupFromIndexService; - private final InferenceRunner inferenceRunner; + private final InferenceService inferenceService; private final PhysicalOperationProviders physicalOperationProviders; private final List shardContexts; @@ -175,7 +175,7 @@ public LocalExecutionPlanner( Supplier exchangeSinkSupplier, EnrichLookupService enrichLookupService, LookupFromIndexService lookupFromIndexService, - InferenceRunner inferenceRunner, + InferenceService inferenceService, PhysicalOperationProviders physicalOperationProviders, List shardContexts ) { @@ -191,7 +191,7 @@ public LocalExecutionPlanner( this.exchangeSinkSupplier = exchangeSinkSupplier; this.enrichLookupService = enrichLookupService; this.lookupFromIndexService = lookupFromIndexService; - this.inferenceRunner = inferenceRunner; + this.inferenceService = inferenceService; this.physicalOperationProviders = physicalOperationProviders; this.shardContexts = shardContexts; } @@ -318,7 +318,7 @@ private PhysicalOperation planCompletion(CompletionExec completion, LocalExecuti source.layout ); - return source.with(new CompletionOperator.Factory(inferenceRunner, inferenceId, promptEvaluatorFactory), outputLayout); + return source.with(new CompletionOperator.Factory(inferenceService, inferenceId, promptEvaluatorFactory), outputLayout); } private PhysicalOperation planRrfScoreEvalExec(RrfScoreEvalExec rrf, LocalExecutionPlannerContext context) { @@ -654,7 +654,7 @@ private PhysicalOperation planRerank(RerankExec rerank, LocalExecutionPlannerCon int scoreChannel = outputLayout.get(rerank.scoreAttribute().id()).channel(); return source.with( - new RerankOperator.Factory(inferenceRunner, inferenceId, queryText, rowEncoderFactory, scoreChannel), + new RerankOperator.Factory(inferenceService, inferenceId, queryText, rowEncoderFactory, scoreChannel), outputLayout ); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java index ae107d63bd51d..2cb7943231938 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java @@ -53,7 +53,7 @@ import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.enrich.EnrichLookupService; import org.elasticsearch.xpack.esql.enrich.LookupFromIndexService; -import org.elasticsearch.xpack.esql.inference.InferenceRunner; +import org.elasticsearch.xpack.esql.inference.InferenceService; import org.elasticsearch.xpack.esql.plan.physical.ExchangeSinkExec; import org.elasticsearch.xpack.esql.plan.physical.ExchangeSourceExec; import org.elasticsearch.xpack.esql.plan.physical.OutputExec; @@ -132,7 +132,7 @@ public class ComputeService { private final DriverTaskRunner driverRunner; private final EnrichLookupService enrichLookupService; private final LookupFromIndexService lookupFromIndexService; - private final InferenceRunner inferenceRunner; + private final InferenceService inferenceService; private final ClusterService clusterService; private final ProjectResolver projectResolver; private final AtomicLong childSessionIdGenerator = new AtomicLong(); @@ -159,7 +159,7 @@ public ComputeService( this.driverRunner = new DriverTaskRunner(transportService, esqlExecutor); this.enrichLookupService = enrichLookupService; this.lookupFromIndexService = lookupFromIndexService; - this.inferenceRunner = transportActionServices.inferenceRunner(); + this.inferenceService = transportActionServices.inferenceService(); this.clusterService = transportActionServices.clusterService(); this.projectResolver = transportActionServices.projectResolver(); this.dataNodeComputeHandler = new DataNodeComputeHandler( @@ -630,7 +630,7 @@ public SourceProvider createSourceProvider() { context.exchangeSinkSupplier(), enrichLookupService, lookupFromIndexService, - inferenceRunner, + inferenceService, physicalOperationProviders, contexts ); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportActionServices.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportActionServices.java index ccabe09fd466c..0255fad60811e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportActionServices.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportActionServices.java @@ -14,7 +14,7 @@ import org.elasticsearch.search.SearchService; import org.elasticsearch.transport.TransportService; import org.elasticsearch.usage.UsageService; -import org.elasticsearch.xpack.esql.inference.InferenceRunner; +import org.elasticsearch.xpack.esql.inference.InferenceService; public record TransportActionServices( TransportService transportService, @@ -24,5 +24,5 @@ public record TransportActionServices( ProjectResolver projectResolver, IndexNameExpressionResolver indexNameExpressionResolver, UsageService usageService, - InferenceRunner inferenceRunner + InferenceService inferenceService ) {} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java index 4be7b31bb96c0..a23154c218a61 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java @@ -51,7 +51,7 @@ import org.elasticsearch.xpack.esql.enrich.LookupFromIndexService; import org.elasticsearch.xpack.esql.execution.PlanExecutor; import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; -import org.elasticsearch.xpack.esql.inference.InferenceRunner; +import org.elasticsearch.xpack.esql.inference.InferenceService; import org.elasticsearch.xpack.esql.session.Configuration; import org.elasticsearch.xpack.esql.session.EsqlSession.PlanRunner; import org.elasticsearch.xpack.esql.session.Result; @@ -166,7 +166,7 @@ public TransportEsqlQueryAction( projectResolver, indexNameExpressionResolver, usageService, - new InferenceRunner(client, threadPool) + new InferenceService(client) ); this.computeService = new ComputeService( diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java index 63b7173f54432..34ece5661c8d7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java @@ -56,7 +56,7 @@ import org.elasticsearch.xpack.esql.index.IndexResolution; import org.elasticsearch.xpack.esql.index.MappingException; import org.elasticsearch.xpack.esql.inference.InferenceResolution; -import org.elasticsearch.xpack.esql.inference.InferenceRunner; +import org.elasticsearch.xpack.esql.inference.InferenceService; import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer; import org.elasticsearch.xpack.esql.optimizer.LogicalPlanPreOptimizer; import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerContext; @@ -120,7 +120,7 @@ public interface PlanRunner { private final PhysicalPlanOptimizer physicalPlanOptimizer; private final PlanTelemetry planTelemetry; private final IndicesExpressionGrouper indicesExpressionGrouper; - private final InferenceRunner inferenceRunner; + private final InferenceService inferenceService; private final RemoteClusterService remoteClusterService; private boolean explainMode; @@ -155,7 +155,7 @@ public EsqlSession( this.physicalPlanOptimizer = new PhysicalPlanOptimizer(new PhysicalOptimizerContext(configuration)); this.planTelemetry = planTelemetry; this.indicesExpressionGrouper = indicesExpressionGrouper; - this.inferenceRunner = services.inferenceRunner(); + this.inferenceService = services.inferenceService(); this.preMapper = new PreMapper(services); this.remoteClusterService = services.transportService().getRemoteClusterService(); } @@ -397,12 +397,7 @@ public void analyzedPlan( l -> enrichPolicyResolver.resolvePolicies(unresolvedPolicies, executionInfo, l) ) .andThenApply(enrichResolution -> FieldNameUtils.resolveFieldNames(parsed, enrichResolution)) - .andThen( - (l, preAnalysisResult) -> inferenceRunner.resolveInferenceIds( - preAnalysis.inferencePlans, - l.map(preAnalysisResult::withInferenceResolution) - ) - ); + .andThen((l, preAnalysisResult) -> resolveInferences(parsed, preAnalysisResult, l)); // first resolve the lookup indices, then the main indices for (var index : preAnalysis.lookupIndices) { listener = listener.andThen((l, preAnalysisResult) -> preAnalyzeLookupIndex(index, preAnalysisResult, executionInfo, l)); @@ -755,6 +750,10 @@ private static void analyzeAndMaybeRetry( logicalPlanListener.onResponse(plan); } + private void resolveInferences(LogicalPlan plan, PreAnalysisResult preAnalysisResult, ActionListener l) { + inferenceService.inferenceResolver().resolveInferenceIds(plan, l.map(preAnalysisResult::withInferenceResolution)); + } + private PhysicalPlan logicalPlanToPhysicalPlan(LogicalPlan optimizedPlan, EsqlQueryRequest request) { PhysicalPlan physicalPlan = optimizedPhysicalPlan(optimizedPlan); physicalPlan = physicalPlan.transformUp(FragmentExec.class, f -> { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java index b38b3089823d5..2c3a9a2fe9679 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java @@ -67,7 +67,7 @@ import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; import org.elasticsearch.xpack.esql.index.EsIndex; import org.elasticsearch.xpack.esql.index.IndexResolution; -import org.elasticsearch.xpack.esql.inference.InferenceRunner; +import org.elasticsearch.xpack.esql.inference.InferenceService; import org.elasticsearch.xpack.esql.optimizer.LocalLogicalOptimizerContext; import org.elasticsearch.xpack.esql.optimizer.LocalLogicalPlanOptimizer; import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalOptimizerContext; @@ -100,7 +100,6 @@ import org.elasticsearch.xpack.esql.telemetry.PlanTelemetry; import org.junit.After; import org.junit.Before; -import org.mockito.Mockito; import java.io.IOException; import java.net.URL; @@ -133,6 +132,7 @@ import static org.hamcrest.Matchers.in; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.notNullValue; +import static org.mockito.Mockito.mock; /** * CSV-based unit testing. @@ -705,9 +705,9 @@ void executeSubPlan( configuration, exchangeSource::createExchangeSource, () -> exchangeSink.createExchangeSink(() -> {}), - Mockito.mock(EnrichLookupService.class), - Mockito.mock(LookupFromIndexService.class), - Mockito.mock(InferenceRunner.class), + mock(EnrichLookupService.class), + mock(LookupFromIndexService.class), + mock(InferenceService.class), physicalOperationProviders, List.of() ); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java index c49e301968aa0..e72eecccf5ab8 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java @@ -18,7 +18,6 @@ import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; -import org.elasticsearch.compute.data.BlockUtils; import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.DoubleBlock; @@ -32,6 +31,7 @@ import org.elasticsearch.compute.operator.SourceOperator; import org.elasticsearch.compute.test.AbstractBlockSourceOperator; import org.elasticsearch.compute.test.OperatorTestCase; +import org.elasticsearch.core.Releasables; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.client.NoOpClient; @@ -51,7 +51,8 @@ import static org.hamcrest.Matchers.notNullValue; public abstract class InferenceOperatorTestCase extends OperatorTestCase { - private ThreadPool threadPool; + protected ThreadPool threadPool; + protected int inputsCount; @Before public void setThreadPool() { @@ -68,6 +69,11 @@ public void setThreadPool() { ); } + @Before + public void initChannels() { + inputsCount = randomIntBetween(1, 10); + } + @After public void shutdownThreadPool() { terminate(threadPool); @@ -88,18 +94,28 @@ protected int remaining() { @Override protected Page createPage(int positionOffset, int length) { length = Integer.min(length, remaining()); - try (var builder = blockFactory.newBytesRefBlockBuilder(length)) { - for (int i = 0; i < length; i++) { - if (randomInt() % 100 == 0) { - builder.appendNull(); - } else { - builder.appendBytesRef(new BytesRef(randomAlphaOfLength(10))); + Block[] blocks = new Block[inputsCount]; + try { + for (int b = 0; b < inputsCount; b++) { + try (var builder = blockFactory.newBytesRefBlockBuilder(length)) { + for (int i = 0; i < length; i++) { + if (randomInt() % 100 == 0) { + builder.appendNull(); + } else { + builder.appendBytesRef(new BytesRef(randomAlphaOfLength(10))); + } + } + blocks[b] = builder.build(); } - } - currentPosition += length; - return new Page(builder.build()); + } catch (Exception e) { + Releasables.closeExpectNoException(blocks); + throw e; } + + currentPosition += length; + return new Page(blocks); + } }; } @@ -118,33 +134,26 @@ public void testOperatorStatus() { } @SuppressWarnings("unchecked") - protected InferenceRunner mockedSimpleInferenceRunner() { - Client client = new NoOpClient(threadPool) { + protected InferenceService mockedInferenceService() { + Client mockClient = new NoOpClient(threadPool) { @Override protected void doExecute( ActionType action, Request request, ActionListener listener ) { - Runnable runnable = () -> { - if (action == InferenceAction.INSTANCE && request instanceof InferenceAction.Request inferenceRequest) { - InferenceAction.Response inferenceResponse = new InferenceAction.Response(mockInferenceResult(inferenceRequest)); - listener.onResponse((Response) inferenceResponse); + runWithRandomDelay(() -> { + if (action instanceof InferenceAction && request instanceof InferenceAction.Request inferenceRequest) { + listener.onResponse((Response) new InferenceAction.Response(mockInferenceResult(inferenceRequest))); return; } - fail("Unexpected call to action [" + action.name() + "]"); - }; - - if (randomBoolean()) { - runnable.run(); - } else { - threadPool.schedule(runnable, TimeValue.timeValueNanos(between(1, 100)), threadPool.executor(ThreadPool.Names.SEARCH)); - } + listener.onFailure(new UnsupportedOperationException("Unexpected action: " + action)); + }); } }; - return new InferenceRunner(client, threadPool); + return new InferenceService(mockClient); } protected abstract InferenceResultsType mockInferenceResult(InferenceAction.Request request); @@ -201,7 +210,9 @@ protected EvalOperator.ExpressionEvaluator.Factory evaluatorFactory(int channel) return context -> new EvalOperator.ExpressionEvaluator() { @Override public Block eval(Page page) { - return BlockUtils.deepCopyOf(page.getBlock(channel), blockFactory()); + Block b = page.getBlock(channel); + b.incRef(); + return b; } @Override @@ -210,4 +221,12 @@ public void close() { } }; } + + private void runWithRandomDelay(Runnable runnable) { + if (randomBoolean()) { + runnable.run(); + } else { + threadPool.schedule(runnable, TimeValue.timeValueNanos(between(1, 1_000)), threadPool.generic()); + } + } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceRunnerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceResolverTests.java similarity index 71% rename from x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceRunnerTests.java rename to x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceResolverTests.java index ef7b3984bd532..8666eedbaeaaa 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceRunnerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceResolverTests.java @@ -22,16 +22,18 @@ import org.elasticsearch.threadpool.FixedExecutorBuilder; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; -import org.elasticsearch.xpack.esql.core.expression.Literal; -import org.elasticsearch.xpack.esql.core.tree.Source; -import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan; +import org.elasticsearch.xpack.esql.parser.EsqlParser; import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; import org.junit.After; import org.junit.Before; +import java.util.HashSet; import java.util.List; +import java.util.Set; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.mockito.ArgumentMatchers.any; @@ -40,7 +42,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -public class InferenceRunnerTests extends ESTestCase { +public class InferenceResolverTests extends ESTestCase { private TestThreadPool threadPool; @Before @@ -63,12 +65,43 @@ public void shutdownThreadPool() { terminate(threadPool); } + public void testCollectInferenceIds() { + // Rerank inference plan + assertCollectInferenceIds( + "FROM books METADATA _score | RERANK \"italian food recipe\" ON title WITH { \"inference_id\": \"rerank-inference-id\" }", + List.of("rerank-inference-id") + ); + + // Completion inference plan + assertCollectInferenceIds( + "FROM books METADATA _score | COMPLETION \"italian food recipe\" WITH { \"inference_id\": \"completion-inference-id\" }", + List.of("completion-inference-id") + ); + + // Multiple inference plans + assertCollectInferenceIds(""" + FROM books METADATA _score + | RERANK "italian food recipe" ON title WITH { "inference_id": "rerank-inference-id" } + | COMPLETION "italian food recipe" WITH { "inference_id": "completion-inference-id" } + """, List.of("rerank-inference-id", "completion-inference-id")); + + // No inference operations + assertCollectInferenceIds("FROM books | WHERE title:\"test\"", List.of()); + } + + private void assertCollectInferenceIds(String query, List expectedInferenceIds) { + Set inferenceIds = new HashSet<>(); + InferenceResolver inferenceResolver = inferenceResolver(); + inferenceResolver.collectInferenceIds(new EsqlParser().createStatement(query, configuration(query)), inferenceIds::add); + assertThat(inferenceIds, containsInAnyOrder(expectedInferenceIds.toArray(new String[0]))); + } + public void testResolveInferenceIds() throws Exception { - InferenceRunner inferenceRunner = new InferenceRunner(mockClient(), threadPool); - List> inferencePlans = List.of(mockInferencePlan("rerank-plan")); + InferenceResolver inferenceResolver = inferenceResolver(); + List inferenceIds = List.of("rerank-plan"); SetOnce inferenceResolutionSetOnce = new SetOnce<>(); - inferenceRunner.resolveInferenceIds(inferencePlans, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> { + inferenceResolver.resolveInferenceIds(inferenceIds, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> { throw new RuntimeException(e); })); @@ -81,15 +114,11 @@ public void testResolveInferenceIds() throws Exception { } public void testResolveMultipleInferenceIds() throws Exception { - InferenceRunner inferenceRunner = new InferenceRunner(mockClient(), threadPool); - List> inferencePlans = List.of( - mockInferencePlan("rerank-plan"), - mockInferencePlan("rerank-plan"), - mockInferencePlan("completion-plan") - ); + InferenceResolver inferenceResolver = inferenceResolver(); + List inferenceIds = List.of("rerank-plan", "rerank-plan", "completion-plan"); SetOnce inferenceResolutionSetOnce = new SetOnce<>(); - inferenceRunner.resolveInferenceIds(inferencePlans, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> { + inferenceResolver.resolveInferenceIds(inferenceIds, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> { throw new RuntimeException(e); })); @@ -109,12 +138,12 @@ public void testResolveMultipleInferenceIds() throws Exception { } public void testResolveMissingInferenceIds() throws Exception { - InferenceRunner inferenceRunner = new InferenceRunner(mockClient(), threadPool); - List> inferencePlans = List.of(mockInferencePlan("missing-plan")); + InferenceResolver inferenceResolver = inferenceResolver(); + List inferenceIds = List.of("missing-plan"); SetOnce inferenceResolutionSetOnce = new SetOnce<>(); - inferenceRunner.resolveInferenceIds(inferencePlans, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> { + inferenceResolver.resolveInferenceIds(inferenceIds, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> { throw new RuntimeException(e); })); @@ -175,13 +204,11 @@ private static ActionResponse getInferenceModelResponse(GetInferenceModelAction. return null; } - private static ModelConfigurations mockModelConfig(String inferenceId, TaskType taskType) { - return new ModelConfigurations(inferenceId, taskType, randomIdentifier(), mock(ServiceSettings.class)); + private InferenceResolver inferenceResolver() { + return new InferenceResolver(mockClient()); } - private static InferencePlan mockInferencePlan(String inferenceId) { - InferencePlan plan = mock(InferencePlan.class); - when(plan.inferenceId()).thenReturn(Literal.keyword(Source.EMPTY, inferenceId)); - return plan; + private static ModelConfigurations mockModelConfig(String inferenceId, TaskType taskType) { + return new ModelConfigurations(inferenceId, taskType, randomIdentifier(), mock(ServiceSettings.class)); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunnerTests.java similarity index 60% rename from x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutorTests.java rename to x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunnerTests.java index 7e44c681c6fc4..dedbf895860b9 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutorTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunnerTests.java @@ -8,16 +8,17 @@ package org.elasticsearch.xpack.esql.inference.bulk; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.core.TimeValue; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.client.NoOpClient; import org.elasticsearch.threadpool.FixedExecutorBuilder; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; -import org.elasticsearch.xpack.esql.inference.InferenceRunner; import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; import org.junit.After; import org.junit.Before; @@ -26,6 +27,8 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import static org.hamcrest.Matchers.allOf; @@ -33,11 +36,12 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.notNullValue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -public class BulkInferenceExecutorTests extends ESTestCase { +public class BulkInferenceRunnerTests extends ESTestCase { private ThreadPool threadPool; @Before @@ -60,53 +64,53 @@ public void shutdownThreadPool() { terminate(threadPool); } - public void testSuccessfulExecution() throws Exception { - List requests = randomInferenceRequestList(between(1, 1000)); + public void testSuccessfulBulkExecution() throws Exception { + List requests = randomInferenceRequestList(between(1, 1_000)); List responses = randomInferenceResponseList(requests.size()); - InferenceRunner inferenceRunner = mockInferenceRunner(invocation -> { + Client client = mockClient(invocation -> { runWithRandomDelay(() -> { - ActionListener l = invocation.getArgument(1); - l.onResponse(responses.get(requests.indexOf(invocation.getArgument(0, InferenceAction.Request.class)))); + ActionListener l = invocation.getArgument(2); + l.onResponse(responses.get(requests.indexOf(invocation.getArgument(1, InferenceAction.Request.class)))); }); return null; }); AtomicReference> output = new AtomicReference<>(); - ActionListener> listener = ActionListener.wrap(output::set, r -> fail("Unexpected exception")); + ActionListener> listener = ActionListener.wrap(output::set, ESTestCase::fail); - bulkExecutor(inferenceRunner).execute(requestIterator(requests), listener); + inferenceRunnerFactory(client).create(randomBulkExecutionConfig()).executeBulk(requestIterator(requests), listener); assertBusy(() -> assertThat(output.get(), allOf(notNullValue(), equalTo(responses)))); } - public void testSuccessfulExecutionOnEmptyRequest() throws Exception { + public void testSuccessfulBulkExecutionOnEmptyRequest() throws Exception { BulkInferenceRequestIterator requestIterator = mock(BulkInferenceRequestIterator.class); when(requestIterator.hasNext()).thenReturn(false); AtomicReference> output = new AtomicReference<>(); - ActionListener> listener = ActionListener.wrap(output::set, r -> fail("Unexpected exception")); + ActionListener> listener = ActionListener.wrap(output::set, ESTestCase::fail); - bulkExecutor(mock(InferenceRunner.class)).execute(requestIterator, listener); + inferenceRunnerFactory(new NoOpClient(threadPool)).create(randomBulkExecutionConfig()).executeBulk(requestIterator, listener); assertBusy(() -> assertThat(output.get(), allOf(notNullValue(), empty()))); } - public void testInferenceRunnerAlwaysFails() throws Exception { + public void testBulkExecutionWhenInferenceRunnerAlwaysFails() throws Exception { List requests = randomInferenceRequestList(between(1, 1000)); - InferenceRunner inferenceRunner = mock(invocation -> { + Client client = mockClient(invocation -> { runWithRandomDelay(() -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onFailure(new RuntimeException("inference failure")); }); return null; }); AtomicReference exception = new AtomicReference<>(); - ActionListener> listener = ActionListener.wrap(r -> fail("Expceted exception"), exception::set); + ActionListener> listener = ActionListener.wrap(r -> fail("Expected an exception"), exception::set); - bulkExecutor(inferenceRunner).execute(requestIterator(requests), listener); + inferenceRunnerFactory(client).create(randomBulkExecutionConfig()).executeBulk(requestIterator(requests), listener); assertBusy(() -> { assertThat(exception.get(), notNullValue()); @@ -114,13 +118,13 @@ public void testInferenceRunnerAlwaysFails() throws Exception { }); } - public void testInferenceRunnerSometimesFails() throws Exception { - List requests = randomInferenceRequestList(between(1, 1000)); + public void testBulkExecutionWhenInferenceRunnerSometimesFails() throws Exception { + List requests = randomInferenceRequestList(between(1, 1_000)); - InferenceRunner inferenceRunner = mockInferenceRunner(invocation -> { - ActionListener listener = invocation.getArgument(1); + Client client = mockClient(invocation -> { + ActionListener listener = invocation.getArgument(2); runWithRandomDelay(() -> { - if ((requests.indexOf(invocation.getArgument(0, InferenceAction.Request.class)) % requests.size()) == 0) { + if ((requests.indexOf(invocation.getArgument(1, InferenceAction.Request.class)) % requests.size()) == 0) { listener.onFailure(new RuntimeException("inference failure")); } else { listener.onResponse(mockInferenceResponse()); @@ -131,9 +135,9 @@ public void testInferenceRunnerSometimesFails() throws Exception { }); AtomicReference exception = new AtomicReference<>(); - ActionListener> listener = ActionListener.wrap(r -> fail("Expceted exception"), exception::set); + ActionListener> listener = ActionListener.wrap(r -> fail("Expected an exception"), exception::set); - bulkExecutor(inferenceRunner).execute(requestIterator(requests), listener); + inferenceRunnerFactory(client).create(randomBulkExecutionConfig()).executeBulk(requestIterator(requests), listener); assertBusy(() -> { assertThat(exception.get(), notNullValue()); @@ -141,8 +145,37 @@ public void testInferenceRunnerSometimesFails() throws Exception { }); } - private BulkInferenceExecutor bulkExecutor(InferenceRunner inferenceRunner) { - return new BulkInferenceExecutor(inferenceRunner, threadPool, randomBulkExecutionConfig()); + public void testParallelBulkExecution() throws Exception { + int batches = between(50, 100); + CountDownLatch latch = new CountDownLatch(batches); + + for (int i = 0; i < batches; i++) { + runWithRandomDelay(() -> { + List requests = randomInferenceRequestList(between(1, 1_000)); + List responses = randomInferenceResponseList(requests.size()); + + Client client = mockClient(invocation -> { + runWithRandomDelay(() -> { + ActionListener l = invocation.getArgument(2); + l.onResponse(responses.get(requests.indexOf(invocation.getArgument(1, InferenceAction.Request.class)))); + }); + return null; + }); + + ActionListener> listener = ActionListener.wrap(r -> { + assertThat(r, equalTo(responses)); + latch.countDown(); + }, ESTestCase::fail); + + inferenceRunnerFactory(client).create(randomBulkExecutionConfig()).executeBulk(requestIterator(requests), listener); + }); + } + + latch.await(10, TimeUnit.SECONDS); + } + + private BulkInferenceRunner.Factory inferenceRunnerFactory(Client client) { + return BulkInferenceRunner.factory(client); } private InferenceAction.Request mockInferenceRequest() { @@ -155,8 +188,8 @@ private InferenceAction.Response mockInferenceResponse() { return response; } - private BulkInferenceExecutionConfig randomBulkExecutionConfig() { - return new BulkInferenceExecutionConfig(between(1, 100), between(1, 100)); + private BulkInferenceRunnerConfig randomBulkExecutionConfig() { + return new BulkInferenceRunnerConfig(between(1, 100), between(1, 100)); } private BulkInferenceRequestIterator requestIterator(List requests) { @@ -171,7 +204,7 @@ private BulkInferenceRequestIterator requestIterator(List randomInferenceRequestList(int size) { List requests = new ArrayList<>(size); while (requests.size() < size) { - requests.add(this.mockInferenceRequest()); + requests.add(mockInferenceRequest()); } return requests; @@ -180,26 +213,23 @@ private List randomInferenceRequestList(int size) { private List randomInferenceResponseList(int size) { List response = new ArrayList<>(size); while (response.size() < size) { - response.add(mock(InferenceAction.Response.class)); + response.add(mockInferenceResponse()); } return response; } - private InferenceRunner mockInferenceRunner(Answer doInferenceAnswer) { - InferenceRunner inferenceRunner = mock(InferenceRunner.class); - doAnswer(doInferenceAnswer).when(inferenceRunner).doInference(any(), any()); - return inferenceRunner; + private Client mockClient(Answer doInferenceAnswer) { + Client client = mock(Client.class); + when(client.threadPool()).thenReturn(threadPool); + doAnswer(doInferenceAnswer).when(client).execute(eq(InferenceAction.INSTANCE), any(InferenceAction.Request.class), any()); + return client; } private void runWithRandomDelay(Runnable runnable) { if (randomBoolean()) { runnable.run(); } else { - threadPool.schedule( - runnable, - TimeValue.timeValueNanos(between(1, 1_000)), - threadPool.executor(EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME) - ); + threadPool.schedule(runnable, TimeValue.timeValueNanos(between(1, 100_000)), threadPool.generic()); } } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilderTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilderTests.java index 77d50ca5ee981..09aabe5fef6e9 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilderTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilderTests.java @@ -23,15 +23,15 @@ public class CompletionOperatorOutputBuilderTests extends ComputeTestCase { - public void testBuildSmallOutput() { + public void testBuildSmallOutput() throws Exception { assertBuildOutput(between(1, 100)); } - public void testBuildLargeOutput() { + public void testBuildLargeOutput() throws Exception { assertBuildOutput(between(10_000, 100_000)); } - private void assertBuildOutput(int size) { + private void assertBuildOutput(int size) throws Exception { final Page inputPage = randomInputPage(size, between(1, 20)); try ( CompletionOperatorOutputBuilder outputBuilder = new CompletionOperatorOutputBuilder( @@ -50,11 +50,9 @@ private void assertBuildOutput(int size) { assertOutputContent(outputPage.getBlock(outputPage.getBlockCount() - 1)); outputPage.releaseBlocks(); - - } finally { - inputPage.releaseBlocks(); } + allBreakersEmpty(); } private void assertOutputContent(BytesRefBlock block) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIteratorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIteratorTests.java index 5c0253e508553..86592256d26bc 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIteratorTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIteratorTests.java @@ -16,21 +16,19 @@ public class CompletionOperatorRequestIteratorTests extends ComputeTestCase { - public void testIterateSmallInput() { + public void testIterateSmallInput() throws Exception { assertIterate(between(1, 100)); } - public void testIterateLargeInput() { + public void testIterateLargeInput() throws Exception { assertIterate(between(10_000, 100_000)); } - private void assertIterate(int size) { + private void assertIterate(int size) throws Exception { final String inferenceId = randomIdentifier(); + final BytesRefBlock inputBlock = randomInputBlock(size); - try ( - BytesRefBlock inputBlock = randomInputBlock(size); - CompletionOperatorRequestIterator requestIterator = new CompletionOperatorRequestIterator(inputBlock, inferenceId) - ) { + try (CompletionOperatorRequestIterator requestIterator = new CompletionOperatorRequestIterator(inputBlock, inferenceId)) { BytesRef scratch = new BytesRef(); for (int currentPos = 0; requestIterator.hasNext(); currentPos++) { @@ -40,6 +38,8 @@ private void assertIterate(int size) { assertThat(request.getInput().getFirst(), equalTo(scratch.utf8ToString())); } } + + allBreakersEmpty(); } private BytesRefBlock randomInputBlock(int size) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorTests.java index add8155240ad1..b40f58c9724b5 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.esql.inference.InferenceOperatorTestCase; import org.hamcrest.Matcher; +import org.junit.Before; import java.util.ArrayList; import java.util.List; @@ -27,9 +28,16 @@ public class CompletionOperatorTests extends InferenceOperatorTestCase { private static final String SIMPLE_INFERENCE_ID = "test_completion"; + private int inputChannel; + + @Before + public void initCompletionChannels() { + inputChannel = between(0, inputsCount - 1); + } + @Override protected Operator.OperatorFactory simple(SimpleOptions options) { - return new CompletionOperator.Factory(mockedSimpleInferenceRunner(), SIMPLE_INFERENCE_ID, evaluatorFactory(0)); + return new CompletionOperator.Factory(mockedInferenceService(), SIMPLE_INFERENCE_ID, evaluatorFactory(inputChannel)); } @Override @@ -54,7 +62,7 @@ protected void assertSimpleOutput(List input, List results) { } private void assertCompletionResults(Page inputPage, Page resultPage) { - BytesRefBlock inputBlock = resultPage.getBlock(0); + BytesRefBlock inputBlock = resultPage.getBlock(inputChannel); BytesRefBlock resultBlock = resultPage.getBlock(inputPage.getBlockCount()); BytesRef scratch = new BytesRef(); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilderTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilderTests.java index 7117ccc19005e..bc204728ff96d 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilderTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilderTests.java @@ -13,7 +13,6 @@ import org.elasticsearch.compute.test.ComputeTestCase; import org.elasticsearch.compute.test.RandomBlock; import org.elasticsearch.core.Releasables; -import org.elasticsearch.logging.LogManager; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; @@ -24,15 +23,15 @@ public class RerankOperatorOutputBuilderTests extends ComputeTestCase { - public void testBuildSmallOutput() { + public void testBuildSmallOutput() throws Exception { assertBuildOutput(between(1, 100)); } - public void testBuildLargeOutput() { + public void testBuildLargeOutput() throws Exception { assertBuildOutput(between(10_000, 100_000)); } - private void assertBuildOutput(int size) { + private void assertBuildOutput(int size) throws Exception { final Page inputPage = randomInputPage(size, between(1, 20)); final int scoreChannel = randomIntBetween(0, inputPage.getBlockCount()); try ( @@ -56,23 +55,15 @@ private void assertBuildOutput(int size) { final Page outputPage = outputBuilder.buildOutput(); try { assertThat(outputPage.getPositionCount(), equalTo(inputPage.getPositionCount())); - LogManager.getLogger(RerankOperatorOutputBuilderTests.class) - .info( - "{} , {}, {}, {}", - scoreChannel, - inputPage.getBlockCount(), - outputPage.getBlockCount(), - Math.max(scoreChannel + 1, inputPage.getBlockCount()) - ); assertThat(outputPage.getBlockCount(), equalTo(Integer.max(scoreChannel + 1, inputPage.getBlockCount()))); assertOutputContent(outputPage.getBlock(scoreChannel)); } finally { outputPage.releaseBlocks(); } - } finally { - inputPage.releaseBlocks(); } + + allBreakersEmpty(); } private float relevanceScore(int position) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIteratorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIteratorTests.java index 133bfeaaf02ad..72397efcf1be3 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIteratorTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIteratorTests.java @@ -18,20 +18,20 @@ public class RerankOperatorRequestIteratorTests extends ComputeTestCase { - public void testIterateSmallInput() { + public void testIterateSmallInput() throws Exception { assertIterate(between(1, 100), randomIntBetween(1, 1_000)); } - public void testIterateLargeInput() { + public void testIterateLargeInput() throws Exception { assertIterate(between(10_000, 100_000), randomIntBetween(1, 1_000)); } - private void assertIterate(int size, int batchSize) { + private void assertIterate(int size, int batchSize) throws Exception { final String inferenceId = randomIdentifier(); final String queryText = randomIdentifier(); + final BytesRefBlock inputBlock = randomInputBlock(size); try ( - BytesRefBlock inputBlock = randomInputBlock(size); RerankOperatorRequestIterator requestIterator = new RerankOperatorRequestIterator(inputBlock, inferenceId, queryText, batchSize) ) { BytesRef scratch = new BytesRef(); @@ -49,6 +49,8 @@ private void assertIterate(int size, int batchSize) { } } } + + allBreakersEmpty(); } private BytesRefBlock randomInputBlock(int size) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorTests.java index f5dc1b3c05fd9..d10540dda7c32 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.esql.inference.InferenceOperatorTestCase; import org.hamcrest.Matcher; +import org.junit.Before; import java.util.ArrayList; import java.util.List; @@ -27,10 +28,27 @@ public class RerankOperatorTests extends InferenceOperatorTestCase inputPages, List resultPages) Page resultPage = resultPages.get(pageId); assertThat(resultPage.getPositionCount(), equalTo(inputPage.getPositionCount())); - assertThat(resultPage.getBlockCount(), equalTo(Integer.max(2, inputPage.getBlockCount()))); + assertThat(resultPage.getBlockCount(), equalTo(Integer.max(scoreChannel + 1, inputPage.getBlockCount()))); for (int channel = 0; channel < inputPage.getBlockCount(); channel++) { - Block inputBlock = inputPage.getBlock(channel); Block resultBlock = resultPage.getBlock(channel); - - assertThat(resultBlock.getPositionCount(), equalTo(resultPage.getPositionCount())); - assertThat(resultBlock.elementType(), equalTo(inputBlock.elementType())); - - if (channel != 1) { + if (channel == scoreChannel) { + assertExpectedScore(inputPage.getBlock(inputChannel), (DoubleBlock) resultBlock); + } else { + Block inputBlock = inputPage.getBlock(channel); + assertThat(resultBlock.getPositionCount(), equalTo(resultPage.getPositionCount())); + assertThat(resultBlock.elementType(), equalTo(inputBlock.elementType())); assertBlockContentEquals(inputBlock, resultBlock); } - - if (channel == 0) { - assertExpectedScore((BytesRefBlock) inputBlock, resultPage.getBlock(1)); - } } } } @@ -79,7 +93,7 @@ protected Matcher expectedDescriptionOfSimple() { @Override protected Matcher expectedToStringOfSimple() { return equalTo( - "RerankOperator[inference_id=[" + SIMPLE_INFERENCE_ID + "], query=[" + SIMPLE_QUERY + "], score_channel=[" + 1 + "]]" + "RerankOperator[inference_id=[" + SIMPLE_INFERENCE_ID + "], query=[" + SIMPLE_QUERY + "], score_channel=[" + scoreChannel + "]]" ); }