Skip to content

Commit 847c998

Browse files
committed
InferenceResolution for text embedding function.
1 parent 8f4c409 commit 847c998

File tree

5 files changed

+79
-10
lines changed

5 files changed

+79
-10
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
5757
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
5858
import org.elasticsearch.xpack.esql.expression.function.grouping.TBucket;
59+
import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding;
5960
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
6061
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest;
6162
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Least;
@@ -193,7 +194,6 @@
193194
import org.elasticsearch.xpack.esql.expression.function.vector.L1Norm;
194195
import org.elasticsearch.xpack.esql.expression.function.vector.L2Norm;
195196
import org.elasticsearch.xpack.esql.expression.function.vector.Magnitude;
196-
import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding;
197197
import org.elasticsearch.xpack.esql.parser.ParsingException;
198198
import org.elasticsearch.xpack.esql.session.Configuration;
199199

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
1616
import org.elasticsearch.xpack.esql.core.expression.Expression;
1717
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
18+
import org.elasticsearch.xpack.esql.core.type.DataType;
19+
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
20+
import org.elasticsearch.xpack.esql.expression.function.FunctionDefinition;
21+
import org.elasticsearch.xpack.esql.expression.function.UnresolvedFunction;
22+
import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction;
1823
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
1924
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
2025

@@ -29,14 +34,16 @@
2934
public class InferenceResolver {
3035

3136
private final Client client;
37+
private final EsqlFunctionRegistry functionRegistry;
3238

3339
/**
3440
* Constructs a new {@code InferenceResolver}.
3541
*
3642
* @param client The Elasticsearch client for executing inference deployment lookups
3743
*/
38-
public InferenceResolver(Client client) {
44+
public InferenceResolver(Client client, EsqlFunctionRegistry functionRegistry) {
3945
this.client = client;
46+
this.functionRegistry = functionRegistry;
4047
}
4148

4249
/**
@@ -71,6 +78,7 @@ public void resolveInferenceIds(LogicalPlan plan, ActionListener<InferenceResolu
7178
*/
7279
void collectInferenceIds(LogicalPlan plan, Consumer<String> c) {
7380
collectInferenceIdsFromInferencePlans(plan, c);
81+
collectInferenceIdsFromInferenceFunctions(plan, c);
7482
}
7583

7684
/**
@@ -130,6 +138,28 @@ private void collectInferenceIdsFromInferencePlans(LogicalPlan plan, Consumer<St
130138
plan.forEachUp(InferencePlan.class, inferencePlan -> c.accept(inferenceId(inferencePlan)));
131139
}
132140

141+
/**
142+
* Collects inference IDs from function expressions within the logical plan.
143+
*
144+
* @param plan The logical plan to scan for function expressions
145+
* @param c Consumer function to receive each discovered inference ID
146+
*/
147+
private void collectInferenceIdsFromInferenceFunctions(LogicalPlan plan, Consumer<String> c) {
148+
EsqlFunctionRegistry snapshotRegistry = functionRegistry.snapshotRegistry();
149+
plan.forEachExpressionUp(UnresolvedFunction.class, f -> {
150+
String functionName = snapshotRegistry.resolveAlias(f.name());
151+
if (snapshotRegistry.functionExists(functionName)) {
152+
FunctionDefinition def = snapshotRegistry.resolveFunction(functionName);
153+
if (InferenceFunction.class.isAssignableFrom(def.clazz())) {
154+
String inferenceId = inferenceId(f, def);
155+
if (inferenceId != null) {
156+
c.accept(inferenceId);
157+
}
158+
}
159+
}
160+
});
161+
}
162+
133163
/**
134164
* Extracts the inference ID from an InferencePlan object.
135165
*
@@ -144,6 +174,23 @@ private static String inferenceId(Expression e) {
144174
return BytesRefs.toString(e.fold(FoldContext.small()));
145175
}
146176

177+
public String inferenceId(UnresolvedFunction f, FunctionDefinition def) {
178+
EsqlFunctionRegistry.FunctionDescription functionDescription = EsqlFunctionRegistry.description(def);
179+
180+
for (int i = 0; i < functionDescription.args().size(); i++) {
181+
EsqlFunctionRegistry.ArgSignature arg = functionDescription.args().get(i);
182+
183+
if (arg.name().equals(InferenceFunction.INFERENCE_ID_PARAMETER_NAME)) {
184+
Expression argValue = f.arguments().get(i);
185+
if (argValue != null && argValue.foldable() && DataType.isString(argValue.dataType())) {
186+
return inferenceId(argValue);
187+
}
188+
}
189+
}
190+
191+
return null;
192+
}
193+
147194
public static Factory factory(Client client) {
148195
return new Factory(client);
149196
}
@@ -155,8 +202,8 @@ private Factory(Client client) {
155202
this.client = client;
156203
}
157204

158-
public InferenceResolver create() {
159-
return new InferenceResolver(client);
205+
public InferenceResolver create(EsqlFunctionRegistry functionRegistry) {
206+
return new InferenceResolver(client, functionRegistry);
160207
}
161208
}
162209
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceService.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.esql.inference;
99

1010
import org.elasticsearch.client.internal.Client;
11+
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
1112
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner;
1213
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunnerConfig;
1314

@@ -33,10 +34,12 @@ private InferenceService(InferenceResolver.Factory inferenceResolverFactory, Bul
3334
/**
3435
* Creates an inference resolver for resolving inference IDs in logical plans.
3536
*
37+
* @param functionRegistry the function registry to resolve functions
38+
*
3639
* @return a new inference resolver instance
3740
*/
38-
public InferenceResolver inferenceResolver() {
39-
return inferenceResolverFactory.create();
41+
public InferenceResolver inferenceResolver(EsqlFunctionRegistry functionRegistry) {
42+
return inferenceResolverFactory.create(functionRegistry);
4043
}
4144

4245
public BulkInferenceRunner bulkInferenceRunner() {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,7 @@ private void analyzeWithRetry(
731731
}
732732

733733
private void resolveInferences(LogicalPlan plan, PreAnalysisResult preAnalysisResult, ActionListener<PreAnalysisResult> l) {
734-
inferenceService.inferenceResolver().resolveInferenceIds(plan, l.map(preAnalysisResult::withInferenceResolution));
734+
inferenceService.inferenceResolver(functionRegistry).resolveInferenceIds(plan, l.map(preAnalysisResult::withInferenceResolution));
735735
}
736736

737737
private PhysicalPlan logicalPlanToPhysicalPlan(LogicalPlan optimizedPlan, EsqlQueryRequest request) {

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceResolverTests.java

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.elasticsearch.threadpool.FixedExecutorBuilder;
2323
import org.elasticsearch.threadpool.TestThreadPool;
2424
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
25+
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
2526
import org.elasticsearch.xpack.esql.parser.EsqlParser;
2627
import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
2728
import org.junit.After;
@@ -44,6 +45,7 @@
4445

4546
public class InferenceResolverTests extends ESTestCase {
4647
private TestThreadPool threadPool;
48+
private EsqlFunctionRegistry functionRegistry;
4749

4850
@Before
4951
public void setThreadPool() {
@@ -60,6 +62,11 @@ public void setThreadPool() {
6062
);
6163
}
6264

65+
@Before
66+
public void setUpFunctionRegistry() {
67+
functionRegistry = new EsqlFunctionRegistry();
68+
}
69+
6370
@After
6471
public void shutdownThreadPool() {
6572
terminate(threadPool);
@@ -78,6 +85,18 @@ public void testCollectInferenceIds() {
7885
List.of("completion-inference-id")
7986
);
8087

88+
// Test inference ID collection from an inference function
89+
assertCollectInferenceIds(
90+
"FROM books METADATA _score | EVAL embedding = TEXT_EMBEDDING(\"description\", \"text-embedding-inference-id\")",
91+
List.of("text-embedding-inference-id")
92+
);
93+
94+
// Test inference ID collection with nested functions
95+
assertCollectInferenceIds(
96+
"FROM books METADATA _score | EVAL embedding = TEXT_EMBEDDING(TEXT_EMBEDDING(\"nested\", \"nested-id\"), \"outer-id\")",
97+
List.of("nested-id", "outer-id")
98+
);
99+
81100
// Multiple inference plans
82101
assertCollectInferenceIds("""
83102
FROM books METADATA _score
@@ -139,7 +158,7 @@ public void testResolveMultipleInferenceIds() throws Exception {
139158

140159
public void testResolveMissingInferenceIds() throws Exception {
141160
InferenceResolver inferenceResolver = inferenceResolver();
142-
List<String> inferenceIds = List.of("missing-plan");
161+
List<String> inferenceIds = List.of("missing-inference-id");
143162

144163
SetOnce<InferenceResolution> inferenceResolutionSetOnce = new SetOnce<>();
145164

@@ -153,7 +172,7 @@ public void testResolveMissingInferenceIds() throws Exception {
153172

154173
assertThat(inferenceResolution.resolvedInferences(), empty());
155174
assertThat(inferenceResolution.hasError(), equalTo(true));
156-
assertThat(inferenceResolution.getError("missing-plan"), equalTo("inference endpoint not found"));
175+
assertThat(inferenceResolution.getError("missing-inference-id"), equalTo("inference endpoint not found"));
157176
});
158177
}
159178

@@ -205,7 +224,7 @@ private static ActionResponse getInferenceModelResponse(GetInferenceModelAction.
205224
}
206225

207226
private InferenceResolver inferenceResolver() {
208-
return new InferenceResolver(mockClient());
227+
return new InferenceResolver(mockClient(), functionRegistry);
209228
}
210229

211230
private static ModelConfigurations mockModelConfig(String inferenceId, TaskType taskType) {

0 commit comments

Comments
 (0)