Skip to content

Commit dd3e3c9

Browse files
authored
ES|QL Inference runner refactoring (#131986)
1 parent 55f6078 commit dd3e3c9

34 files changed

+1028
-693
lines changed

x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/CannedSourceOperator.java

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,20 @@ public static List<Page> deepCopyOf(BlockFactory blockFactory, List<Page> pages)
8484
try {
8585
for (Page p : pages) {
8686
Block[] blocks = new Block[p.getBlockCount()];
87-
for (int b = 0; b < blocks.length; b++) {
88-
Block orig = p.getBlock(b);
89-
try (Block.Builder builder = orig.elementType().newBlockBuilder(p.getPositionCount(), blockFactory)) {
90-
builder.copyFrom(orig, 0, p.getPositionCount());
91-
blocks[b] = builder.build();
87+
try {
88+
for (int b = 0; b < blocks.length; b++) {
89+
Block orig = p.getBlock(b);
90+
try (Block.Builder builder = orig.elementType().newBlockBuilder(p.getPositionCount(), blockFactory)) {
91+
builder.copyFrom(orig, 0, p.getPositionCount());
92+
blocks[b] = builder.build();
93+
}
9294
}
95+
out.add(new Page(blocks));
96+
} catch (Exception e) {
97+
// Something went wrong, release the blocks.
98+
Releasables.closeExpectNoException(blocks);
99+
throw e;
93100
}
94-
out.add(new Page(blocks));
95101
}
96102
} finally {
97103
if (pages.size() != out.size()) {

x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import org.apache.lucene.sandbox.document.HalfFloatPoint;
1313
import org.apache.lucene.util.BytesRef;
1414
import org.elasticsearch.ExceptionsHelper;
15-
import org.elasticsearch.action.ActionListener;
15+
import org.elasticsearch.client.internal.Client;
1616
import org.elasticsearch.cluster.RemoteException;
1717
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
1818
import org.elasticsearch.cluster.project.ProjectResolver;
@@ -76,7 +76,7 @@
7676
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals;
7777
import org.elasticsearch.xpack.esql.index.EsIndex;
7878
import org.elasticsearch.xpack.esql.inference.InferenceResolution;
79-
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
79+
import org.elasticsearch.xpack.esql.inference.InferenceService;
8080
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
8181
import org.elasticsearch.xpack.esql.parser.QueryParam;
8282
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
@@ -161,8 +161,6 @@
161161
import static org.hamcrest.Matchers.instanceOf;
162162
import static org.junit.Assert.assertNotNull;
163163
import static org.junit.Assert.assertNull;
164-
import static org.mockito.ArgumentMatchers.any;
165-
import static org.mockito.Mockito.doAnswer;
166164
import static org.mockito.Mockito.mock;
167165

168166
public final class EsqlTestUtils {
@@ -422,20 +420,9 @@ public static LogicalOptimizerContext unboundLogicalOptimizerContext() {
422420
mock(ProjectResolver.class),
423421
mock(IndexNameExpressionResolver.class),
424422
null,
425-
mockInferenceRunner()
423+
new InferenceService(mock(Client.class))
426424
);
427425

428-
@SuppressWarnings("unchecked")
429-
private static InferenceRunner mockInferenceRunner() {
430-
InferenceRunner inferenceRunner = mock(InferenceRunner.class);
431-
doAnswer(i -> {
432-
i.getArgument(1, ActionListener.class).onResponse(emptyInferenceResolution());
433-
return null;
434-
}).when(inferenceRunner).resolveInferenceIds(any(), any());
435-
436-
return inferenceRunner;
437-
}
438-
439426
private EsqlTestUtils() {}
440427

441428
public static Configuration configuration(QueryPragmas pragmas, String query) {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 36 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,9 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
188188
Limiter.ONCE,
189189
new ResolveTable(),
190190
new ResolveEnrich(),
191-
new ResolveInference(),
192191
new ResolveLookupTables(),
193192
new ResolveFunctions(),
193+
new ResolveInference(),
194194
new DateMillisToNanosInEsRelation(IMPLICIT_CASTING_DATE_AND_DATE_NANOS.isEnabled())
195195
),
196196
new Batch<>(
@@ -414,34 +414,6 @@ private static NamedExpression createEnrichFieldExpression(
414414
}
415415
}
416416

417-
private static class ResolveInference extends ParameterizedAnalyzerRule<InferencePlan<?>, AnalyzerContext> {
418-
@Override
419-
protected LogicalPlan rule(InferencePlan<?> plan, AnalyzerContext context) {
420-
assert plan.inferenceId().resolved() && plan.inferenceId().foldable();
421-
422-
String inferenceId = BytesRefs.toString(plan.inferenceId().fold(FoldContext.small()));
423-
ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId);
424-
425-
if (resolvedInference != null && resolvedInference.taskType() == plan.taskType()) {
426-
return plan;
427-
} else if (resolvedInference != null) {
428-
String error = "cannot use inference endpoint ["
429-
+ inferenceId
430-
+ "] with task type ["
431-
+ resolvedInference.taskType()
432-
+ "] within a "
433-
+ plan.nodeName()
434-
+ " command. Only inference endpoints with the task type ["
435-
+ plan.taskType()
436-
+ "] are supported.";
437-
return plan.withInferenceResolutionError(inferenceId, error);
438-
} else {
439-
String error = context.inferenceResolution().getError(inferenceId);
440-
return plan.withInferenceResolutionError(inferenceId, error);
441-
}
442-
}
443-
}
444-
445417
private static class ResolveLookupTables extends ParameterizedAnalyzerRule<Lookup, AnalyzerContext> {
446418

447419
@Override
@@ -1335,6 +1307,41 @@ public static org.elasticsearch.xpack.esql.core.expression.function.Function res
13351307
}
13361308
}
13371309

1310+
private static class ResolveInference extends ParameterizedRule<LogicalPlan, LogicalPlan, AnalyzerContext> {
1311+
1312+
@Override
1313+
public LogicalPlan apply(LogicalPlan plan, AnalyzerContext context) {
1314+
return plan.transformDown(InferencePlan.class, p -> resolveInferencePlan(p, context));
1315+
}
1316+
1317+
private LogicalPlan resolveInferencePlan(InferencePlan<?> plan, AnalyzerContext context) {
1318+
assert plan.inferenceId().resolved() && plan.inferenceId().foldable();
1319+
1320+
String inferenceId = BytesRefs.toString(plan.inferenceId().fold(FoldContext.small()));
1321+
ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId);
1322+
1323+
if (resolvedInference == null) {
1324+
String error = context.inferenceResolution().getError(inferenceId);
1325+
return plan.withInferenceResolutionError(inferenceId, error);
1326+
}
1327+
1328+
if (resolvedInference.taskType() != plan.taskType()) {
1329+
String error = "cannot use inference endpoint ["
1330+
+ inferenceId
1331+
+ "] with task type ["
1332+
+ resolvedInference.taskType()
1333+
+ "] within a "
1334+
+ plan.nodeName()
1335+
+ " command. Only inference endpoints with the task type ["
1336+
+ plan.taskType()
1337+
+ "] are supported.";
1338+
return plan.withInferenceResolutionError(inferenceId, error);
1339+
}
1340+
1341+
return plan;
1342+
}
1343+
}
1344+
13381345
private static class AddImplicitLimit extends ParameterizedRule<LogicalPlan, LogicalPlan, AnalyzerContext> {
13391346
@Override
13401347
public LogicalPlan apply(LogicalPlan logicalPlan, AnalyzerContext context) {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzer.java

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
1414
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
1515
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
16-
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
1716

1817
import java.util.ArrayList;
1918
import java.util.HashSet;
@@ -28,25 +27,17 @@
2827
public class PreAnalyzer {
2928

3029
public static class PreAnalysis {
31-
public static final PreAnalysis EMPTY = new PreAnalysis(null, emptyList(), emptyList(), emptyList(), emptyList());
30+
public static final PreAnalysis EMPTY = new PreAnalysis(null, emptyList(), emptyList(), emptyList());
3231

3332
public final IndexMode indexMode;
3433
public final List<IndexPattern> indices;
3534
public final List<Enrich> enriches;
36-
public final List<InferencePlan<?>> inferencePlans;
3735
public final List<IndexPattern> lookupIndices;
3836

39-
public PreAnalysis(
40-
IndexMode indexMode,
41-
List<IndexPattern> indices,
42-
List<Enrich> enriches,
43-
List<InferencePlan<?>> inferencePlans,
44-
List<IndexPattern> lookupIndices
45-
) {
37+
public PreAnalysis(IndexMode indexMode, List<IndexPattern> indices, List<Enrich> enriches, List<IndexPattern> lookupIndices) {
4638
this.indexMode = indexMode;
4739
this.indices = indices;
4840
this.enriches = enriches;
49-
this.inferencePlans = inferencePlans;
5041
this.lookupIndices = lookupIndices;
5142
}
5243
}
@@ -64,7 +55,7 @@ protected PreAnalysis doPreAnalyze(LogicalPlan plan) {
6455

6556
List<Enrich> unresolvedEnriches = new ArrayList<>();
6657
List<IndexPattern> lookupIndices = new ArrayList<>();
67-
List<InferencePlan<?>> unresolvedInferencePlans = new ArrayList<>();
58+
6859
Holder<IndexMode> indexMode = new Holder<>();
6960
plan.forEachUp(UnresolvedRelation.class, p -> {
7061
if (p.indexMode() == IndexMode.LOOKUP) {
@@ -78,11 +69,11 @@ protected PreAnalysis doPreAnalyze(LogicalPlan plan) {
7869
});
7970

8071
plan.forEachUp(Enrich.class, unresolvedEnriches::add);
81-
plan.forEachUp(InferencePlan.class, unresolvedInferencePlans::add);
8272

8373
// mark plan as preAnalyzed (if it were marked, there would be no analysis)
8474
plan.forEachUp(LogicalPlan::setPreAnalyzed);
8575

86-
return new PreAnalysis(indexMode.get(), indices.stream().toList(), unresolvedEnriches, unresolvedInferencePlans, lookupIndices);
76+
return new PreAnalysis(indexMode.get(), indices.stream().toList(), unresolvedEnriches, lookupIndices);
8777
}
78+
8879
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,16 @@
1515
import org.elasticsearch.core.Releasable;
1616
import org.elasticsearch.core.Releasables;
1717
import org.elasticsearch.inference.InferenceServiceResults;
18-
import org.elasticsearch.threadpool.ThreadPool;
1918
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
20-
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutionConfig;
21-
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutor;
2219
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
20+
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner;
2321

2422
import java.util.List;
2523

2624
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
2725

2826
/**
29-
* An abstract asynchronous operator that performs throttled bulk inference execution using an {@link InferenceRunner}.
27+
* An abstract asynchronous operator that performs throttled bulk inference execution using an {@link InferenceResolver}.
3028
* <p>
3129
* The {@code InferenceOperator} integrates with the compute framework supports throttled bulk execution of inference requests. It
3230
* transforms input {@link Page} into inference requests, asynchronously executes them, and converts the responses into a new {@link Page}.
@@ -35,27 +33,25 @@
3533
public abstract class InferenceOperator extends AsyncOperator<InferenceOperator.OngoingInferenceResult> {
3634
private final String inferenceId;
3735
private final BlockFactory blockFactory;
38-
private final BulkInferenceExecutor bulkInferenceExecutor;
36+
private final BulkInferenceRunner bulkInferenceRunner;
3937

4038
/**
4139
* Constructs a new {@code InferenceOperator}.
4240
*
4341
* @param driverContext The driver context.
44-
* @param inferenceRunner The runner used to execute inference requests.
45-
* @param bulkExecutionConfig Configuration for inference execution.
46-
* @param threadPool The thread pool used for executing async inference.
42+
* @param bulkInferenceRunner Inference runner used to execute inference requests.
4743
* @param inferenceId The ID of the inference model to use.
44+
* @param maxOutstandingPages The number of concurrent pages to process in parallel.
4845
*/
4946
public InferenceOperator(
5047
DriverContext driverContext,
51-
InferenceRunner inferenceRunner,
52-
BulkInferenceExecutionConfig bulkExecutionConfig,
53-
ThreadPool threadPool,
54-
String inferenceId
48+
BulkInferenceRunner bulkInferenceRunner,
49+
String inferenceId,
50+
int maxOutstandingPages
5551
) {
56-
super(driverContext, inferenceRunner.threadPool().getThreadContext(), bulkExecutionConfig.workers());
52+
super(driverContext, bulkInferenceRunner.threadPool().getThreadContext(), maxOutstandingPages);
5753
this.blockFactory = driverContext.blockFactory();
58-
this.bulkInferenceExecutor = new BulkInferenceExecutor(inferenceRunner, threadPool, bulkExecutionConfig);
54+
this.bulkInferenceRunner = bulkInferenceRunner;
5955
this.inferenceId = inferenceId;
6056
}
6157

@@ -81,7 +77,8 @@ protected void performAsync(Page input, ActionListener<OngoingInferenceResult> l
8177
try {
8278
BulkInferenceRequestIterator requests = requests(input);
8379
listener = ActionListener.releaseBefore(requests, listener);
84-
bulkInferenceExecutor.execute(requests, listener.map(responses -> new OngoingInferenceResult(input, responses)));
80+
81+
bulkInferenceRunner.executeBulk(requests, listener.map(responses -> new OngoingInferenceResult(input, responses)));
8582
} catch (Exception e) {
8683
listener.onFailure(e);
8784
}
@@ -110,9 +107,9 @@ public Page getOutput() {
110107
outputBuilder.addInferenceResponse(response);
111108
}
112109
return outputBuilder.buildOutput();
113-
114-
} finally {
110+
} catch (Exception e) {
115111
releaseFetchedOnAnyThread(ongoingInferenceResult);
112+
throw e;
116113
}
117114
}
118115

0 commit comments

Comments
 (0)