Skip to content

Commit 59c3601

Browse files
authored
[ES|QL] TEXT_EMBEDDING function definition (#135059)
1 parent bce397f commit 59c3601

File tree

19 files changed

+571
-28
lines changed

19 files changed

+571
-28
lines changed

docs/reference/query-languages/esql/_snippets/functions/parameters/text_embedding.md

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/reference/query-languages/esql/images/functions/text_embedding.svg

Lines changed: 1 addition & 0 deletions
Loading

docs/reference/query-languages/esql/kibana/definition/functions/text_embedding.json

Lines changed: 12 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/reference/query-languages/esql/kibana/docs/functions/text_embedding.md

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
placeholder
2+
required_capability: text_embedding_function
3+
required_capability: not_existing_capability
4+
5+
// tag::embedding-eval[]
6+
ROW input="Who is Victor Hugo?"
7+
| EVAL embedding = TEXT_EMBEDDING("Who is Victor Hugo?", "test_dense_inference")
8+
;
9+
// end::embedding-eval[]
10+
11+
12+
input:keyword | embedding:dense_vector
13+
Who is Victor Hugo? | [56.0, 50.0, 48.0]
14+
;
15+

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1319,6 +1319,11 @@ public enum Cap {
13191319
*/
13201320
KNN_FUNCTION_V5(Build.current().isSnapshot()),
13211321

1322+
/**
1323+
* Support for the {@code TEXT_EMBEDDING} function for generating dense vector embeddings.
1324+
*/
1325+
TEXT_EMBEDDING_FUNCTION(Build.current().isSnapshot()),
1326+
13221327
/**
13231328
* Support for the LIKE operator with a list of wildcards.
13241329
*/

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
import org.elasticsearch.xpack.esql.expression.function.aggregate.SummationMode;
7474
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
7575
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
76+
import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction;
7677
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
7778
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
7879
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest;
@@ -1419,7 +1420,8 @@ private static class ResolveInference extends ParameterizedRule<LogicalPlan, Log
14191420

14201421
@Override
14211422
public LogicalPlan apply(LogicalPlan plan, AnalyzerContext context) {
1422-
return plan.transformDown(InferencePlan.class, p -> resolveInferencePlan(p, context));
1423+
return plan.transformDown(InferencePlan.class, p -> resolveInferencePlan(p, context))
1424+
.transformExpressionsOnly(InferenceFunction.class, f -> resolveInferenceFunction(f, context));
14231425
}
14241426

14251427
private LogicalPlan resolveInferencePlan(InferencePlan<?> plan, AnalyzerContext context) {
@@ -1448,6 +1450,36 @@ private LogicalPlan resolveInferencePlan(InferencePlan<?> plan, AnalyzerContext
14481450

14491451
return plan;
14501452
}
1453+
1454+
private InferenceFunction<?> resolveInferenceFunction(InferenceFunction<?> inferenceFunction, AnalyzerContext context) {
1455+
if (inferenceFunction.inferenceId().resolved()
1456+
&& inferenceFunction.inferenceId().foldable()
1457+
&& DataType.isString(inferenceFunction.inferenceId().dataType())) {
1458+
1459+
String inferenceId = BytesRefs.toString(inferenceFunction.inferenceId().fold(FoldContext.small()));
1460+
ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId);
1461+
1462+
if (resolvedInference == null) {
1463+
String error = context.inferenceResolution().getError(inferenceId);
1464+
return inferenceFunction.withInferenceResolutionError(inferenceId, error);
1465+
}
1466+
1467+
if (resolvedInference.taskType() != inferenceFunction.taskType()) {
1468+
String error = "cannot use inference endpoint ["
1469+
+ inferenceId
1470+
+ "] with task type ["
1471+
+ resolvedInference.taskType()
1472+
+ "] within a "
1473+
+ context.functionRegistry().snapshotRegistry().functionName(inferenceFunction.getClass())
1474+
+ " function. Only inference endpoints with the task type ["
1475+
+ inferenceFunction.taskType()
1476+
+ "] are supported.";
1477+
return inferenceFunction.withInferenceResolutionError(inferenceId, error);
1478+
}
1479+
}
1480+
1481+
return inferenceFunction;
1482+
}
14511483
}
14521484

14531485
private static class AddImplicitLimit extends ParameterizedRule<LogicalPlan, LogicalPlan, AnalyzerContext> {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
6464
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
6565
import org.elasticsearch.xpack.esql.expression.function.grouping.TBucket;
66+
import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding;
6667
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
6768
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest;
6869
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Least;
@@ -543,7 +544,8 @@ private static FunctionDefinition[][] snapshotFunctions() {
543544
def(Hamming.class, Hamming::new, "v_hamming"),
544545
def(UrlEncode.class, UrlEncode::new, "url_encode"),
545546
def(UrlEncodeComponent.class, UrlEncodeComponent::new, "url_encode_component"),
546-
def(UrlDecode.class, UrlDecode::new, "url_decode") } };
547+
def(UrlDecode.class, UrlDecode::new, "url_decode"),
548+
def(TextEmbedding.class, bi(TextEmbedding::new), "text_embedding") } };
547549
}
548550

549551
public EsqlFunctionRegistry snapshotRegistry() {
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.expression.function.inference;
9+
10+
import org.elasticsearch.inference.TaskType;
11+
import org.elasticsearch.xpack.esql.core.expression.Expression;
12+
import org.elasticsearch.xpack.esql.core.expression.function.Function;
13+
import org.elasticsearch.xpack.esql.core.tree.Source;
14+
15+
import java.util.List;
16+
17+
/**
18+
* Base class for ESQL functions that use inference endpoints (e.g., TEXT_EMBEDDING).
19+
*/
20+
public abstract class InferenceFunction<PlanType extends InferenceFunction<PlanType>> extends Function {
21+
22+
public static final String INFERENCE_ID_PARAMETER_NAME = "inference_id";
23+
24+
protected InferenceFunction(Source source, List<Expression> children) {
25+
super(source, children);
26+
}
27+
28+
/** The inference endpoint identifier expression. */
29+
public abstract Expression inferenceId();
30+
31+
/** The task type required by this function (e.g., TEXT_EMBEDDING). */
32+
public abstract TaskType taskType();
33+
34+
/** Returns a copy with inference resolution error for display to user. */
35+
public abstract PlanType withInferenceResolutionError(String inferenceId, String error);
36+
37+
/** True if this function contains nested inference function calls. */
38+
public boolean hasNestedInferenceFunction() {
39+
return anyMatch(e -> e instanceof InferenceFunction && e != this);
40+
}
41+
}
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.expression.function.inference;
9+
10+
import org.elasticsearch.common.io.stream.StreamOutput;
11+
import org.elasticsearch.inference.TaskType;
12+
import org.elasticsearch.xpack.esql.core.expression.Expression;
13+
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
14+
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
15+
import org.elasticsearch.xpack.esql.core.tree.Source;
16+
import org.elasticsearch.xpack.esql.core.type.DataType;
17+
import org.elasticsearch.xpack.esql.expression.function.Example;
18+
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo;
19+
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle;
20+
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
21+
import org.elasticsearch.xpack.esql.expression.function.Param;
22+
23+
import java.io.IOException;
24+
import java.util.List;
25+
import java.util.Objects;
26+
27+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
28+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
29+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable;
30+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull;
31+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
32+
33+
/**
34+
* TEXT_EMBEDDING function converts text to dense vector embeddings using an inference endpoint.
35+
*/
36+
public class TextEmbedding extends InferenceFunction<TextEmbedding> {
37+
38+
private final Expression inferenceId;
39+
private final Expression inputText;
40+
41+
@FunctionInfo(
42+
returnType = "dense_vector",
43+
description = "Generates dense vector embeddings for text using a specified inference endpoint.",
44+
appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) },
45+
preview = true,
46+
examples = {
47+
@Example(
48+
description = "Generate text embeddings using the 'test_dense_inference' inference endpoint.",
49+
file = "text-embedding",
50+
tag = "embedding-eval"
51+
) }
52+
)
53+
public TextEmbedding(
54+
Source source,
55+
@Param(name = "text", type = { "keyword" }, description = "Text to generate embeddings from") Expression inputText,
56+
@Param(
57+
name = InferenceFunction.INFERENCE_ID_PARAMETER_NAME,
58+
type = { "keyword" },
59+
description = "Identifier of the inference endpoint"
60+
) Expression inferenceId
61+
) {
62+
super(source, List.of(inputText, inferenceId));
63+
this.inferenceId = inferenceId;
64+
this.inputText = inputText;
65+
}
66+
67+
@Override
68+
public void writeTo(StreamOutput out) throws IOException {
69+
throw new UnsupportedOperationException("doesn't escape the node");
70+
}
71+
72+
@Override
73+
public String getWriteableName() {
74+
throw new UnsupportedOperationException("doesn't escape the node");
75+
}
76+
77+
public Expression inputText() {
78+
return inputText;
79+
}
80+
81+
@Override
82+
public Expression inferenceId() {
83+
return inferenceId;
84+
}
85+
86+
@Override
87+
public boolean foldable() {
88+
return inferenceId.foldable() && inputText.foldable();
89+
}
90+
91+
@Override
92+
public DataType dataType() {
93+
return DataType.DENSE_VECTOR;
94+
}
95+
96+
@Override
97+
protected TypeResolution resolveType() {
98+
if (childrenResolved() == false) {
99+
return new TypeResolution("Unresolved children");
100+
}
101+
102+
TypeResolution textResolution = isNotNull(inputText, sourceText(), FIRST).and(isFoldable(inputText, sourceText(), FIRST))
103+
.and(isType(inputText, DataType.KEYWORD::equals, sourceText(), FIRST, "string"));
104+
105+
if (textResolution.unresolved()) {
106+
return textResolution;
107+
}
108+
109+
TypeResolution inferenceIdResolution = isNotNull(inferenceId, sourceText(), SECOND).and(
110+
isType(inferenceId, DataType.KEYWORD::equals, sourceText(), SECOND, "string")
111+
).and(isFoldable(inferenceId, sourceText(), SECOND));
112+
113+
if (inferenceIdResolution.unresolved()) {
114+
return inferenceIdResolution;
115+
}
116+
117+
return TypeResolution.TYPE_RESOLVED;
118+
}
119+
120+
@Override
121+
public TaskType taskType() {
122+
return TaskType.TEXT_EMBEDDING;
123+
}
124+
125+
@Override
126+
public TextEmbedding withInferenceResolutionError(String inferenceId, String error) {
127+
return new TextEmbedding(source(), inputText, new UnresolvedAttribute(inferenceId().source(), inferenceId, error));
128+
}
129+
130+
@Override
131+
public Expression replaceChildren(List<Expression> newChildren) {
132+
return new TextEmbedding(source(), newChildren.get(0), newChildren.get(1));
133+
}
134+
135+
@Override
136+
protected NodeInfo<? extends Expression> info() {
137+
return NodeInfo.create(this, TextEmbedding::new, inputText, inferenceId);
138+
}
139+
140+
@Override
141+
public String toString() {
142+
return "TEXT_EMBEDDING(" + inputText + ", " + inferenceId + ")";
143+
}
144+
145+
@Override
146+
public boolean equals(Object o) {
147+
if (o == null || getClass() != o.getClass()) return false;
148+
if (super.equals(o) == false) return false;
149+
TextEmbedding textEmbedding = (TextEmbedding) o;
150+
return Objects.equals(inferenceId, textEmbedding.inferenceId) && Objects.equals(inputText, textEmbedding.inputText);
151+
}
152+
153+
@Override
154+
public int hashCode() {
155+
return Objects.hash(super.hashCode(), inferenceId, inputText);
156+
}
157+
}

0 commit comments

Comments
 (0)