Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,20 @@ public static List<Page> deepCopyOf(BlockFactory blockFactory, List<Page> pages)
try {
for (Page p : pages) {
Block[] blocks = new Block[p.getBlockCount()];
for (int b = 0; b < blocks.length; b++) {
Block orig = p.getBlock(b);
try (Block.Builder builder = orig.elementType().newBlockBuilder(p.getPositionCount(), blockFactory)) {
builder.copyFrom(orig, 0, p.getPositionCount());
blocks[b] = builder.build();
try {
for (int b = 0; b < blocks.length; b++) {
Block orig = p.getBlock(b);
try (Block.Builder builder = orig.elementType().newBlockBuilder(p.getPositionCount(), blockFactory)) {
builder.copyFrom(orig, 0, p.getPositionCount());
blocks[b] = builder.build();
}
}
out.add(new Page(blocks));
} catch (Exception e) {
// Something went wrong, release the blocks.
Releasables.closeExpectNoException(blocks);
throw e;
}
out.add(new Page(blocks));
}
} finally {
if (pages.size() != out.size()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import org.apache.lucene.sandbox.document.HalfFloatPoint;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.RemoteException;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.project.ProjectResolver;
Expand Down Expand Up @@ -76,7 +76,7 @@
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals;
import org.elasticsearch.xpack.esql.index.EsIndex;
import org.elasticsearch.xpack.esql.inference.InferenceResolution;
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
import org.elasticsearch.xpack.esql.inference.InferenceService;
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import org.elasticsearch.xpack.esql.parser.QueryParam;
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
Expand Down Expand Up @@ -161,8 +161,6 @@
import static org.hamcrest.Matchers.instanceOf;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;

public final class EsqlTestUtils {
Expand Down Expand Up @@ -422,20 +420,9 @@ public static LogicalOptimizerContext unboundLogicalOptimizerContext() {
mock(ProjectResolver.class),
mock(IndexNameExpressionResolver.class),
null,
mockInferenceRunner()
new InferenceService(mock(Client.class))
);

@SuppressWarnings("unchecked")
private static InferenceRunner mockInferenceRunner() {
InferenceRunner inferenceRunner = mock(InferenceRunner.class);
doAnswer(i -> {
i.getArgument(1, ActionListener.class).onResponse(emptyInferenceResolution());
return null;
}).when(inferenceRunner).resolveInferenceIds(any(), any());

return inferenceRunner;
}

private EsqlTestUtils() {}

public static Configuration configuration(QueryPragmas pragmas, String query) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,9 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
Limiter.ONCE,
new ResolveTable(),
new ResolveEnrich(),
new ResolveInference(),
new ResolveLookupTables(),
new ResolveFunctions(),
new ResolveInference(),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ℹ️ Inference resolution is moved after function resolution, so we can resolve inference ids use in functions like TEXT_EMBEDDING

new DateMillisToNanosInEsRelation(IMPLICIT_CASTING_DATE_AND_DATE_NANOS.isEnabled())
),
new Batch<>(
Expand Down Expand Up @@ -414,34 +414,6 @@ private static NamedExpression createEnrichFieldExpression(
}
}

private static class ResolveInference extends ParameterizedAnalyzerRule<InferencePlan<?>, AnalyzerContext> {
Copy link
Contributor

@ioanatia ioanatia Jul 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any difference between this implementation and the new one you added?
I don't think it is - they both do the same thing, even if this one overrides the rule method and the new one the apply one.

So this diff here looks like it was entirely unnecessary.
If we needed to make a change to ResolveInference, we could have made it here, not move it entirely later in the file. This forces the reviewers to check each line to figure out what actually changed.

Maybe the one thing that we needed in Analyzer.java was the order in which we apply rules. But even that is debatable, because it's not necessary yet, we could have switched the order when we actually add the text_embedding function. 🤷‍♀️

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no major difference but the new method is easier to extends since I can not chain more transform.

return plan.transformDown(InferencePlan.class, p -> resolveInferencePlan(p, context));

will become when implementing the resolution of InferenceFunction

return plan.transformDown(InferencePlan.class, p -> resolveInferencePlan(p, context))
    .transformExpressionsOnly(InferenceFunction.class, f -> resolveInferenceFunction(f, context));

This is the whole point of this PR to anticipate change that are required in the existing framework and to isolate these changes, so we can focus on testing potential regression.

BTW, I am 100% sure that we will have to change the order of the resolution and that I will use this change to finish the implementation of TEXT_EMBEDDING.

@Override
protected LogicalPlan rule(InferencePlan<?> plan, AnalyzerContext context) {
assert plan.inferenceId().resolved() && plan.inferenceId().foldable();

String inferenceId = BytesRefs.toString(plan.inferenceId().fold(FoldContext.small()));
ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId);

if (resolvedInference != null && resolvedInference.taskType() == plan.taskType()) {
return plan;
} else if (resolvedInference != null) {
String error = "cannot use inference endpoint ["
+ inferenceId
+ "] with task type ["
+ resolvedInference.taskType()
+ "] within a "
+ plan.nodeName()
+ " command. Only inference endpoints with the task type ["
+ plan.taskType()
+ "] are supported.";
return plan.withInferenceResolutionError(inferenceId, error);
} else {
String error = context.inferenceResolution().getError(inferenceId);
return plan.withInferenceResolutionError(inferenceId, error);
}
}
}

private static class ResolveLookupTables extends ParameterizedAnalyzerRule<Lookup, AnalyzerContext> {

@Override
Expand Down Expand Up @@ -1335,6 +1307,41 @@ public static org.elasticsearch.xpack.esql.core.expression.function.Function res
}
}

private static class ResolveInference extends ParameterizedRule<LogicalPlan, LogicalPlan, AnalyzerContext> {

@Override
public LogicalPlan apply(LogicalPlan plan, AnalyzerContext context) {
return plan.transformDown(InferencePlan.class, p -> resolveInferencePlan(p, context));
}

private LogicalPlan resolveInferencePlan(InferencePlan<?> plan, AnalyzerContext context) {
assert plan.inferenceId().resolved() && plan.inferenceId().foldable();

String inferenceId = BytesRefs.toString(plan.inferenceId().fold(FoldContext.small()));
ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId);

if (resolvedInference == null) {
String error = context.inferenceResolution().getError(inferenceId);
return plan.withInferenceResolutionError(inferenceId, error);
}

if (resolvedInference.taskType() != plan.taskType()) {
String error = "cannot use inference endpoint ["
+ inferenceId
+ "] with task type ["
+ resolvedInference.taskType()
+ "] within a "
+ plan.nodeName()
+ " command. Only inference endpoints with the task type ["
+ plan.taskType()
+ "] are supported.";
return plan.withInferenceResolutionError(inferenceId, error);
}

return plan;
}
}

private static class AddImplicitLimit extends ParameterizedRule<LogicalPlan, LogicalPlan, AnalyzerContext> {
@Override
public LogicalPlan apply(LogicalPlan logicalPlan, AnalyzerContext context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;

import java.util.ArrayList;
import java.util.HashSet;
Expand All @@ -28,25 +27,17 @@
public class PreAnalyzer {

public static class PreAnalysis {
public static final PreAnalysis EMPTY = new PreAnalysis(null, emptyList(), emptyList(), emptyList(), emptyList());
public static final PreAnalysis EMPTY = new PreAnalysis(null, emptyList(), emptyList(), emptyList());

public final IndexMode indexMode;
public final List<IndexPattern> indices;
public final List<Enrich> enriches;
public final List<InferencePlan<?>> inferencePlans;
public final List<IndexPattern> lookupIndices;

public PreAnalysis(
IndexMode indexMode,
List<IndexPattern> indices,
List<Enrich> enriches,
List<InferencePlan<?>> inferencePlans,
List<IndexPattern> lookupIndices
) {
public PreAnalysis(IndexMode indexMode, List<IndexPattern> indices, List<Enrich> enriches, List<IndexPattern> lookupIndices) {
this.indexMode = indexMode;
this.indices = indices;
this.enriches = enriches;
this.inferencePlans = inferencePlans;
this.lookupIndices = lookupIndices;
}
}
Expand All @@ -64,7 +55,7 @@ protected PreAnalysis doPreAnalyze(LogicalPlan plan) {

List<Enrich> unresolvedEnriches = new ArrayList<>();
List<IndexPattern> lookupIndices = new ArrayList<>();
List<InferencePlan<?>> unresolvedInferencePlans = new ArrayList<>();

Holder<IndexMode> indexMode = new Holder<>();
plan.forEachUp(UnresolvedRelation.class, p -> {
if (p.indexMode() == IndexMode.LOOKUP) {
Expand All @@ -78,11 +69,11 @@ protected PreAnalysis doPreAnalyze(LogicalPlan plan) {
});

plan.forEachUp(Enrich.class, unresolvedEnriches::add);
plan.forEachUp(InferencePlan.class, unresolvedInferencePlans::add);

// mark plan as preAnalyzed (if it were marked, there would be no analysis)
plan.forEachUp(LogicalPlan::setPreAnalyzed);

return new PreAnalysis(indexMode.get(), indices.stream().toList(), unresolvedEnriches, unresolvedInferencePlans, lookupIndices);
return new PreAnalysis(indexMode.get(), indices.stream().toList(), unresolvedEnriches, lookupIndices);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,16 @@
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutionConfig;
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutor;
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner;

import java.util.List;

import static org.elasticsearch.common.logging.LoggerMessageFormat.format;

/**
* An abstract asynchronous operator that performs throttled bulk inference execution using an {@link InferenceRunner}.
* An abstract asynchronous operator that performs throttled bulk inference execution using an {@link InferenceResolver}.
* <p>
* The {@code InferenceOperator} integrates with the compute framework supports throttled bulk execution of inference requests. It
* transforms input {@link Page} into inference requests, asynchronously executes them, and converts the responses into a new {@link Page}.
Expand All @@ -35,27 +33,25 @@
public abstract class InferenceOperator extends AsyncOperator<InferenceOperator.OngoingInferenceResult> {
private final String inferenceId;
private final BlockFactory blockFactory;
private final BulkInferenceExecutor bulkInferenceExecutor;
private final BulkInferenceRunner bulkInferenceRunner;

/**
* Constructs a new {@code InferenceOperator}.
*
* @param driverContext The driver context.
* @param inferenceRunner The runner used to execute inference requests.
* @param bulkExecutionConfig Configuration for inference execution.
* @param threadPool The thread pool used for executing async inference.
* @param bulkInferenceRunner Inference runner used to execute inference requests.
* @param inferenceId The ID of the inference model to use.
* @param maxOutstandingPages The number of concurrent pages to process in parallel.
*/
public InferenceOperator(
DriverContext driverContext,
InferenceRunner inferenceRunner,
BulkInferenceExecutionConfig bulkExecutionConfig,
ThreadPool threadPool,
String inferenceId
BulkInferenceRunner bulkInferenceRunner,
String inferenceId,
int maxOutstandingPages
) {
super(driverContext, inferenceRunner.threadPool().getThreadContext(), bulkExecutionConfig.workers());
super(driverContext, bulkInferenceRunner.threadPool().getThreadContext(), maxOutstandingPages);
this.blockFactory = driverContext.blockFactory();
this.bulkInferenceExecutor = new BulkInferenceExecutor(inferenceRunner, threadPool, bulkExecutionConfig);
this.bulkInferenceRunner = bulkInferenceRunner;
this.inferenceId = inferenceId;
}

Expand All @@ -81,7 +77,8 @@ protected void performAsync(Page input, ActionListener<OngoingInferenceResult> l
try {
BulkInferenceRequestIterator requests = requests(input);
listener = ActionListener.releaseBefore(requests, listener);
bulkInferenceExecutor.execute(requests, listener.map(responses -> new OngoingInferenceResult(input, responses)));

bulkInferenceRunner.executeBulk(requests, listener.map(responses -> new OngoingInferenceResult(input, responses)));
} catch (Exception e) {
listener.onFailure(e);
}
Expand Down Expand Up @@ -110,9 +107,9 @@ public Page getOutput() {
outputBuilder.addInferenceResponse(response);
}
return outputBuilder.buildOutput();

} finally {
} catch (Exception e) {
releaseFetchedOnAnyThread(ongoingInferenceResult);
throw e;
}
}

Expand Down
Loading