Skip to content

Commit 94fdad1

Browse files
committed
Fix tests after rebasing.
1 parent 0859c7d commit 94fdad1

File tree

13 files changed

+177
-74
lines changed

13 files changed

+177
-74
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
6767
import org.elasticsearch.xpack.esql.expression.function.aggregate.SumOverTime;
6868
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
69+
import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction;
6970
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
7071
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
7172
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest;
@@ -1311,7 +1312,35 @@ private static class ResolveInference extends ParameterizedRule<LogicalPlan, Log
13111312

13121313
@Override
13131314
public LogicalPlan apply(LogicalPlan plan, AnalyzerContext context) {
1314-
return plan.transformDown(InferencePlan.class, p -> resolveInferencePlan(p, context));
1315+
return plan.transformDown(InferencePlan.class, p -> resolveInferencePlan(p, context))
1316+
.transformExpressionsOnly(InferenceFunction.class, f -> resolveInferenceFunction(f, context));
1317+
}
1318+
1319+
private InferenceFunction<?> resolveInferenceFunction(InferenceFunction<?> inferenceFunction, AnalyzerContext context) {
1320+
assert inferenceFunction.inferenceId().resolved() && inferenceFunction.inferenceId().foldable();
1321+
1322+
String inferenceId = BytesRefs.toString(inferenceFunction.inferenceId().fold(FoldContext.small()));
1323+
ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId);
1324+
1325+
if (resolvedInference == null) {
1326+
String error = context.inferenceResolution().getError(inferenceId);
1327+
return inferenceFunction.withInferenceResolutionError(inferenceId, error);
1328+
}
1329+
1330+
if (resolvedInference.taskType() != inferenceFunction.taskType()) {
1331+
String error = "cannot use inference endpoint ["
1332+
+ inferenceId
1333+
+ "] with task type ["
1334+
+ resolvedInference.taskType()
1335+
+ "] within a "
1336+
+ context.functionRegistry().snapshotRegistry().functionName(inferenceFunction.getClass())
1337+
+ " function. Only inference endpoints with the task type ["
1338+
+ inferenceFunction.taskType()
1339+
+ "] are supported.";
1340+
return inferenceFunction.withInferenceResolutionError(inferenceId, error);
1341+
}
1342+
1343+
return inferenceFunction;
13151344
}
13161345

13171346
private LogicalPlan resolveInferencePlan(InferencePlan<?> plan, AnalyzerContext context) {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ public void esql(
8585
indexResolver,
8686
enrichPolicyResolver,
8787
new PreAnalyzer(),
88-
new LogicalPlanPreOptimizer(new LogicalPreOptimizerContext(foldContext)),
88+
new LogicalPlanPreOptimizer(services, new LogicalPreOptimizerContext(foldContext)),
8989
functionRegistry,
9090
new LogicalPlanOptimizer(new LogicalOptimizerContext(cfg, foldContext)),
9191
mapper,

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
import org.elasticsearch.xpack.esql.core.expression.Expression;
1717
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
1818
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
19+
import org.elasticsearch.xpack.esql.expression.function.FunctionDefinition;
20+
import org.elasticsearch.xpack.esql.expression.function.UnresolvedFunction;
21+
import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction;
1922
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
2023
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
2124

@@ -31,13 +34,16 @@ public class InferenceResolver {
3134

3235
private final Client client;
3336

37+
private final EsqlFunctionRegistry functionRegistry;
38+
3439
/**
3540
* Constructs a new {@code InferenceResolver}.
3641
*
3742
* @param client The Elasticsearch client for executing inference deployment lookups
3843
*/
39-
public InferenceResolver(Client client) {
44+
public InferenceResolver(Client client, EsqlFunctionRegistry functionRegistry) {
4045
this.client = client;
46+
this.functionRegistry = functionRegistry;
4147
}
4248

4349
/**
@@ -72,6 +78,7 @@ public void resolveInferenceIds(LogicalPlan plan, ActionListener<InferenceResolu
7278
*/
7379
void collectInferenceIds(LogicalPlan plan, Consumer<String> c) {
7480
collectInferenceIdsFromInferencePlans(plan, c);
81+
collectInferenceIdsFromInferenceFunctions(plan, c);
7582
}
7683

7784
/**
@@ -131,6 +138,38 @@ private void collectInferenceIdsFromInferencePlans(LogicalPlan plan, Consumer<St
131138
plan.forEachUp(InferencePlan.class, inferencePlan -> c.accept(inferenceId(inferencePlan)));
132139
}
133140

141+
/**
142+
* Collects inference IDs from inference function calls within the logical plan.
143+
* <p>
144+
* This method scans the logical plan for {@link UnresolvedFunction} instances that represent
145+
* inference functions (e.g., TEXT_EMBEDDING). For each inference function found:
146+
* <ol>
147+
* <li>Resolves the function definition through the registry and checks if the function implements {@link InferenceFunction}</li>
148+
* <li>Extracts the inference deployment ID from the function arguments</li>
149+
* </ol>
150+
* <p>
151+
* This operates during pre-analysis when functions are still unresolved, allowing early
152+
* validation of inference deployments before query optimization.
153+
*
154+
* @param plan The logical plan to scan for inference function calls
155+
* @param c Consumer function to receive each discovered inference ID
156+
*/
157+
private void collectInferenceIdsFromInferenceFunctions(LogicalPlan plan, Consumer<String> c) {
158+
EsqlFunctionRegistry snapshotRegistry = functionRegistry.snapshotRegistry();
159+
plan.forEachExpressionUp(UnresolvedFunction.class, f -> {
160+
String functionName = snapshotRegistry.resolveAlias(f.name());
161+
if (snapshotRegistry.functionExists(functionName)) {
162+
FunctionDefinition def = snapshotRegistry.resolveFunction(functionName);
163+
if (InferenceFunction.class.isAssignableFrom(def.clazz())) {
164+
String inferenceId = inferenceId(f, def);
165+
if (inferenceId != null) {
166+
c.accept(inferenceId);
167+
}
168+
}
169+
}
170+
});
171+
}
172+
134173
/**
135174
* Extracts the inference ID from an InferencePlan object.
136175
*
@@ -141,10 +180,43 @@ private static String inferenceId(InferencePlan<?> plan) {
141180
return inferenceId(plan.inferenceId());
142181
}
143182

183+
/**
184+
* Extracts the inference ID from an Expression (expect the expression to be a constant).
185+
*/
144186
private static String inferenceId(Expression e) {
145187
return BytesRefs.toString(e.fold(FoldContext.small()));
146188
}
147189

190+
/**
191+
* Extracts the inference ID from an {@link UnresolvedFunction} instance.
192+
* <p>
193+
* This method inspects the function's arguments to find the inference ID.
194+
* Currently, it only supports positional parameters named "inference_id".
195+
*
196+
* @param f The unresolved function to extract the ID from
197+
* @param def The function definition
198+
* @return The inference ID as a string, or null if not found
199+
*/
200+
public String inferenceId(UnresolvedFunction f, FunctionDefinition def) {
201+
EsqlFunctionRegistry.FunctionDescription functionDescription = EsqlFunctionRegistry.description(def);
202+
203+
for (int i = 0; i < functionDescription.args().size(); i++) {
204+
EsqlFunctionRegistry.ArgSignature arg = functionDescription.args().get(i);
205+
206+
if (arg.name().equals(InferenceFunction.INFERENCE_ID_PARAMETER_NAME)) {
207+
// Found a positional parameter named "inference_id", so use its value
208+
Expression argValue = f.arguments().get(i);
209+
if (argValue != null && argValue.foldable()) {
210+
return inferenceId(argValue);
211+
}
212+
}
213+
214+
// TODO: support inference ID as an optional named parameter
215+
}
216+
217+
return null;
218+
}
219+
148220
public static Factory factory(Client client) {
149221
return new Factory(client);
150222
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceService.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.esql.inference;
99

1010
import org.elasticsearch.client.internal.Client;
11+
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
1112
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner;
1213
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunnerConfig;
1314

@@ -35,8 +36,8 @@ private InferenceService(InferenceResolver.Factory inferenceResolverFactory, Bul
3536
*
3637
* @return a new inference resolver instance
3738
*/
38-
public InferenceResolver inferenceResolver() {
39-
return inferenceResolverFactory.create();
39+
public InferenceResolver inferenceResolver(EsqlFunctionRegistry functionRegistry) {
40+
return inferenceResolverFactory.create(functionRegistry);
4041
}
4142

4243
public BulkInferenceRunner bulkInferenceRunner() {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ public class LogicalPlanPreOptimizer {
2828
private final List<PreOptimizerRule> rules;
2929

3030
public LogicalPlanPreOptimizer(TransportActionServices services, LogicalPreOptimizerContext preOptimizerContext) {
31-
rules = List.of(new InferenceFunctionConstantFolding(services.bulkInferenceRunner(), preOptimizerContext.foldCtx()));
31+
rules = List.of(
32+
new InferenceFunctionConstantFolding(services.inferenceService().bulkInferenceRunner(), preOptimizerContext.foldCtx())
33+
);
3234
}
3335

3436
/**

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,7 @@ private static void analyzeAndMaybeRetry(
752752
}
753753

754754
private void resolveInferences(LogicalPlan plan, PreAnalysisResult preAnalysisResult, ActionListener<PreAnalysisResult> l) {
755-
inferenceService.inferenceResolver().resolveInferenceIds(plan, l.map(preAnalysisResult::withInferenceResolution));
755+
inferenceService.inferenceResolver(functionRegistry).resolveInferenceIds(plan, l.map(preAnalysisResult::withInferenceResolution));
756756
}
757757

758758
private PhysicalPlan logicalPlanToPhysicalPlan(LogicalPlan optimizedPlan, EsqlQueryRequest request) {

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
import org.elasticsearch.xpack.esql.expression.function.fulltext.MultiMatch;
5151
import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString;
5252
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
53-
import org.elasticsearch.xpack.esql.expression.function.inference.EmbedText;
53+
import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding;
5454
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDateNanos;
5555
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDatetime;
5656
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger;
@@ -3865,13 +3865,13 @@ public void testResolveCompletionOutputField() {
38653865
public void testResolveEmbedTextInferenceId() {
38663866
LogicalPlan plan = analyze("""
38673867
FROM books METADATA _score
3868-
| EVAL embedding = EMBED_TEXT("description", "text-embedding-inference-id")
3868+
| EVAL embedding = TEXT_EMBEDDING("description", "text-embedding-inference-id")
38693869
""", "mapping-books.json");
38703870

38713871
var limit = as(plan, Limit.class);
38723872
var eval = as(limit.child(), Eval.class);
38733873
var embedTextAlias = as(eval.fields().get(0), Alias.class);
3874-
var embedText = as(embedTextAlias.child(), EmbedText.class);
3874+
var embedText = as(embedTextAlias.child(), TextEmbedding.class);
38753875

38763876
assertThat(embedText.inferenceId(), equalTo(string("text-embedding-inference-id")));
38773877
assertThat(embedText.inputText(), equalTo(string("description")));
@@ -3881,40 +3881,40 @@ public void testResolveEmbedTextInferenceIdInvalidTaskType() {
38813881
assertError(
38823882
"""
38833883
FROM books METADATA _score
3884-
| EVAL embedding = EMBED_TEXT(description, "completion-inference-id")
3884+
| EVAL embedding = TEXT_EMBEDDING("blue", "completion-inference-id")
38853885
""",
38863886
"mapping-books.json",
38873887
new QueryParams(),
3888-
"cannot use inference endpoint [completion-inference-id] with task type [completion] within a embed_text function."
3888+
"cannot use inference endpoint [completion-inference-id] with task type [completion] within a text_embedding function."
38893889
+ " Only inference endpoints with the task type [text_embedding] are supported"
38903890
);
38913891
}
38923892

38933893
public void testResolveEmbedTextInferenceMissingInferenceId() {
38943894
assertError("""
38953895
FROM books METADATA _score
3896-
| EVAL embedding = EMBED_TEXT(description, "unknown-inference-id")
3896+
| EVAL embedding = TEXT_EMBEDDING("blue", "unknown-inference-id")
38973897
""", "mapping-books.json", new QueryParams(), "unresolved inference [unknown-inference-id]");
38983898
}
38993899

39003900
public void testResolveEmbedTextInferenceIdResolutionError() {
39013901
assertError("""
39023902
FROM books METADATA _score
3903-
| EVAL embedding = EMBED_TEXT(description, "error-inference-id")
3903+
| EVAL embedding = TEXT_EMBEDDING("blue", "error-inference-id")
39043904
""", "mapping-books.json", new QueryParams(), "error with inference resolution");
39053905
}
39063906

39073907
public void testResolveEmbedTextInNestedExpression() {
39083908
LogicalPlan plan = analyze("""
39093909
FROM colors METADATA _score
3910-
| WHERE KNN(rgb_vector, EMBED_TEXT("blue", "text-embedding-inference-id"), 10)
3910+
| WHERE KNN(rgb_vector, TEXT_EMBEDDING("blue", "text-embedding-inference-id"), 10)
39113911
""", "mapping-colors.json");
39123912

39133913
var limit = as(plan, Limit.class);
39143914
var filter = as(limit.child(), Filter.class);
39153915

3916-
// Navigate to the EMBED_TEXT function within the KNN function
3917-
filter.condition().forEachDown(EmbedText.class, embedText -> {
3916+
// Navigate to the TEXT_EMBEDDING function within the KNN function
3917+
filter.condition().forEachDown(TextEmbedding.class, embedText -> {
39183918
assertThat(embedText.inferenceId(), equalTo(string("text-embedding-inference-id")));
39193919
assertThat(embedText.inputText(), equalTo(string("blue")));
39203920
});
@@ -3923,53 +3923,53 @@ public void testResolveEmbedTextInNestedExpression() {
39233923
public void testResolveEmbedTextDataType() {
39243924
LogicalPlan plan = analyze("""
39253925
FROM books METADATA _score
3926-
| EVAL embedding = EMBED_TEXT("description", "text-embedding-inference-id")
3926+
| EVAL embedding = TEXT_EMBEDDING("description", "text-embedding-inference-id")
39273927
""", "mapping-books.json");
39283928

39293929
var limit = as(plan, Limit.class);
39303930
var eval = as(limit.child(), Eval.class);
39313931
var embedTextAlias = as(eval.fields().get(0), Alias.class);
3932-
var embedText = as(embedTextAlias.child(), EmbedText.class);
3932+
var embedText = as(embedTextAlias.child(), TextEmbedding.class);
39333933

39343934
assertThat(embedText.dataType(), equalTo(DataType.DENSE_VECTOR));
39353935
}
39363936

39373937
public void testResolveEmbedTextInvalidParameters() {
39383938
assertError(
3939-
"FROM books METADATA _score| EVAL embedding = EMBED_TEXT(description, \"text-embedding-inference-id\")",
3939+
"FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING(description, \"text-embedding-inference-id\")",
39403940
"mapping-books.json",
39413941
new QueryParams(),
3942-
"first argument of [EMBED_TEXT(description, \"text-embedding-inference-id\")] must be a constant, received [description]"
3942+
"first argument of [TEXT_EMBEDDING(description, \"text-embedding-inference-id\")] must be a constant, received [description]"
39433943
);
39443944

39453945
assertError(
3946-
"FROM books METADATA _score| EVAL embedding = EMBED_TEXT(description)",
3946+
"FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING(description)",
39473947
"mapping-books.json",
39483948
new QueryParams(),
3949-
"error building [embed_text]: function [embed_text] expects exactly two arguments, it received 1",
3949+
"error building [text_embedding]: function [text_embedding] expects exactly two arguments, it received 1",
39503950
ParsingException.class
39513951
);
39523952
}
39533953

39543954
public void testResolveEmbedTextWithPositionalQueryParams() {
39553955
LogicalPlan plan = analyze(
3956-
"FROM books METADATA _score| EVAL embedding = EMBED_TEXT(?, ?)",
3956+
"FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING(?, ?)",
39573957
"mapping-books.json",
39583958
new QueryParams(List.of(paramAsConstant(null, "description"), paramAsConstant(null, "text-embedding-inference-id")))
39593959
);
39603960

39613961
var limit = as(plan, Limit.class);
39623962
var eval = as(limit.child(), Eval.class);
39633963
var embedTextAlias = as(eval.fields().get(0), Alias.class);
3964-
var embedText = as(embedTextAlias.child(), EmbedText.class);
3964+
var embedText = as(embedTextAlias.child(), TextEmbedding.class);
39653965

39663966
assertThat(embedText.inferenceId(), equalTo(string("text-embedding-inference-id")));
39673967
assertThat(embedText.inputText(), equalTo(string("description")));
39683968
}
39693969

39703970
public void testResolveEmbedTextWithNamedQueryParams() {
39713971
LogicalPlan plan = analyze(
3972-
"FROM books METADATA _score| EVAL embedding = EMBED_TEXT(?inputText, ?inferenceId)",
3972+
"FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING(?inputText, ?inferenceId)",
39733973
"mapping-books.json",
39743974
new QueryParams(
39753975
List.of(paramAsConstant("inputText", "description"), paramAsConstant("inferenceId", "text-embedding-inference-id"))
@@ -3979,7 +3979,7 @@ public void testResolveEmbedTextWithNamedQueryParams() {
39793979
var limit = as(plan, Limit.class);
39803980
var eval = as(limit.child(), Eval.class);
39813981
var embedTextAlias = as(eval.fields().get(0), Alias.class);
3982-
var embedText = as(embedTextAlias.child(), EmbedText.class);
3982+
var embedText = as(embedTextAlias.child(), TextEmbedding.class);
39833983

39843984
assertThat(embedText.inferenceId(), equalTo(string("text-embedding-inference-id")));
39853985
assertThat(embedText.inputText(), equalTo(string("description")));

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
import org.elasticsearch.inference.InferenceServiceResults;
3737
import org.elasticsearch.test.client.NoOpClient;
3838
import org.elasticsearch.threadpool.FixedExecutorBuilder;
39-
import org.elasticsearch.threadpool.TestThreadPool;
4039
import org.elasticsearch.threadpool.ThreadPool;
4140
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
4241
import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
@@ -56,8 +55,7 @@ public abstract class InferenceOperatorTestCase<InferenceResultsType extends Inf
5655

5756
@Before
5857
public void setThreadPool() {
59-
threadPool = new TestThreadPool(
60-
getTestClass().getSimpleName(),
58+
threadPool = createThreadPool(
6159
new FixedExecutorBuilder(
6260
Settings.EMPTY,
6361
EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME,

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceResolverTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.elasticsearch.threadpool.FixedExecutorBuilder;
2323
import org.elasticsearch.threadpool.TestThreadPool;
2424
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
25+
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
2526
import org.elasticsearch.xpack.esql.parser.EsqlParser;
2627
import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
2728
import org.junit.After;
@@ -205,7 +206,7 @@ private static ActionResponse getInferenceModelResponse(GetInferenceModelAction.
205206
}
206207

207208
private InferenceResolver inferenceResolver() {
208-
return new InferenceResolver(mockClient());
209+
return new InferenceResolver(mockClient(), new EsqlFunctionRegistry());
209210
}
210211

211212
private static ModelConfigurations mockModelConfig(String inferenceId, TaskType taskType) {

0 commit comments

Comments
 (0)