Skip to content

Commit 59087fb

Browse files
committed
Add text embedding function definition.
1 parent 5e74ecd commit 59087fb

File tree

12 files changed

+461
-1
lines changed

12 files changed

+461
-1
lines changed

docs/reference/query-languages/esql/images/functions/text_embedding.svg

Lines changed: 1 addition & 0 deletions
Loading

docs/reference/query-languages/esql/kibana/definition/functions/text_embedding.json

Lines changed: 9 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/reference/query-languages/esql/kibana/docs/functions/text_embedding.md

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,6 +1303,11 @@ public enum Cap {
13031303
*/
13041304
KNN_FUNCTION_V5(Build.current().isSnapshot()),
13051305

1306+
/**
1307+
* Support for the {@code TEXT_EMBEDDING} function for generating dense vector embeddings.
1308+
*/
1309+
TEXT_EMBEDDING_FUNCTION(Build.current().isSnapshot()),
1310+
13061311
/**
13071312
* Support for the LIKE operator with a list of wildcards.
13081313
*/

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
1313
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateWritables;
1414
import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextWritables;
15+
import org.elasticsearch.xpack.esql.expression.function.inference.InferenceWritables;
1516
import org.elasticsearch.xpack.esql.expression.function.scalar.ScalarFunctionWritables;
1617
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FromBase64;
1718
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToAggregateMetricDouble;
@@ -118,6 +119,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
118119
entries.addAll(fullText());
119120
entries.addAll(unaryScalars());
120121
entries.addAll(vector());
122+
entries.addAll(inference());
121123
return entries;
122124
}
123125

@@ -260,4 +262,8 @@ private static List<NamedWriteableRegistry.Entry> fullText() {
260262
private static List<NamedWriteableRegistry.Entry> vector() {
261263
return VectorWritables.getNamedWritables();
262264
}
265+
266+
private static List<NamedWriteableRegistry.Entry> inference() {
267+
return InferenceWritables.getNamedWritables();
268+
}
263269
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@
193193
import org.elasticsearch.xpack.esql.expression.function.vector.L1Norm;
194194
import org.elasticsearch.xpack.esql.expression.function.vector.L2Norm;
195195
import org.elasticsearch.xpack.esql.expression.function.vector.Magnitude;
196+
import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding;
196197
import org.elasticsearch.xpack.esql.parser.ParsingException;
197198
import org.elasticsearch.xpack.esql.session.Configuration;
198199

@@ -519,7 +520,8 @@ private static FunctionDefinition[][] snapshotFunctions() {
519520
def(Magnitude.class, Magnitude::new, "v_magnitude"),
520521
def(Hamming.class, Hamming::new, "v_hamming"),
521522
def(UrlEncode.class, UrlEncode::new, "url_encode"),
522-
def(UrlDecode.class, UrlDecode::new, "url_decode") } };
523+
def(UrlDecode.class, UrlDecode::new, "url_decode"),
524+
def(TextEmbedding.class, bi(TextEmbedding::new), "text_embedding") } };
523525
}
524526

525527
public EsqlFunctionRegistry snapshotRegistry() {
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.expression.function.inference;
9+
10+
import org.elasticsearch.inference.TaskType;
11+
import org.elasticsearch.xpack.esql.core.expression.Expression;
12+
import org.elasticsearch.xpack.esql.core.expression.function.Function;
13+
import org.elasticsearch.xpack.esql.core.tree.Source;
14+
15+
import java.util.List;
16+
17+
/**
18+
* Base class for ESQL functions that use inference endpoints (e.g., TEXT_EMBEDDING).
19+
*/
20+
public abstract class InferenceFunction<PlanType extends InferenceFunction<PlanType>> extends Function {
21+
22+
public static final String INFERENCE_ID_PARAMETER_NAME = "inference_id";
23+
24+
protected InferenceFunction(Source source, List<Expression> children) {
25+
super(source, children);
26+
}
27+
28+
/** The inference endpoint identifier expression. */
29+
public abstract Expression inferenceId();
30+
31+
/** The task type required by this function (e.g., TEXT_EMBEDDING). */
32+
public abstract TaskType taskType();
33+
34+
/** Returns a copy with inference resolution error for display to user. */
35+
public abstract PlanType withInferenceResolutionError(String inferenceId, String error);
36+
37+
/** True if this function contains nested inference function calls. */
38+
public boolean hasNestedInferenceFunction() {
39+
return anyMatch(e -> e instanceof InferenceFunction && e != this);
40+
}
41+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.expression.function.inference;
9+
10+
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
11+
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
12+
13+
import java.util.ArrayList;
14+
import java.util.Collections;
15+
import java.util.List;
16+
17+
/**
18+
* Defines the named writables for inference functions in ESQL.
19+
*/
20+
public final class InferenceWritables {
21+
22+
private InferenceWritables() {
23+
// Utility class
24+
throw new UnsupportedOperationException();
25+
}
26+
27+
public static List<NamedWriteableRegistry.Entry> getNamedWritables() {
28+
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
29+
30+
if (EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()) {
31+
entries.add(TextEmbedding.ENTRY);
32+
}
33+
34+
return Collections.unmodifiableList(entries);
35+
}
36+
}
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.expression.function.inference;
9+
10+
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
11+
import org.elasticsearch.common.io.stream.StreamInput;
12+
import org.elasticsearch.common.io.stream.StreamOutput;
13+
import org.elasticsearch.inference.TaskType;
14+
import org.elasticsearch.xpack.esql.core.expression.Expression;
15+
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
16+
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
17+
import org.elasticsearch.xpack.esql.core.tree.Source;
18+
import org.elasticsearch.xpack.esql.core.type.DataType;
19+
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo;
20+
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle;
21+
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
22+
import org.elasticsearch.xpack.esql.expression.function.Param;
23+
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
24+
25+
import java.io.IOException;
26+
import java.util.List;
27+
import java.util.Objects;
28+
29+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
30+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
31+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable;
32+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull;
33+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isString;
34+
35+
/**
36+
* TEXT_EMBEDDING function converts text to dense vector embeddings using an inference endpoint.
37+
*/
38+
public class TextEmbedding extends InferenceFunction<TextEmbedding> {
39+
40+
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
41+
Expression.class,
42+
"TextEmbedding",
43+
TextEmbedding::new
44+
);
45+
46+
private final Expression inferenceId;
47+
private final Expression inputText;
48+
49+
@FunctionInfo(
50+
returnType = "dense_vector",
51+
description = "Generates dense vector embeddings for text using a specified inference endpoint.",
52+
appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) },
53+
preview = true
54+
)
55+
public TextEmbedding(
56+
Source source,
57+
@Param(name = "text", type = { "keyword", "text" }, description = "Text to embed") Expression inputText,
58+
@Param(
59+
name = InferenceFunction.INFERENCE_ID_PARAMETER_NAME,
60+
type = { "keyword", "text" },
61+
description = "Identifier of the inference endpoint"
62+
) Expression inferenceId
63+
) {
64+
super(source, List.of(inputText, inferenceId));
65+
this.inferenceId = inferenceId;
66+
this.inputText = inputText;
67+
}
68+
69+
private TextEmbedding(StreamInput in) throws IOException {
70+
this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class), in.readNamedWriteable(Expression.class));
71+
}
72+
73+
@Override
74+
public void writeTo(StreamOutput out) throws IOException {
75+
source().writeTo(out);
76+
out.writeNamedWriteable(inputText);
77+
out.writeNamedWriteable(inferenceId);
78+
}
79+
80+
@Override
81+
public String getWriteableName() {
82+
return ENTRY.name;
83+
}
84+
85+
public Expression inputText() {
86+
return inputText;
87+
}
88+
89+
@Override
90+
public Expression inferenceId() {
91+
return inferenceId;
92+
}
93+
94+
@Override
95+
public boolean foldable() {
96+
return inferenceId.foldable() && inputText.foldable();
97+
}
98+
99+
@Override
100+
public DataType dataType() {
101+
return DataType.DENSE_VECTOR;
102+
}
103+
104+
@Override
105+
protected TypeResolution resolveType() {
106+
if (childrenResolved() == false) {
107+
return new TypeResolution("Unresolved children");
108+
}
109+
110+
TypeResolution textResolution = isNotNull(inputText, sourceText(), FIRST).and(isFoldable(inputText, sourceText(), FIRST))
111+
.and(isString(inputText, sourceText(), FIRST));
112+
113+
if (textResolution.unresolved()) {
114+
return textResolution;
115+
}
116+
117+
TypeResolution inferenceIdResolution = isNotNull(inferenceId, sourceText(), SECOND).and(isString(inferenceId, sourceText(), SECOND))
118+
.and(isFoldable(inferenceId, sourceText(), SECOND));
119+
120+
if (inferenceIdResolution.unresolved()) {
121+
return inferenceIdResolution;
122+
}
123+
124+
return TypeResolution.TYPE_RESOLVED;
125+
}
126+
127+
@Override
128+
public TaskType taskType() {
129+
return TaskType.TEXT_EMBEDDING;
130+
}
131+
132+
@Override
133+
public TextEmbedding withInferenceResolutionError(String inferenceId, String error) {
134+
return new TextEmbedding(source(), inputText, new UnresolvedAttribute(inferenceId().source(), inferenceId, error));
135+
}
136+
137+
@Override
138+
public Expression replaceChildren(List<Expression> newChildren) {
139+
return new TextEmbedding(source(), newChildren.get(0), newChildren.get(1));
140+
}
141+
142+
@Override
143+
protected NodeInfo<? extends Expression> info() {
144+
return NodeInfo.create(this, TextEmbedding::new, inputText, inferenceId);
145+
}
146+
147+
@Override
148+
public String toString() {
149+
return "TEXT_EMBEDDING(" + inputText + ", " + inferenceId + ")";
150+
}
151+
152+
@Override
153+
public boolean equals(Object o) {
154+
if (o == null || getClass() != o.getClass()) return false;
155+
if (super.equals(o) == false) return false;
156+
TextEmbedding textEmbedding = (TextEmbedding) o;
157+
return Objects.equals(inferenceId, textEmbedding.inferenceId) && Objects.equals(inputText, textEmbedding.inputText);
158+
}
159+
160+
@Override
161+
public int hashCode() {
162+
return Objects.hash(super.hashCode(), inferenceId, inputText);
163+
}
164+
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.expression.function.inference;
9+
10+
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
11+
import org.elasticsearch.xpack.esql.core.expression.Expression;
12+
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
13+
import org.elasticsearch.xpack.esql.core.tree.Source;
14+
import org.elasticsearch.xpack.esql.core.type.DataType;
15+
import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase;
16+
import org.elasticsearch.xpack.esql.expression.function.ErrorsForCasesWithoutExamplesTestCase;
17+
import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
18+
import org.hamcrest.Matcher;
19+
import org.junit.Before;
20+
21+
import java.util.List;
22+
import java.util.Locale;
23+
import java.util.Set;
24+
25+
import static org.hamcrest.Matchers.equalTo;
26+
27+
/** Tests error conditions and type validation for TEXT_EMBEDDING function. */
28+
public class TextEmbeddingErrorTests extends ErrorsForCasesWithoutExamplesTestCase {
29+
30+
@Before
31+
public void checkCapability() {
32+
assumeTrue("TEXT_EMBEDDING is not enabled", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled());
33+
}
34+
35+
@Override
36+
protected List<TestCaseSupplier> cases() {
37+
return paramsToSuppliers(TextEmbeddingTests.parameters());
38+
}
39+
40+
@Override
41+
protected Expression build(Source source, List<Expression> args) {
42+
return new TextEmbedding(source, args.get(0), args.get(1));
43+
}
44+
45+
@Override
46+
protected Matcher<String> expectedTypeErrorMatcher(List<Set<DataType>> validPerPosition, List<DataType> signature) {
47+
return equalTo(typeErrorMessage(true, validPerPosition, signature, (v, p) -> "string"));
48+
}
49+
50+
protected static String typeErrorMessage(
51+
boolean includeOrdinal,
52+
List<Set<DataType>> validPerPosition,
53+
List<DataType> signature,
54+
AbstractFunctionTestCase.PositionalErrorMessageSupplier positionalErrorMessageSupplier
55+
) {
56+
for (int i = 0; i < signature.size(); i++) {
57+
if (signature.get(i) == DataType.NULL) {
58+
String ordinal = includeOrdinal ? TypeResolutions.ParamOrdinal.fromIndex(i).name().toLowerCase(Locale.ROOT) + " " : "";
59+
return ordinal + "argument of [" + sourceForSignature(signature) + "] cannot be null, received []";
60+
}
61+
62+
if (validPerPosition.get(i).contains(signature.get(i)) == false) {
63+
break;
64+
}
65+
}
66+
67+
return ErrorsForCasesWithoutExamplesTestCase.typeErrorMessage(
68+
includeOrdinal,
69+
validPerPosition,
70+
signature,
71+
positionalErrorMessageSupplier
72+
);
73+
}
74+
}

0 commit comments

Comments
 (0)