Skip to content

Commit cab92a1

Browse files
committed
Fix tests after rebasing.
1 parent c09800a commit cab92a1

File tree

12 files changed

+176
-73
lines changed

12 files changed

+176
-73
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
6767
import org.elasticsearch.xpack.esql.expression.function.aggregate.SumOverTime;
6868
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
69+
import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction;
6970
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
7071
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
7172
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest;
@@ -1311,7 +1312,35 @@ private static class ResolveInference extends ParameterizedRule<LogicalPlan, Log
13111312

13121313
@Override
13131314
public LogicalPlan apply(LogicalPlan plan, AnalyzerContext context) {
1314-
return plan.transformDown(InferencePlan.class, p -> resolveInferencePlan(p, context));
1315+
return plan.transformDown(InferencePlan.class, p -> resolveInferencePlan(p, context))
1316+
.transformExpressionsOnly(InferenceFunction.class, f -> resolveInferenceFunction(f, context));
1317+
}
1318+
1319+
private InferenceFunction<?> resolveInferenceFunction(InferenceFunction<?> inferenceFunction, AnalyzerContext context) {
1320+
assert inferenceFunction.inferenceId().resolved() && inferenceFunction.inferenceId().foldable();
1321+
1322+
String inferenceId = BytesRefs.toString(inferenceFunction.inferenceId().fold(FoldContext.small()));
1323+
ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId);
1324+
1325+
if (resolvedInference == null) {
1326+
String error = context.inferenceResolution().getError(inferenceId);
1327+
return inferenceFunction.withInferenceResolutionError(inferenceId, error);
1328+
}
1329+
1330+
if (resolvedInference.taskType() != inferenceFunction.taskType()) {
1331+
String error = "cannot use inference endpoint ["
1332+
+ inferenceId
1333+
+ "] with task type ["
1334+
+ resolvedInference.taskType()
1335+
+ "] within a "
1336+
+ context.functionRegistry().snapshotRegistry().functionName(inferenceFunction.getClass())
1337+
+ " function. Only inference endpoints with the task type ["
1338+
+ inferenceFunction.taskType()
1339+
+ "] are supported.";
1340+
return inferenceFunction.withInferenceResolutionError(inferenceId, error);
1341+
}
1342+
1343+
return inferenceFunction;
13151344
}
13161345

13171346
private LogicalPlan resolveInferencePlan(InferencePlan<?> plan, AnalyzerContext context) {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
import org.elasticsearch.xpack.esql.core.expression.Expression;
1717
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
1818
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
19+
import org.elasticsearch.xpack.esql.expression.function.FunctionDefinition;
20+
import org.elasticsearch.xpack.esql.expression.function.UnresolvedFunction;
21+
import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction;
1922
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
2023
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
2124

@@ -31,13 +34,16 @@ public class InferenceResolver {
3134

3235
private final Client client;
3336

37+
private final EsqlFunctionRegistry functionRegistry;
38+
3439
/**
3540
* Constructs a new {@code InferenceResolver}.
3641
*
3742
* @param client The Elasticsearch client for executing inference deployment lookups
3843
*/
39-
public InferenceResolver(Client client) {
44+
public InferenceResolver(Client client, EsqlFunctionRegistry functionRegistry) {
4045
this.client = client;
46+
this.functionRegistry = functionRegistry;
4147
}
4248

4349
/**
@@ -72,6 +78,7 @@ public void resolveInferenceIds(LogicalPlan plan, ActionListener<InferenceResolu
7278
*/
7379
void collectInferenceIds(LogicalPlan plan, Consumer<String> c) {
7480
collectInferenceIdsFromInferencePlans(plan, c);
81+
collectInferenceIdsFromInferenceFunctions(plan, c);
7582
}
7683

7784
/**
@@ -131,6 +138,38 @@ private void collectInferenceIdsFromInferencePlans(LogicalPlan plan, Consumer<St
131138
plan.forEachUp(InferencePlan.class, inferencePlan -> c.accept(inferenceId(inferencePlan)));
132139
}
133140

141+
/**
142+
* Collects inference IDs from inference function calls within the logical plan.
143+
* <p>
144+
* This method scans the logical plan for {@link UnresolvedFunction} instances that represent
145+
* inference functions (e.g., TEXT_EMBEDDING). For each inference function found:
146+
* <ol>
147+
* <li>Resolves the function definition through the registry and checks if the function implements {@link InferenceFunction}</li>
148+
* <li>Extracts the inference deployment ID from the function arguments</li>
149+
* </ol>
150+
* <p>
151+
* This operates during pre-analysis when functions are still unresolved, allowing early
152+
* validation of inference deployments before query optimization.
153+
*
154+
* @param plan The logical plan to scan for inference function calls
155+
* @param c Consumer function to receive each discovered inference ID
156+
*/
157+
private void collectInferenceIdsFromInferenceFunctions(LogicalPlan plan, Consumer<String> c) {
158+
EsqlFunctionRegistry snapshotRegistry = functionRegistry.snapshotRegistry();
159+
plan.forEachExpressionUp(UnresolvedFunction.class, f -> {
160+
String functionName = snapshotRegistry.resolveAlias(f.name());
161+
if (snapshotRegistry.functionExists(functionName)) {
162+
FunctionDefinition def = snapshotRegistry.resolveFunction(functionName);
163+
if (InferenceFunction.class.isAssignableFrom(def.clazz())) {
164+
String inferenceId = inferenceId(f, def);
165+
if (inferenceId != null) {
166+
c.accept(inferenceId);
167+
}
168+
}
169+
}
170+
});
171+
}
172+
134173
/**
135174
* Extracts the inference ID from an InferencePlan object.
136175
*
@@ -141,10 +180,43 @@ private static String inferenceId(InferencePlan<?> plan) {
141180
return inferenceId(plan.inferenceId());
142181
}
143182

183+
/**
184+
* Extracts the inference ID from an Expression (expect the expression to be a constant).
185+
*/
144186
private static String inferenceId(Expression e) {
145187
return BytesRefs.toString(e.fold(FoldContext.small()));
146188
}
147189

190+
/**
191+
* Extracts the inference ID from an {@link UnresolvedFunction} instance.
192+
* <p>
193+
* This method inspects the function's arguments to find the inference ID.
194+
* Currently, it only supports positional parameters named "inference_id".
195+
*
196+
* @param f The unresolved function to extract the ID from
197+
* @param def The function definition
198+
* @return The inference ID as a string, or null if not found
199+
*/
200+
public String inferenceId(UnresolvedFunction f, FunctionDefinition def) {
201+
EsqlFunctionRegistry.FunctionDescription functionDescription = EsqlFunctionRegistry.description(def);
202+
203+
for (int i = 0; i < functionDescription.args().size(); i++) {
204+
EsqlFunctionRegistry.ArgSignature arg = functionDescription.args().get(i);
205+
206+
if (arg.name().equals(InferenceFunction.INFERENCE_ID_PARAMETER_NAME)) {
207+
// Found a positional parameter named "inference_id", so use its value
208+
Expression argValue = f.arguments().get(i);
209+
if (argValue != null && argValue.foldable()) {
210+
return inferenceId(argValue);
211+
}
212+
}
213+
214+
// TODO: support inference ID as an optional named parameter
215+
}
216+
217+
return null;
218+
}
219+
148220
public static Factory factory(Client client) {
149221
return new Factory(client);
150222
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceService.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.esql.inference;
99

1010
import org.elasticsearch.client.internal.Client;
11+
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
1112
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner;
1213
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunnerConfig;
1314

@@ -35,8 +36,8 @@ private InferenceService(InferenceResolver.Factory inferenceResolverFactory, Bul
3536
*
3637
* @return a new inference resolver instance
3738
*/
38-
public InferenceResolver inferenceResolver() {
39-
return inferenceResolverFactory.create();
39+
public InferenceResolver inferenceResolver(EsqlFunctionRegistry functionRegistry) {
40+
return inferenceResolverFactory.create(functionRegistry);
4041
}
4142

4243
public BulkInferenceRunner bulkInferenceRunner() {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ public class LogicalPlanPreOptimizer {
2828
private final List<PreOptimizerRule> rules;
2929

3030
public LogicalPlanPreOptimizer(TransportActionServices services, LogicalPreOptimizerContext preOptimizerContext) {
31-
rules = List.of(new InferenceFunctionConstantFolding(services.bulkInferenceRunner(), preOptimizerContext.foldCtx()));
31+
rules = List.of(
32+
new InferenceFunctionConstantFolding(services.inferenceService().bulkInferenceRunner(), preOptimizerContext.foldCtx())
33+
);
3234
}
3335

3436
/**

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,7 @@ private static void analyzeAndMaybeRetry(
751751
}
752752

753753
private void resolveInferences(LogicalPlan plan, PreAnalysisResult preAnalysisResult, ActionListener<PreAnalysisResult> l) {
754-
inferenceService.inferenceResolver().resolveInferenceIds(plan, l.map(preAnalysisResult::withInferenceResolution));
754+
inferenceService.inferenceResolver(functionRegistry).resolveInferenceIds(plan, l.map(preAnalysisResult::withInferenceResolution));
755755
}
756756

757757
private PhysicalPlan logicalPlanToPhysicalPlan(LogicalPlan optimizedPlan, EsqlQueryRequest request) {

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
import org.elasticsearch.xpack.esql.expression.function.fulltext.MultiMatch;
5151
import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString;
5252
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
53-
import org.elasticsearch.xpack.esql.expression.function.inference.EmbedText;
53+
import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding;
5454
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDateNanos;
5555
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDatetime;
5656
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger;
@@ -3843,13 +3843,13 @@ public void testResolveCompletionOutputFieldOverwriteInputField() {
38433843
public void testResolveEmbedTextInferenceId() {
38443844
LogicalPlan plan = analyze("""
38453845
FROM books METADATA _score
3846-
| EVAL embedding = EMBED_TEXT("description", "text-embedding-inference-id")
3846+
| EVAL embedding = TEXT_EMBEDDING("description", "text-embedding-inference-id")
38473847
""", "mapping-books.json");
38483848

38493849
var limit = as(plan, Limit.class);
38503850
var eval = as(limit.child(), Eval.class);
38513851
var embedTextAlias = as(eval.fields().get(0), Alias.class);
3852-
var embedText = as(embedTextAlias.child(), EmbedText.class);
3852+
var embedText = as(embedTextAlias.child(), TextEmbedding.class);
38533853

38543854
assertThat(embedText.inferenceId(), equalTo(string("text-embedding-inference-id")));
38553855
assertThat(embedText.inputText(), equalTo(string("description")));
@@ -3859,40 +3859,40 @@ public void testResolveEmbedTextInferenceIdInvalidTaskType() {
38593859
assertError(
38603860
"""
38613861
FROM books METADATA _score
3862-
| EVAL embedding = EMBED_TEXT(description, "completion-inference-id")
3862+
| EVAL embedding = TEXT_EMBEDDING("blue", "completion-inference-id")
38633863
""",
38643864
"mapping-books.json",
38653865
new QueryParams(),
3866-
"cannot use inference endpoint [completion-inference-id] with task type [completion] within a embed_text function."
3866+
"cannot use inference endpoint [completion-inference-id] with task type [completion] within a text_embedding function."
38673867
+ " Only inference endpoints with the task type [text_embedding] are supported"
38683868
);
38693869
}
38703870

38713871
public void testResolveEmbedTextInferenceMissingInferenceId() {
38723872
assertError("""
38733873
FROM books METADATA _score
3874-
| EVAL embedding = EMBED_TEXT(description, "unknown-inference-id")
3874+
| EVAL embedding = TEXT_EMBEDDING("blue", "unknown-inference-id")
38753875
""", "mapping-books.json", new QueryParams(), "unresolved inference [unknown-inference-id]");
38763876
}
38773877

38783878
public void testResolveEmbedTextInferenceIdResolutionError() {
38793879
assertError("""
38803880
FROM books METADATA _score
3881-
| EVAL embedding = EMBED_TEXT(description, "error-inference-id")
3881+
| EVAL embedding = TEXT_EMBEDDING("blue", "error-inference-id")
38823882
""", "mapping-books.json", new QueryParams(), "error with inference resolution");
38833883
}
38843884

38853885
public void testResolveEmbedTextInNestedExpression() {
38863886
LogicalPlan plan = analyze("""
38873887
FROM colors METADATA _score
3888-
| WHERE KNN(rgb_vector, EMBED_TEXT("blue", "text-embedding-inference-id"), 10)
3888+
| WHERE KNN(rgb_vector, TEXT_EMBEDDING("blue", "text-embedding-inference-id"), 10)
38893889
""", "mapping-colors.json");
38903890

38913891
var limit = as(plan, Limit.class);
38923892
var filter = as(limit.child(), Filter.class);
38933893

3894-
// Navigate to the EMBED_TEXT function within the KNN function
3895-
filter.condition().forEachDown(EmbedText.class, embedText -> {
3894+
// Navigate to the TEXT_EMBEDDING function within the KNN function
3895+
filter.condition().forEachDown(TextEmbedding.class, embedText -> {
38963896
assertThat(embedText.inferenceId(), equalTo(string("text-embedding-inference-id")));
38973897
assertThat(embedText.inputText(), equalTo(string("blue")));
38983898
});
@@ -3901,53 +3901,53 @@ public void testResolveEmbedTextInNestedExpression() {
39013901
public void testResolveEmbedTextDataType() {
39023902
LogicalPlan plan = analyze("""
39033903
FROM books METADATA _score
3904-
| EVAL embedding = EMBED_TEXT("description", "text-embedding-inference-id")
3904+
| EVAL embedding = TEXT_EMBEDDING("description", "text-embedding-inference-id")
39053905
""", "mapping-books.json");
39063906

39073907
var limit = as(plan, Limit.class);
39083908
var eval = as(limit.child(), Eval.class);
39093909
var embedTextAlias = as(eval.fields().get(0), Alias.class);
3910-
var embedText = as(embedTextAlias.child(), EmbedText.class);
3910+
var embedText = as(embedTextAlias.child(), TextEmbedding.class);
39113911

39123912
assertThat(embedText.dataType(), equalTo(DataType.DENSE_VECTOR));
39133913
}
39143914

39153915
public void testResolveEmbedTextInvalidParameters() {
39163916
assertError(
3917-
"FROM books METADATA _score| EVAL embedding = EMBED_TEXT(description, \"text-embedding-inference-id\")",
3917+
"FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING(description, \"text-embedding-inference-id\")",
39183918
"mapping-books.json",
39193919
new QueryParams(),
3920-
"first argument of [EMBED_TEXT(description, \"text-embedding-inference-id\")] must be a constant, received [description]"
3920+
"first argument of [TEXT_EMBEDDING(description, \"text-embedding-inference-id\")] must be a constant, received [description]"
39213921
);
39223922

39233923
assertError(
3924-
"FROM books METADATA _score| EVAL embedding = EMBED_TEXT(description)",
3924+
"FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING(description)",
39253925
"mapping-books.json",
39263926
new QueryParams(),
3927-
"error building [embed_text]: function [embed_text] expects exactly two arguments, it received 1",
3927+
"error building [text_embedding]: function [text_embedding] expects exactly two arguments, it received 1",
39283928
ParsingException.class
39293929
);
39303930
}
39313931

39323932
public void testResolveEmbedTextWithPositionalQueryParams() {
39333933
LogicalPlan plan = analyze(
3934-
"FROM books METADATA _score| EVAL embedding = EMBED_TEXT(?, ?)",
3934+
"FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING(?, ?)",
39353935
"mapping-books.json",
39363936
new QueryParams(List.of(paramAsConstant(null, "description"), paramAsConstant(null, "text-embedding-inference-id")))
39373937
);
39383938

39393939
var limit = as(plan, Limit.class);
39403940
var eval = as(limit.child(), Eval.class);
39413941
var embedTextAlias = as(eval.fields().get(0), Alias.class);
3942-
var embedText = as(embedTextAlias.child(), EmbedText.class);
3942+
var embedText = as(embedTextAlias.child(), TextEmbedding.class);
39433943

39443944
assertThat(embedText.inferenceId(), equalTo(string("text-embedding-inference-id")));
39453945
assertThat(embedText.inputText(), equalTo(string("description")));
39463946
}
39473947

39483948
public void testResolveEmbedTextWithNamedQueryParams() {
39493949
LogicalPlan plan = analyze(
3950-
"FROM books METADATA _score| EVAL embedding = EMBED_TEXT(?inputText, ?inferenceId)",
3950+
"FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING(?inputText, ?inferenceId)",
39513951
"mapping-books.json",
39523952
new QueryParams(
39533953
List.of(paramAsConstant("inputText", "description"), paramAsConstant("inferenceId", "text-embedding-inference-id"))
@@ -3957,7 +3957,7 @@ public void testResolveEmbedTextWithNamedQueryParams() {
39573957
var limit = as(plan, Limit.class);
39583958
var eval = as(limit.child(), Eval.class);
39593959
var embedTextAlias = as(eval.fields().get(0), Alias.class);
3960-
var embedText = as(embedTextAlias.child(), EmbedText.class);
3960+
var embedText = as(embedTextAlias.child(), TextEmbedding.class);
39613961

39623962
assertThat(embedText.inferenceId(), equalTo(string("text-embedding-inference-id")));
39633963
assertThat(embedText.inputText(), equalTo(string("description")));

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
import org.elasticsearch.inference.InferenceServiceResults;
3737
import org.elasticsearch.test.client.NoOpClient;
3838
import org.elasticsearch.threadpool.FixedExecutorBuilder;
39-
import org.elasticsearch.threadpool.TestThreadPool;
4039
import org.elasticsearch.threadpool.ThreadPool;
4140
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
4241
import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
@@ -56,8 +55,7 @@ public abstract class InferenceOperatorTestCase<InferenceResultsType extends Inf
5655

5756
@Before
5857
public void setThreadPool() {
59-
threadPool = new TestThreadPool(
60-
getTestClass().getSimpleName(),
58+
threadPool = createThreadPool(
6159
new FixedExecutorBuilder(
6260
Settings.EMPTY,
6361
EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME,

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceResolverTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.elasticsearch.threadpool.FixedExecutorBuilder;
2323
import org.elasticsearch.threadpool.TestThreadPool;
2424
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
25+
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
2526
import org.elasticsearch.xpack.esql.parser.EsqlParser;
2627
import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
2728
import org.junit.After;
@@ -205,7 +206,7 @@ private static ActionResponse getInferenceModelResponse(GetInferenceModelAction.
205206
}
206207

207208
private InferenceResolver inferenceResolver() {
208-
return new InferenceResolver(mockClient());
209+
return new InferenceResolver(mockClient(), new EsqlFunctionRegistry());
209210
}
210211

211212
private static ModelConfigurations mockModelConfig(String inferenceId, TaskType taskType) {

0 commit comments

Comments
 (0)