Skip to content

Commit b80dc31

Browse files
committed
Modify PreAnalyzer so it will be easier to implement inference function pre-analysis.
1 parent 49e1834 commit b80dc31

File tree

5 files changed

+43
-38
lines changed

5 files changed

+43
-38
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzer.java

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77

88
package org.elasticsearch.xpack.esql.analysis;
99

10+
import org.elasticsearch.common.lucene.BytesRefs;
1011
import org.elasticsearch.index.IndexMode;
12+
import org.elasticsearch.xpack.esql.core.expression.Literal;
1113
import org.elasticsearch.xpack.esql.core.util.Holder;
1214
import org.elasticsearch.xpack.esql.plan.IndexPattern;
1315
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
@@ -21,32 +23,33 @@
2123
import java.util.Set;
2224

2325
import static java.util.Collections.emptyList;
26+
import static java.util.Collections.emptySet;
2427

2528
/**
2629
* This class is part of the planner. Acts somewhat like a linker, to find the indices and enrich policies referenced by the query.
2730
*/
2831
public class PreAnalyzer {
2932

3033
public static class PreAnalysis {
31-
public static final PreAnalysis EMPTY = new PreAnalysis(null, emptyList(), emptyList(), emptyList(), emptyList());
34+
public static final PreAnalysis EMPTY = new PreAnalysis(null, emptyList(), emptyList(), emptySet(), emptyList());
3235

3336
public final IndexMode indexMode;
3437
public final List<IndexPattern> indices;
3538
public final List<Enrich> enriches;
36-
public final List<InferencePlan<?>> inferencePlans;
39+
public final Set<String> inferenceIds;
3740
public final List<IndexPattern> lookupIndices;
3841

3942
public PreAnalysis(
4043
IndexMode indexMode,
4144
List<IndexPattern> indices,
4245
List<Enrich> enriches,
43-
List<InferencePlan<?>> inferencePlans,
46+
Set<String> inferenceIds,
4447
List<IndexPattern> lookupIndices
4548
) {
4649
this.indexMode = indexMode;
4750
this.indices = indices;
4851
this.enriches = enriches;
49-
this.inferencePlans = inferencePlans;
52+
this.inferenceIds = inferenceIds;
5053
this.lookupIndices = lookupIndices;
5154
}
5255
}
@@ -64,7 +67,7 @@ protected PreAnalysis doPreAnalyze(LogicalPlan plan) {
6467

6568
List<Enrich> unresolvedEnriches = new ArrayList<>();
6669
List<IndexPattern> lookupIndices = new ArrayList<>();
67-
List<InferencePlan<?>> unresolvedInferencePlans = new ArrayList<>();
70+
Set<String> unresolvedInferenceIds = new HashSet<>();
6871
Holder<IndexMode> indexMode = new Holder<>();
6972
plan.forEachUp(UnresolvedRelation.class, p -> {
7073
if (p.indexMode() == IndexMode.LOOKUP) {
@@ -78,11 +81,28 @@ protected PreAnalysis doPreAnalyze(LogicalPlan plan) {
7881
});
7982

8083
plan.forEachUp(Enrich.class, unresolvedEnriches::add);
81-
plan.forEachUp(InferencePlan.class, unresolvedInferencePlans::add);
8284

8385
// mark plan as preAnalyzed (if it were marked, there would be no analysis)
8486
plan.forEachUp(LogicalPlan::setPreAnalyzed);
8587

86-
return new PreAnalysis(indexMode.get(), indices.stream().toList(), unresolvedEnriches, unresolvedInferencePlans, lookupIndices);
88+
return new PreAnalysis(indexMode.get(), indices.stream().toList(), unresolvedEnriches, inferenceIds(plan), lookupIndices);
89+
}
90+
91+
protected Set<String> inferenceIds(LogicalPlan plan) {
92+
Set<String> inferenceIds = new HashSet<>();
93+
94+
List<InferencePlan<?>> inferencePlans = new ArrayList<>();
95+
plan.forEachUp(InferencePlan.class, inferencePlans::add);
96+
inferencePlans.stream().map(this::inferenceId).forEach(inferenceIds::add);
97+
98+
return inferenceIds;
99+
}
100+
101+
private String inferenceId(InferencePlan<?> inferencePlan) {
102+
if (inferencePlan.inferenceId() instanceof Literal literal) {
103+
return BytesRefs.toString(literal.value());
104+
}
105+
106+
throw new IllegalStateException("inferenceId is not a literal");
87107
}
88108
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceRunner.java

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@
1818
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
1919
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
2020

21-
import java.util.List;
2221
import java.util.Set;
23-
import java.util.stream.Collectors;
2422

2523
import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN;
2624
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
@@ -39,12 +37,7 @@ public ThreadPool threadPool() {
3937
return threadPool;
4038
}
4139

42-
public void resolveInferenceIds(List<InferencePlan<?>> plans, ActionListener<InferenceResolution> listener) {
43-
resolveInferenceIds(plans.stream().map(InferenceRunner::planInferenceId).collect(Collectors.toSet()), listener);
44-
45-
}
46-
47-
private void resolveInferenceIds(Set<String> inferenceIds, ActionListener<InferenceResolution> listener) {
40+
public void resolveInferenceIds(Set<String> inferenceIds, ActionListener<InferenceResolution> listener) {
4841

4942
if (inferenceIds.isEmpty()) {
5043
listener.onResponse(InferenceResolution.EMPTY);

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@
8585
import org.elasticsearch.xpack.esql.plan.logical.TopN;
8686
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
8787
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
88-
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
8988
import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin;
9089
import org.elasticsearch.xpack.esql.plan.logical.join.JoinTypes;
9190
import org.elasticsearch.xpack.esql.plan.logical.join.LookupJoin;
@@ -372,7 +371,7 @@ public void analyzedPlan(
372371
l -> enrichPolicyResolver.resolvePolicies(unresolvedPolicies, executionInfo, l)
373372
)
374373
.<PreAnalysisResult>andThen((l, enrichResolution) -> resolveFieldNames(parsed, enrichResolution, l))
375-
.<PreAnalysisResult>andThen((l, preAnalysisResult) -> resolveInferences(preAnalysis.inferencePlans, preAnalysisResult, l));
374+
.<PreAnalysisResult>andThen((l, preAnalysisResult) -> resolveInferences(preAnalysis.inferenceIds, preAnalysisResult, l));
376375
// first resolve the lookup indices, then the main indices
377376
for (var index : preAnalysis.lookupIndices) {
378377
listener = listener.andThen((l, preAnalysisResult) -> { preAnalyzeLookupIndex(index, preAnalysisResult, l); });
@@ -588,12 +587,8 @@ private static void resolveFieldNames(LogicalPlan parsed, EnrichResolution enric
588587
}
589588
}
590589

591-
private void resolveInferences(
592-
List<InferencePlan<?>> inferencePlans,
593-
PreAnalysisResult preAnalysisResult,
594-
ActionListener<PreAnalysisResult> l
595-
) {
596-
inferenceRunner.resolveInferenceIds(inferencePlans, l.map(preAnalysisResult::withInferenceResolution));
590+
private void resolveInferences(Set<String> inferenceIds, PreAnalysisResult preAnalysisResult, ActionListener<PreAnalysisResult> l) {
591+
inferenceRunner.resolveInferenceIds(inferenceIds, l.map(preAnalysisResult::withInferenceResolution));
597592
}
598593

599594
static PreAnalysisResult fieldNames(LogicalPlan parsed, Set<String> enrichPolicyMatchFields, PreAnalysisResult result) {

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.esql.analysis;
99

1010
import org.elasticsearch.index.IndexMode;
11+
import org.elasticsearch.inference.ModelConfigurations;
1112
import org.elasticsearch.inference.TaskType;
1213
import org.elasticsearch.xpack.core.enrich.EnrichPolicy;
1314
import org.elasticsearch.xpack.esql.EsqlTestUtils;
@@ -197,11 +198,10 @@ public static InferenceResolution defaultInferenceResolution() {
197198
.build();
198199
}
199200

200-
private static ResolvedInference mockedResolvedInference(String id, TaskType taskType) {
201-
ResolvedInference resolvedInference = mock(ResolvedInference.class);
202-
when(resolvedInference.inferenceId()).thenReturn(id);
203-
when(resolvedInference.taskType()).thenReturn(taskType);
204-
return resolvedInference;
201+
private static ResolvedInference mockedResolvedInference(String inferenceId, TaskType taskType) {
202+
ModelConfigurations modelConfigurations = mock(ModelConfigurations.class);
203+
when(modelConfigurations.getTaskType()).thenReturn(taskType);
204+
return new ResolvedInference(inferenceId, modelConfigurations);
205205
}
206206

207207
public static void loadEnrichPolicyResolution(

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceRunnerTests.java

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.junit.Before;
3333

3434
import java.util.List;
35+
import java.util.Set;
3536

3637
import static org.hamcrest.Matchers.allOf;
3738
import static org.hamcrest.Matchers.contains;
@@ -68,10 +69,10 @@ public void shutdownThreadPool() {
6869

6970
public void testResolveInferenceIds() throws Exception {
7071
InferenceRunner inferenceRunner = new InferenceRunner(mockClient(), threadPool);
71-
List<InferencePlan<?>> inferencePlans = List.of(mockInferencePlan("rerank-plan"));
72+
Set<String> inferenceIds = Set.of("rerank-plan");
7273
SetOnce<InferenceResolution> inferenceResolutionSetOnce = new SetOnce<>();
7374

74-
inferenceRunner.resolveInferenceIds(inferencePlans, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
75+
inferenceRunner.resolveInferenceIds(inferenceIds, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
7576
throw new RuntimeException(e);
7677
}));
7778

@@ -88,14 +89,10 @@ public void testResolveInferenceIds() throws Exception {
8889

8990
public void testResolveMultipleInferenceIds() throws Exception {
9091
InferenceRunner inferenceRunner = new InferenceRunner(mockClient(), threadPool);
91-
List<InferencePlan<?>> inferencePlans = List.of(
92-
mockInferencePlan("rerank-plan"),
93-
mockInferencePlan("rerank-plan"),
94-
mockInferencePlan("completion-plan")
95-
);
92+
Set<String> inferenceIds = Set.of("rerank-plan", "completion-plan");
9693
SetOnce<InferenceResolution> inferenceResolutionSetOnce = new SetOnce<>();
9794

98-
inferenceRunner.resolveInferenceIds(inferencePlans, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
95+
inferenceRunner.resolveInferenceIds(inferenceIds, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
9996
throw new RuntimeException(e);
10097
}));
10198

@@ -116,11 +113,11 @@ public void testResolveMultipleInferenceIds() throws Exception {
116113

117114
public void testResolveMissingInferenceIds() throws Exception {
118115
InferenceRunner inferenceRunner = new InferenceRunner(mockClient(), threadPool);
119-
List<InferencePlan<?>> inferencePlans = List.of(mockInferencePlan("missing-plan"));
116+
Set<String> inferenceIds = Set.of("missing-plan");
120117

121118
SetOnce<InferenceResolution> inferenceResolutionSetOnce = new SetOnce<>();
122119

123-
inferenceRunner.resolveInferenceIds(inferencePlans, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
120+
inferenceRunner.resolveInferenceIds(inferenceIds, ActionListener.wrap(inferenceResolutionSetOnce::set, e -> {
124121
throw new RuntimeException(e);
125122
}));
126123

0 commit comments

Comments
 (0)