diff --git a/docs/changelog/134573.yaml b/docs/changelog/134573.yaml
new file mode 100644
index 0000000000000..64e79561678a2
--- /dev/null
+++ b/docs/changelog/134573.yaml
@@ -0,0 +1,5 @@
+pr: 134573
+summary: Esql text embedding function
+area: ES|QL
+type: feature
+issues: []
diff --git a/docs/reference/query-languages/esql/images/functions/text_embedding.svg b/docs/reference/query-languages/esql/images/functions/text_embedding.svg
new file mode 100644
index 0000000000000..dab58c5e5bda0
--- /dev/null
+++ b/docs/reference/query-languages/esql/images/functions/text_embedding.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/text_embedding.json b/docs/reference/query-languages/esql/kibana/definition/functions/text_embedding.json
new file mode 100644
index 0000000000000..9e4967b92c367
--- /dev/null
+++ b/docs/reference/query-languages/esql/kibana/definition/functions/text_embedding.json
@@ -0,0 +1,9 @@
+{
+ "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.",
+ "type" : "scalar",
+ "name" : "text_embedding",
+ "description" : "Generates dense vector embeddings for text using a specified inference endpoint.",
+ "signatures" : [ ],
+ "preview" : true,
+ "snapshot_only" : true
+}
diff --git a/docs/reference/query-languages/esql/kibana/docs/functions/text_embedding.md b/docs/reference/query-languages/esql/kibana/docs/functions/text_embedding.md
new file mode 100644
index 0000000000000..bb3e74fc116cd
--- /dev/null
+++ b/docs/reference/query-languages/esql/kibana/docs/functions/text_embedding.md
@@ -0,0 +1,4 @@
+% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.
+
+### TEXT EMBEDDING
+Generates dense vector embeddings for text using a specified inference endpoint.
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java
index 54f858cb20ae0..d8b3fb2b9c544 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java
@@ -155,7 +155,7 @@ public String toString() {
return Strings.toString(this);
}
- float[] toFloatArray() {
+ public float[] toFloatArray() {
float[] floatArray = new float[values.length];
for (int i = 0; i < values.length; i++) {
floatArray[i] = ((Byte) values[i]).floatValue();
diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java
index 33ff75519348c..15b79183cac4a 100644
--- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java
+++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java
@@ -77,6 +77,7 @@
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.RERANK;
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.SEMANTIC_TEXT_FIELD_CAPS;
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.SOURCE_FIELD_MAPPING;
+import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION;
import static org.elasticsearch.xpack.esql.qa.rest.RestEsqlTestCase.assertNotPartial;
import static org.elasticsearch.xpack.esql.qa.rest.RestEsqlTestCase.hasCapabilities;
@@ -205,7 +206,8 @@ protected boolean requiresInferenceEndpoint() {
SEMANTIC_TEXT_FIELD_CAPS.capabilityName(),
RERANK.capabilityName(),
COMPLETION.capabilityName(),
- KNN_FUNCTION_V5.capabilityName()
+ KNN_FUNCTION_V5.capabilityName(),
+ TEXT_EMBEDDING_FUNCTION.capabilityName()
).anyMatch(testCase.requiredCapabilities::contains);
}
diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec
new file mode 100644
index 0000000000000..4be9bacab399d
--- /dev/null
+++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec
@@ -0,0 +1,66 @@
+text_embedding using a row source operator
+required_capability: text_embedding_function
+required_capability: dense_vector_field_type
+
+ROW input="Who is Victor Hugo?"
+| EVAL embedding = TEXT_EMBEDDING("Who is Victor Hugo?", "test_dense_inference")
+;
+
+input:keyword | embedding:dense_vector
+Who is Victor Hugo? | [56.0, 50.0, 48.0]
+;
+
+
+text_embedding using a row source operator with query build using CONCAT
+required_capability: text_embedding_function
+required_capability: dense_vector_field_type
+
+ROW input="Who is Victor Hugo?"
+| EVAL embedding = TEXT_EMBEDDING(CONCAT("Who is ", "Victor Hugo?"), "test_dense_inference")
+;
+
+input:keyword | embedding:dense_vector
+Who is Victor Hugo? | [56.0, 50.0, 48.0]
+;
+
+
+text_embedding with knn on semantic_text_dense_field
+required_capability: text_embedding_function
+required_capability: dense_vector_field_type
+required_capability: knn_function_v5
+required_capability: semantic_text_field_caps
+
+FROM semantic_text METADATA _score
+| EVAL query_embedding = TEXT_EMBEDDING("be excellent to each other", "test_dense_inference")
+| WHERE KNN(semantic_text_dense_field, query_embedding)
+| KEEP semantic_text_field, query_embedding, _score
+| EVAL _score = ROUND(_score, 4)
+| SORT _score DESC
+| LIMIT 10
+;
+
+semantic_text_field:text | query_embedding:dense_vector | _score:double
+be excellent to each other | [45.0, 55.0, 54.0] | 1.0
+live long and prosper | [45.0, 55.0, 54.0] | 0.0295
+all we have to decide is what to do with the time that is given to us | [45.0, 55.0, 54.0] | 0.0214
+;
+
+text_embedding with knn (inline) on semantic_text_dense_field
+required_capability: text_embedding_function
+required_capability: dense_vector_field_type
+required_capability: knn_function_v5
+required_capability: semantic_text_field_caps
+
+FROM semantic_text METADATA _score
+| WHERE KNN(semantic_text_dense_field, TEXT_EMBEDDING("be excellent to each other", "test_dense_inference"))
+| KEEP semantic_text_field, _score
+| EVAL _score = ROUND(_score, 4)
+| SORT _score DESC
+| LIMIT 10
+;
+
+semantic_text_field:text | _score:double
+be excellent to each other | 1.0
+live long and prosper | 0.0295
+all we have to decide is what to do with the time that is given to us | 0.0214
+;
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
index 607ff3faeb6d9..9710c94ef2d9c 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
@@ -1309,6 +1309,11 @@ public enum Cap {
*/
KNN_FUNCTION_V5(Build.current().isSnapshot()),
+ /**
+ * Support for the {@code TEXT_EMBEDDING} function for generating dense vector embeddings.
+ */
+ TEXT_EMBEDDING_FUNCTION(Build.current().isSnapshot()),
+
/**
* Support for the LIKE operator with a list of wildcards.
*/
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java
index f7bd49b75b4c5..a02a5565c46d0 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java
@@ -73,6 +73,7 @@
import org.elasticsearch.xpack.esql.expression.function.aggregate.SummationMode;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
+import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction;
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest;
@@ -1414,7 +1415,8 @@ private static class ResolveInference extends ParameterizedRule resolveInferencePlan(p, context));
+ return plan.transformDown(InferencePlan.class, p -> resolveInferencePlan(p, context))
+ .transformExpressionsOnly(InferenceFunction.class, f -> resolveInferenceFunction(f, context));
}
private LogicalPlan resolveInferencePlan(InferencePlan> plan, AnalyzerContext context) {
@@ -1443,6 +1445,36 @@ private LogicalPlan resolveInferencePlan(InferencePlan> plan, AnalyzerContext
return plan;
}
+
+ private InferenceFunction> resolveInferenceFunction(InferenceFunction> inferenceFunction, AnalyzerContext context) {
+ if (inferenceFunction.inferenceId().resolved()
+ && inferenceFunction.inferenceId().foldable()
+ && DataType.isString(inferenceFunction.inferenceId().dataType())) {
+
+ String inferenceId = BytesRefs.toString(inferenceFunction.inferenceId().fold(FoldContext.small()));
+ ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId);
+
+ if (resolvedInference == null) {
+ String error = context.inferenceResolution().getError(inferenceId);
+ return inferenceFunction.withInferenceResolutionError(inferenceId, error);
+ }
+
+ if (resolvedInference.taskType() != inferenceFunction.taskType()) {
+ String error = "cannot use inference endpoint ["
+ + inferenceId
+ + "] with task type ["
+ + resolvedInference.taskType()
+ + "] within a "
+ + context.functionRegistry().snapshotRegistry().functionName(inferenceFunction.getClass())
+ + " function. Only inference endpoints with the task type ["
+ + inferenceFunction.taskType()
+ + "] are supported.";
+ return inferenceFunction.withInferenceResolutionError(inferenceId, error);
+ }
+ }
+
+ return inferenceFunction;
+ }
}
private static class AddImplicitLimit extends ParameterizedRule {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java
index 2043176f24a29..974b73718ff0b 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java
@@ -88,7 +88,7 @@ public void esql(
indexResolver,
enrichPolicyResolver,
preAnalyzer,
- new LogicalPlanPreOptimizer(new LogicalPreOptimizerContext(foldContext)),
+ new LogicalPlanPreOptimizer(new LogicalPreOptimizerContext(foldContext, services.inferenceService())),
functionRegistry,
new LogicalPlanOptimizer(new LogicalOptimizerContext(cfg, foldContext)),
mapper,
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java
index 2e06db66a85e5..07e3c73e666d7 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java
@@ -13,6 +13,7 @@
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateWritables;
import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextWritables;
+import org.elasticsearch.xpack.esql.expression.function.inference.InferenceWritables;
import org.elasticsearch.xpack.esql.expression.function.scalar.ScalarFunctionWritables;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FromBase64;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToAggregateMetricDouble;
@@ -120,6 +121,7 @@ public static List getNamedWriteables() {
entries.addAll(fullText());
entries.addAll(unaryScalars());
entries.addAll(vector());
+ entries.addAll(inference());
return entries;
}
@@ -265,4 +267,8 @@ private static List fullText() {
private static List vector() {
return VectorWritables.getNamedWritables();
}
+
+ private static List inference() {
+ return InferenceWritables.getNamedWritables();
+ }
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java
index 4f98a523910d5..f7da7380015bb 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java
@@ -62,6 +62,7 @@
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
import org.elasticsearch.xpack.esql.expression.function.grouping.TBucket;
+import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Least;
@@ -539,7 +540,8 @@ private static FunctionDefinition[][] snapshotFunctions() {
def(Magnitude.class, Magnitude::new, "v_magnitude"),
def(Hamming.class, Hamming::new, "v_hamming"),
def(UrlEncode.class, UrlEncode::new, "url_encode"),
- def(UrlDecode.class, UrlDecode::new, "url_decode") } };
+ def(UrlDecode.class, UrlDecode::new, "url_decode"),
+ def(TextEmbedding.class, bi(TextEmbedding::new), "text_embedding") } };
}
public EsqlFunctionRegistry snapshotRegistry() {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java
new file mode 100644
index 0000000000000..d2d6d9b6e2af7
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java
@@ -0,0 +1,41 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.inference;
+
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.function.Function;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+
+import java.util.List;
+
+/**
+ * Base class for ESQL functions that use inference endpoints (e.g., TEXT_EMBEDDING).
+ */
+public abstract class InferenceFunction> extends Function {
+
+ public static final String INFERENCE_ID_PARAMETER_NAME = "inference_id";
+
+ protected InferenceFunction(Source source, List children) {
+ super(source, children);
+ }
+
+ /** The inference endpoint identifier expression. */
+ public abstract Expression inferenceId();
+
+ /** The task type required by this function (e.g., TEXT_EMBEDDING). */
+ public abstract TaskType taskType();
+
+ /** Returns a copy with inference resolution error for display to user. */
+ public abstract PlanType withInferenceResolutionError(String inferenceId, String error);
+
+ /** True if this function contains nested inference function calls. */
+ public boolean hasNestedInferenceFunction() {
+ return anyMatch(e -> e instanceof InferenceFunction && e != this);
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceWritables.java
new file mode 100644
index 0000000000000..9809ef0d46b66
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceWritables.java
@@ -0,0 +1,36 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.inference;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * Defines the named writables for inference functions in ESQL.
+ */
+public final class InferenceWritables {
+
+ private InferenceWritables() {
+ // Utility class
+ throw new UnsupportedOperationException();
+ }
+
+ public static List getNamedWritables() {
+ List entries = new ArrayList<>();
+
+ if (EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()) {
+ entries.add(TextEmbedding.ENTRY);
+ }
+
+ return Collections.unmodifiableList(entries);
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbedding.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbedding.java
new file mode 100644
index 0000000000000..a5ef509df1dff
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbedding.java
@@ -0,0 +1,164 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.inference;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
+import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo;
+import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle;
+import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
+import org.elasticsearch.xpack.esql.expression.function.Param;
+import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
+import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
+import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable;
+import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull;
+import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isString;
+
+/**
+ * TEXT_EMBEDDING function converts text to dense vector embeddings using an inference endpoint.
+ */
+public class TextEmbedding extends InferenceFunction {
+
+ public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
+ Expression.class,
+ "TextEmbedding",
+ TextEmbedding::new
+ );
+
+ private final Expression inferenceId;
+ private final Expression inputText;
+
+ @FunctionInfo(
+ returnType = "dense_vector",
+ description = "Generates dense vector embeddings for text using a specified inference endpoint.",
+ appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) },
+ preview = true
+ )
+ public TextEmbedding(
+ Source source,
+ @Param(name = "text", type = { "keyword", "text" }, description = "Text to embed") Expression inputText,
+ @Param(
+ name = InferenceFunction.INFERENCE_ID_PARAMETER_NAME,
+ type = { "keyword", "text" },
+ description = "Identifier of the inference endpoint"
+ ) Expression inferenceId
+ ) {
+ super(source, List.of(inputText, inferenceId));
+ this.inferenceId = inferenceId;
+ this.inputText = inputText;
+ }
+
+ private TextEmbedding(StreamInput in) throws IOException {
+ this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class), in.readNamedWriteable(Expression.class));
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ source().writeTo(out);
+ out.writeNamedWriteable(inputText);
+ out.writeNamedWriteable(inferenceId);
+ }
+
+ @Override
+ public String getWriteableName() {
+ return ENTRY.name;
+ }
+
+ public Expression inputText() {
+ return inputText;
+ }
+
+ @Override
+ public Expression inferenceId() {
+ return inferenceId;
+ }
+
+ @Override
+ public boolean foldable() {
+ return inferenceId.foldable() && inputText.foldable();
+ }
+
+ @Override
+ public DataType dataType() {
+ return DataType.DENSE_VECTOR;
+ }
+
+ @Override
+ protected TypeResolution resolveType() {
+ if (childrenResolved() == false) {
+ return new TypeResolution("Unresolved children");
+ }
+
+ TypeResolution textResolution = isNotNull(inputText, sourceText(), FIRST).and(isFoldable(inputText, sourceText(), FIRST))
+ .and(isString(inputText, sourceText(), FIRST));
+
+ if (textResolution.unresolved()) {
+ return textResolution;
+ }
+
+ TypeResolution inferenceIdResolution = isNotNull(inferenceId, sourceText(), SECOND).and(isString(inferenceId, sourceText(), SECOND))
+ .and(isFoldable(inferenceId, sourceText(), SECOND));
+
+ if (inferenceIdResolution.unresolved()) {
+ return inferenceIdResolution;
+ }
+
+ return TypeResolution.TYPE_RESOLVED;
+ }
+
+ @Override
+ public TaskType taskType() {
+ return TaskType.TEXT_EMBEDDING;
+ }
+
+ @Override
+ public TextEmbedding withInferenceResolutionError(String inferenceId, String error) {
+ return new TextEmbedding(source(), inputText, new UnresolvedAttribute(inferenceId().source(), inferenceId, error));
+ }
+
+ @Override
+ public Expression replaceChildren(List newChildren) {
+ return new TextEmbedding(source(), newChildren.get(0), newChildren.get(1));
+ }
+
+ @Override
+ protected NodeInfo extends Expression> info() {
+ return NodeInfo.create(this, TextEmbedding::new, inputText, inferenceId);
+ }
+
+ @Override
+ public String toString() {
+ return "TEXT_EMBEDDING(" + inputText + ", " + inferenceId + ")";
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (o == null || getClass() != o.getClass()) return false;
+ if (super.equals(o) == false) return false;
+ TextEmbedding textEmbedding = (TextEmbedding) o;
+ return Objects.equals(inferenceId, textEmbedding.inferenceId) && Objects.equals(inputText, textEmbedding.inputText);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(super.hashCode(), inferenceId, inputText);
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java
new file mode 100644
index 0000000000000..eba1e11cb296f
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java
@@ -0,0 +1,236 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.inference;
+
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.common.breaker.CircuitBreaker;
+import org.elasticsearch.common.lucene.BytesRefs;
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.BlockUtils;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.compute.operator.DriverContext;
+import org.elasticsearch.compute.operator.EvalOperator;
+import org.elasticsearch.compute.operator.Operator;
+import org.elasticsearch.core.Releasables;
+import org.elasticsearch.indices.breaker.AllCircuitBreakerStats;
+import org.elasticsearch.indices.breaker.CircuitBreakerService;
+import org.elasticsearch.indices.breaker.CircuitBreakerStats;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
+import org.elasticsearch.xpack.esql.core.expression.Literal;
+import org.elasticsearch.xpack.esql.evaluator.EvalMapper;
+import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction;
+import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding;
+import org.elasticsearch.xpack.esql.inference.textembedding.TextEmbeddingOperator;
+
+/**
+ * Evaluator for inference functions that performs constant folding by executing inference operations
+ * at optimization time and replacing them with their computed results.
+ *
+ * This class is responsible for:
+ *
+ *
Setting up the necessary execution context (DriverContext, CircuitBreaker, etc.)
+ *
Creating and configuring appropriate inference operators for different function types
+ *
Executing inference operations asynchronously
+ *
Converting operator results back to ESQL expressions
+ *
+ */
+public class InferenceFunctionEvaluator {
+
+ private static final Factory FACTORY = new Factory();
+
+ public static InferenceFunctionEvaluator.Factory factory() {
+ return FACTORY;
+ }
+
+ private final FoldContext foldContext;
+ private final InferenceOperatorProvider inferenceOperatorProvider;
+
+ /**
+ * Creates a new inference function evaluator with a custom operator provider.
+ * This constructor is primarily used for testing to inject mock operator providers.
+ *
+ * @param foldContext the fold context containing circuit breakers and evaluation settings
+ * @param inferenceOperatorProvider custom provider for creating inference operators
+ */
+ InferenceFunctionEvaluator(FoldContext foldContext, InferenceOperatorProvider inferenceOperatorProvider) {
+ this.foldContext = foldContext;
+ this.inferenceOperatorProvider = inferenceOperatorProvider;
+ }
+
+ /**
+ * Folds an inference function by executing it and replacing it with its computed result.
+ *
+ * This method performs the following steps:
+ *
+ *
Validates that the function is foldable (has constant parameters)
+ *
Sets up a minimal execution context with appropriate circuit breakers
+ *
Creates and configures the appropriate inference operator
+ *
Executes the inference operation asynchronously
+ *
Converts the result to a {@link Literal} expression
+ *
+ *
+ * @param f the inference function to fold - must be foldable (have constant parameters)
+ * @param listener the listener to notify when folding completes successfully or fails
+ */
+ public void fold(InferenceFunction> f, ActionListener listener) {
+ if (f.foldable() == false) {
+ listener.onFailure(new IllegalArgumentException("Inference function must be foldable"));
+ return;
+ }
+
+ // Set up a DriverContext for executing the inference operator.
+ // This follows the same pattern as EvaluatorMapper but in a simplified context
+ // suitable for constant folding during optimization.
+ CircuitBreaker breaker = foldContext.circuitBreakerView(f.source());
+ BigArrays bigArrays = new BigArrays(null, new CircuitBreakerService() {
+ @Override
+ public CircuitBreaker getBreaker(String name) {
+ if (name.equals(CircuitBreaker.REQUEST) == false) {
+ throw new UnsupportedOperationException("Only REQUEST circuit breaker is supported");
+ }
+ return breaker;
+ }
+
+ @Override
+ public AllCircuitBreakerStats stats() {
+ throw new UnsupportedOperationException("Circuit breaker stats not supported in fold context");
+ }
+
+ @Override
+ public CircuitBreakerStats stats(String name) {
+ throw new UnsupportedOperationException("Circuit breaker stats not supported in fold context");
+ }
+ }, CircuitBreaker.REQUEST).withCircuitBreaking();
+
+ DriverContext driverContext = new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays));
+
+ // Create the inference operator for the specific function type using the provider
+ try {
+ Operator inferenceOperator = inferenceOperatorProvider.getOperator(f, driverContext);
+
+ try {
+ // Feed the operator with a single page to trigger execution
+ // The actual input data is already bound in the operator through expression evaluators
+ inferenceOperator.addInput(new Page(1));
+
+ // Execute the inference operation asynchronously and handle the result
+ // The operator will perform the actual inference call and return a page with the result
+ driverContext.waitForAsyncActions(listener.delegateFailureIgnoreResponseAndWrap(l -> {
+ Page output = inferenceOperator.getOutput();
+
+ try {
+ if (output == null) {
+ l.onFailure(new IllegalStateException("Expected output page from inference operator"));
+ return;
+ }
+
+ if (output.getPositionCount() != 1 || output.getBlockCount() != 1) {
+ l.onFailure(new IllegalStateException("Expected a single block with a single value from inference operator"));
+ return;
+ }
+
+ // Convert the operator result back to an ESQL expression (Literal)
+ l.onResponse(Literal.of(f, BlockUtils.toJavaObject(output.getBlock(0), 0)));
+ } finally {
+ Releasables.close(inferenceOperator);
+ if (output != null) {
+ output.releaseBlocks();
+ }
+ }
+ }));
+ } catch (Exception e) {
+ Releasables.close(inferenceOperator);
+ listener.onFailure(e);
+ }
+ } catch (Exception e) {
+ listener.onFailure(e);
+ } finally {
+ driverContext.finish();
+ }
+ }
+
+ /**
+ * Functional interface for providing inference operators.
+ *
+ * This interface abstracts the creation of inference operators for different function types,
+ * allowing for easier testing and potential future extensibility. The provider is responsible
+ * for creating an appropriate operator instance given an inference function and driver context.
+ */
+ interface InferenceOperatorProvider {
+ /**
+ * Creates an inference operator for the given function and driver context.
+ *
+ * @param f the inference function to create an operator for
+ * @param driverContext the driver context to use for operator creation
+ * @return an operator instance configured for the given function
+ */
+ Operator getOperator(InferenceFunction> f, DriverContext driverContext);
+ }
+
+ /**
+ * Factory for creating {@link InferenceFunctionEvaluator} instances.
+ */
+ public static class Factory {
+ private Factory() {}
+
+ /**
+ * Creates a new inference function evaluator.
+ *
+ * @param foldContext the fold context
+ * @param inferenceService the inference service
+ * @return a new instance of {@link InferenceFunctionEvaluator}
+ */
+ public InferenceFunctionEvaluator create(FoldContext foldContext, InferenceService inferenceService) {
+ return new InferenceFunctionEvaluator(foldContext, createInferenceOperatorProvider(foldContext, inferenceService));
+ }
+
+ /**
+ * Creates an {@link InferenceOperatorProvider} that can produce operators for all supported inference functions.
+ */
+ private InferenceOperatorProvider createInferenceOperatorProvider(FoldContext foldContext, InferenceService inferenceService) {
+ return (inferenceFunction, driverContext) -> {
+ Operator.OperatorFactory factory = switch (inferenceFunction) {
+ case TextEmbedding textEmbedding -> new TextEmbeddingOperator.Factory(
+ inferenceService,
+ inferenceId(inferenceFunction, foldContext),
+ expressionEvaluatorFactory(textEmbedding.inputText(), foldContext)
+ );
+ default -> throw new IllegalArgumentException("Unknown inference function: " + inferenceFunction.getClass().getName());
+ };
+
+ return factory.get(driverContext);
+ };
+ }
+
+ /**
+ * Extracts the inference endpoint ID from an inference function.
+ *
+ * @param f the inference function containing the inference ID
+ * @return the inference endpoint ID as a string
+ */
+ private String inferenceId(InferenceFunction> f, FoldContext foldContext) {
+ return BytesRefs.toString(f.inferenceId().fold(foldContext));
+ }
+
+ /**
+ * Creates an expression evaluator factory for a foldable expression.
+ *
+ * This method converts a foldable expression into an evaluator factory that can be used by inference
+ * operators. The expression is first folded to its constant value and then wrapped in a literal.
+ *
+ * @param e the foldable expression to create an evaluator factory for
+ * @return an expression evaluator factory for the given expression
+ */
+ private EvalOperator.ExpressionEvaluator.Factory expressionEvaluatorFactory(Expression e, FoldContext foldContext) {
+ assert e.foldable() : "Input expression must be foldable";
+ return EvalMapper.toEvaluator(foldContext, Literal.of(foldContext, e), null);
+ }
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java
index 93085969415a6..f93c024eeab73 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java
@@ -144,9 +144,9 @@ public interface OutputBuilder extends Releasable {
void addInferenceResponse(InferenceAction.Response inferenceResponse);
/**
- * Builds the final output page from accumulated inference responses.
+ * Builds the final output from accumulated inference responses.
*
- * @return The constructed output page.
+ * @return The constructed output block.
*/
Page buildOutput();
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java
index abb4eef251374..bc492d64e5c5a 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java
@@ -16,6 +16,11 @@
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
+import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
+import org.elasticsearch.xpack.esql.expression.function.FunctionDefinition;
+import org.elasticsearch.xpack.esql.expression.function.UnresolvedFunction;
+import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
@@ -30,7 +35,7 @@
public class InferenceResolver {
private final Client client;
-
+ private final EsqlFunctionRegistry functionRegistry;
private final ThreadPool threadPool;
/**
@@ -38,8 +43,9 @@ public class InferenceResolver {
*
* @param client The Elasticsearch client for executing inference deployment lookups
*/
- public InferenceResolver(Client client, ThreadPool threadPool) {
+ public InferenceResolver(Client client, EsqlFunctionRegistry functionRegistry, ThreadPool threadPool) {
this.client = client;
+ this.functionRegistry = functionRegistry;
this.threadPool = threadPool;
}
@@ -75,6 +81,7 @@ public void resolveInferenceIds(LogicalPlan plan, ActionListener c) {
collectInferenceIdsFromInferencePlans(plan, c);
+ collectInferenceIdsFromInferenceFunctions(plan, c);
}
/**
@@ -134,6 +141,28 @@ private void collectInferenceIdsFromInferencePlans(LogicalPlan plan, Consumer c.accept(inferenceId(inferencePlan)));
}
+ /**
+ * Collects inference IDs from function expressions within the logical plan.
+ *
+ * @param plan The logical plan to scan for function expressions
+ * @param c Consumer function to receive each discovered inference ID
+ */
+ private void collectInferenceIdsFromInferenceFunctions(LogicalPlan plan, Consumer c) {
+ EsqlFunctionRegistry snapshotRegistry = functionRegistry.snapshotRegistry();
+ plan.forEachExpressionUp(UnresolvedFunction.class, f -> {
+ String functionName = snapshotRegistry.resolveAlias(f.name());
+ if (snapshotRegistry.functionExists(functionName)) {
+ FunctionDefinition def = snapshotRegistry.resolveFunction(functionName);
+ if (InferenceFunction.class.isAssignableFrom(def.clazz())) {
+ String inferenceId = inferenceId(f, def);
+ if (inferenceId != null) {
+ c.accept(inferenceId);
+ }
+ }
+ }
+ });
+ }
+
/**
* Extracts the inference ID from an InferencePlan object.
*
@@ -148,6 +177,23 @@ private static String inferenceId(Expression e) {
return BytesRefs.toString(e.fold(FoldContext.small()));
}
+ public String inferenceId(UnresolvedFunction f, FunctionDefinition def) {
+ EsqlFunctionRegistry.FunctionDescription functionDescription = EsqlFunctionRegistry.description(def);
+
+ for (int i = 0; i < functionDescription.args().size(); i++) {
+ EsqlFunctionRegistry.ArgSignature arg = functionDescription.args().get(i);
+
+ if (arg.name().equals(InferenceFunction.INFERENCE_ID_PARAMETER_NAME)) {
+ Expression argValue = f.arguments().get(i);
+ if (argValue != null && argValue.foldable() && DataType.isString(argValue.dataType())) {
+ return inferenceId(argValue);
+ }
+ }
+ }
+
+ return null;
+ }
+
public static Factory factory(Client client) {
return new Factory(client, client.threadPool());
}
@@ -161,8 +207,8 @@ private Factory(Client client, ThreadPool threadPool) {
this.threadPool = threadPool;
}
- public InferenceResolver create() {
- return new InferenceResolver(client, threadPool);
+ public InferenceResolver create(EsqlFunctionRegistry functionRegistry) {
+ return new InferenceResolver(client, functionRegistry, threadPool);
}
}
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceService.java
index 37c163beaecda..630477a20f447 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceService.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceService.java
@@ -8,6 +8,7 @@
package org.elasticsearch.xpack.esql.inference;
import org.elasticsearch.client.internal.Client;
+import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner;
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunnerConfig;
@@ -33,10 +34,12 @@ private InferenceService(InferenceResolver.Factory inferenceResolverFactory, Bul
/**
* Creates an inference resolver for resolving inference IDs in logical plans.
*
+ * @param functionRegistry the function registry to resolve functions
+ *
* @return a new inference resolver instance
*/
- public InferenceResolver inferenceResolver() {
- return inferenceResolverFactory.create();
+ public InferenceResolver inferenceResolver(EsqlFunctionRegistry functionRegistry) {
+ return inferenceResolverFactory.create(functionRegistry);
}
public BulkInferenceRunner bulkInferenceRunner() {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InputTextReader.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InputTextReader.java
new file mode 100644
index 0000000000000..d7bf43cdfbdd5
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InputTextReader.java
@@ -0,0 +1,76 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.inference;
+
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.compute.data.BytesRefBlock;
+import org.elasticsearch.core.Releasable;
+import org.elasticsearch.core.Releasables;
+
+/**
+ * Helper class that reads text strings from a {@link BytesRefBlock}.
+ * This class is used by inference operators to extract text content from block data.
+ */
+public class InputTextReader implements Releasable {
+ private final BytesRefBlock textBlock;
+ private final StringBuilder strBuilder = new StringBuilder();
+ private BytesRef readBuffer = new BytesRef();
+
+ public InputTextReader(BytesRefBlock textBlock) {
+ this.textBlock = textBlock;
+ }
+
+ /**
+ * Reads the text string at the given position.
+ * Multiple values at the position are concatenated with newlines.
+ *
+ * @param pos the position index in the block
+ * @return the text string at the position, or null if the position contains a null value
+ */
+ public String readText(int pos) {
+ return readText(pos, Integer.MAX_VALUE);
+ }
+
+ /**
+ * Reads the text string at the given position.
+ *
+ * @param pos the position index in the block
+ * @param limit the maximum number of value to read from the position
+ * @return the text string at the position, or null if the position contains a null value
+ */
+ public String readText(int pos, int limit) {
+ if (textBlock.isNull(pos)) {
+ return null;
+ }
+
+ strBuilder.setLength(0);
+ int maxPos = Math.min(limit, textBlock.getValueCount(pos));
+ for (int valueIndex = 0; valueIndex < maxPos; valueIndex++) {
+ readBuffer = textBlock.getBytesRef(textBlock.getFirstValueIndex(pos) + valueIndex, readBuffer);
+ strBuilder.append(readBuffer.utf8ToString());
+ if (valueIndex != maxPos - 1) {
+ strBuilder.append("\n");
+ }
+ }
+
+ return strBuilder.toString();
+ }
+
+ /**
+ * Returns the total number of positions (text entries) in the block.
+ */
+ public int estimatedSize() {
+ return textBlock.getPositionCount();
+ }
+
+ @Override
+ public void close() {
+ textBlock.allowPassingToDifferentDriver();
+ Releasables.close(textBlock);
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java
index 203a3031bcad4..f220f1842953c 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java
@@ -8,6 +8,7 @@
package org.elasticsearch.xpack.esql.inference.bulk;
import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.ThreadedActionListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.threadpool.ThreadPool;
@@ -25,7 +26,6 @@
import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
-import static org.elasticsearch.xpack.esql.plugin.EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME;
/**
* Implementation of bulk inference execution with throttling and concurrency control.
@@ -88,7 +88,7 @@ public BulkInferenceRequest poll() {
public BulkInferenceRunner(Client client, int maxRunningTasks) {
this.permits = new Semaphore(maxRunningTasks);
this.client = client;
- this.executor = client.threadPool().executor(ESQL_WORKER_THREAD_POOL_NAME);
+ this.executor = client.threadPool().executor(ThreadPool.Names.SEARCH);
}
/**
@@ -253,48 +253,51 @@ private void executePendingRequests(int recursionDepth) {
executionState.finish();
}
- final ActionListener inferenceResponseListener = ActionListener.runAfter(
- ActionListener.wrap(
- r -> executionState.onInferenceResponse(bulkRequestItem.seqNo(), r),
- e -> executionState.onInferenceException(bulkRequestItem.seqNo(), e)
- ),
- () -> {
- // Release the permit we used
- permits.release();
-
- try {
- synchronized (executionState) {
- persistPendingResponses();
- }
+ final ActionListener inferenceResponseListener = new ThreadedActionListener<>(
+ executor,
+ ActionListener.runAfter(
+ ActionListener.wrap(
+ r -> executionState.onInferenceResponse(bulkRequestItem.seqNo(), r),
+ e -> executionState.onInferenceException(bulkRequestItem.seqNo(), e)
+ ),
+ () -> {
+ // Release the permit we used
+ permits.release();
+
+ try {
+ synchronized (executionState) {
+ persistPendingResponses();
+ }
- if (executionState.finished() && responseSent.compareAndSet(false, true)) {
- onBulkCompletion();
- }
+ if (executionState.finished() && responseSent.compareAndSet(false, true)) {
+ onBulkCompletion();
+ }
- if (responseSent.get()) {
- // Response has already been sent
- // No need to continue processing this bulk.
- // Check if another bulk request is pending for execution.
- BulkInferenceRequest nexBulkRequest = pendingBulkRequests.poll();
- if (nexBulkRequest != null) {
- executor.execute(nexBulkRequest::executePendingRequests);
+ if (responseSent.get()) {
+ // Response has already been sent
+ // No need to continue processing this bulk.
+ // Check if another bulk request is pending for execution.
+ BulkInferenceRequest nexBulkRequest = pendingBulkRequests.poll();
+ if (nexBulkRequest != null) {
+ executor.execute(nexBulkRequest::executePendingRequests);
+ }
+ return;
}
- return;
- }
- if (executionState.finished() == false) {
- // Execute any pending requests if any
- if (recursionDepth > 100) {
- executor.execute(this::executePendingRequests);
- } else {
- this.executePendingRequests(recursionDepth + 1);
+ if (executionState.finished() == false) {
+ // Execute any pending requests if any
+ if (recursionDepth > 100) {
+ executor.execute(this::executePendingRequests);
+ } else {
+ this.executePendingRequests(recursionDepth + 1);
+ }
+ }
+ } catch (Exception e) {
+ if (responseSent.compareAndSet(false, true)) {
+ completionListener.onFailure(e);
}
- }
- } catch (Exception e) {
- if (responseSent.compareAndSet(false, true)) {
- completionListener.onFailure(e);
}
}
- }
+ )
);
// Handle null requests (edge case in some iterators)
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java
index f526cd9edb077..509110e0ffe2b 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java
@@ -7,12 +7,11 @@
package org.elasticsearch.xpack.esql.inference.completion;
-import org.apache.lucene.util.BytesRef;
import org.elasticsearch.compute.data.BytesRefBlock;
-import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.esql.inference.InputTextReader;
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
import java.util.List;
@@ -24,7 +23,7 @@
*/
public class CompletionOperatorRequestIterator implements BulkInferenceRequestIterator {
- private final PromptReader promptReader;
+ private final InputTextReader textReader;
private final String inferenceId;
private final int size;
private int currentPos = 0;
@@ -36,7 +35,7 @@ public class CompletionOperatorRequestIterator implements BulkInferenceRequestIt
* @param inferenceId The ID of the inference model to invoke.
*/
public CompletionOperatorRequestIterator(BytesRefBlock promptBlock, String inferenceId) {
- this.promptReader = new PromptReader(promptBlock);
+ this.textReader = new InputTextReader(promptBlock);
this.size = promptBlock.getPositionCount();
this.inferenceId = inferenceId;
}
@@ -52,7 +51,7 @@ public InferenceAction.Request next() {
throw new NoSuchElementException();
}
- return inferenceRequest(promptReader.readPrompt(currentPos++));
+ return inferenceRequest(textReader.readText(currentPos++));
}
/**
@@ -68,60 +67,11 @@ private InferenceAction.Request inferenceRequest(String prompt) {
@Override
public int estimatedSize() {
- return promptReader.estimatedSize();
+ return textReader.estimatedSize();
}
@Override
public void close() {
- Releasables.close(promptReader);
- }
-
- /**
- * Helper class that reads prompts from a {@link BytesRefBlock}.
- */
- private static class PromptReader implements Releasable {
- private final BytesRefBlock promptBlock;
- private final StringBuilder strBuilder = new StringBuilder();
- private BytesRef readBuffer = new BytesRef();
-
- private PromptReader(BytesRefBlock promptBlock) {
- this.promptBlock = promptBlock;
- }
-
- /**
- * Reads the prompt string at the given position..
- *
- * @param pos the position index in the block
- */
- public String readPrompt(int pos) {
- if (promptBlock.isNull(pos)) {
- return null;
- }
-
- strBuilder.setLength(0);
-
- for (int valueIndex = 0; valueIndex < promptBlock.getValueCount(pos); valueIndex++) {
- readBuffer = promptBlock.getBytesRef(promptBlock.getFirstValueIndex(pos) + valueIndex, readBuffer);
- strBuilder.append(readBuffer.utf8ToString());
- if (valueIndex != promptBlock.getValueCount(pos) - 1) {
- strBuilder.append("\n");
- }
- }
-
- return strBuilder.toString();
- }
-
- /**
- * Returns the total number of positions (prompts) in the block.
- */
- public int estimatedSize() {
- return promptBlock.getPositionCount();
- }
-
- @Override
- public void close() {
- promptBlock.allowPassingToDifferentDriver();
- Releasables.close(promptBlock);
- }
+ Releasables.close(textReader);
}
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperator.java
new file mode 100644
index 0000000000000..7817612007614
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperator.java
@@ -0,0 +1,96 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.inference.textembedding;
+
+import org.elasticsearch.compute.data.BytesRefBlock;
+import org.elasticsearch.compute.data.FloatBlock;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.compute.operator.DriverContext;
+import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
+import org.elasticsearch.compute.operator.Operator;
+import org.elasticsearch.core.Releasables;
+import org.elasticsearch.xpack.esql.inference.InferenceOperator;
+import org.elasticsearch.xpack.esql.inference.InferenceService;
+import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
+import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner;
+import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunnerConfig;
+
+/**
+ * {@link TextEmbeddingOperator} is an {@link InferenceOperator} that performs text embedding inference.
+ * It evaluates a text expression for each input row, constructs text embedding inference requests,
+ * and emits the dense vector embeddings as output.
+ */
+public class TextEmbeddingOperator extends InferenceOperator {
+
+ private final ExpressionEvaluator textEvaluator;
+
+ public TextEmbeddingOperator(
+ DriverContext driverContext,
+ BulkInferenceRunner bulkInferenceRunner,
+ String inferenceId,
+ ExpressionEvaluator textEvaluator,
+ int maxOutstandingPages
+ ) {
+ super(driverContext, bulkInferenceRunner, inferenceId, maxOutstandingPages);
+ this.textEvaluator = textEvaluator;
+ }
+
+ @Override
+ protected void doClose() {
+ Releasables.close(textEvaluator);
+ }
+
+ @Override
+ public String toString() {
+ return "TextEmbeddingOperator[inference_id=[" + inferenceId() + "]]";
+ }
+
+ /**
+ * Constructs the text embedding inference requests iterator for the given input page by evaluating the text expression.
+ *
+ * @param inputPage The input data page.
+ */
+ @Override
+ protected BulkInferenceRequestIterator requests(Page inputPage) {
+ return new TextEmbeddingOperatorRequestIterator((BytesRefBlock) textEvaluator.eval(inputPage), inferenceId());
+ }
+
+ /**
+ * Creates a new {@link TextEmbeddingOperatorOutputBuilder} to collect and emit the text embedding results.
+ *
+ * @param input The input page for which results will be constructed.
+ */
+ @Override
+ protected TextEmbeddingOperatorOutputBuilder outputBuilder(Page input) {
+ FloatBlock.Builder outputBlockBuilder = blockFactory().newFloatBlockBuilder(input.getPositionCount());
+ return new TextEmbeddingOperatorOutputBuilder(outputBlockBuilder, input);
+ }
+
+ /**
+ * Factory for creating {@link TextEmbeddingOperator} instances.
+ */
+ public record Factory(InferenceService inferenceService, String inferenceId, ExpressionEvaluator.Factory textEvaluatorFactory)
+ implements
+ OperatorFactory {
+ @Override
+ public String describe() {
+ return "TextEmbeddingOperator[inference_id=[" + inferenceId + "]]";
+ }
+
+ @Override
+ public Operator get(DriverContext driverContext) {
+ return new TextEmbeddingOperator(
+ driverContext,
+ inferenceService.bulkInferenceRunner(),
+ inferenceId,
+ textEvaluatorFactory.get(driverContext),
+ BulkInferenceRunnerConfig.DEFAULT.maxOutstandingBulkRequests()
+ );
+ }
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java
new file mode 100644
index 0000000000000..521bb508c30af
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java
@@ -0,0 +1,103 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.inference.textembedding;
+
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.FloatBlock;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.core.Releasables;
+import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
+import org.elasticsearch.xpack.esql.inference.InferenceOperator;
+
+/**
+ * {@link TextEmbeddingOperatorOutputBuilder} builds the output page for text embedding by converting
+ * {@link TextEmbeddingResults} into a {@link FloatBlock} containing dense vector embeddings.
+ */
+public class TextEmbeddingOperatorOutputBuilder implements InferenceOperator.OutputBuilder {
+ private final Page inputPage;
+ private final FloatBlock.Builder outputBlockBuilder;
+
+ public TextEmbeddingOperatorOutputBuilder(FloatBlock.Builder outputBlockBuilder, Page inputPage) {
+ this.inputPage = inputPage;
+ this.outputBlockBuilder = outputBlockBuilder;
+ }
+
+ @Override
+ public void close() {
+ Releasables.close(outputBlockBuilder);
+ }
+
+ /**
+ * Adds an inference response to the output builder.
+ *
+ *
+ * If the response is null or not of type {@link TextEmbeddingResults} an {@link IllegalStateException} is thrown.
+ * Else, the embedding vector is added to the output block as a multi-value position.
+ *
+ *
+ *
+ * The responses must be added in the same order as the corresponding inference requests were generated.
+ * Failing to preserve order may lead to incorrect or misaligned output rows.
+ *
+ */
+ @Override
+ public void addInferenceResponse(InferenceAction.Response inferenceResponse) {
+ if (inferenceResponse == null) {
+ outputBlockBuilder.appendNull();
+ return;
+ }
+
+ TextEmbeddingResults> embeddingResults = inferenceResults(inferenceResponse);
+
+ var embeddings = embeddingResults.embeddings();
+ if (embeddings.isEmpty()) {
+ outputBlockBuilder.appendNull();
+ return;
+ }
+
+ float[] embeddingArray = getEmbeddingAsFloatArray(embeddingResults);
+
+ outputBlockBuilder.beginPositionEntry();
+ for (float component : embeddingArray) {
+ outputBlockBuilder.appendFloat(component);
+ }
+ outputBlockBuilder.endPositionEntry();
+ }
+
+ /**
+ * Builds the final output page by appending the embedding output block to the input page.
+ */
+ @Override
+ public Page buildOutput() {
+ Block outputBlock = outputBlockBuilder.build();
+ assert outputBlock.getPositionCount() == inputPage.getPositionCount();
+ return inputPage.appendBlock(outputBlock);
+ }
+
+ private TextEmbeddingResults> inferenceResults(InferenceAction.Response inferenceResponse) {
+ return InferenceOperator.OutputBuilder.inferenceResults(inferenceResponse, TextEmbeddingResults.class);
+ }
+
+ /**
+ * Extracts the embedding as a float array from the embedding result.
+ */
+ private float[] getEmbeddingAsFloatArray(TextEmbeddingResults> embedding) {
+ return switch (embedding.embeddings().get(0)) {
+ case TextEmbeddingFloatResults.Embedding floatEmbedding -> floatEmbedding.values();
+ case TextEmbeddingByteResults.Embedding byteEmbedding -> byteEmbedding.toFloatArray();
+ default -> throw new IllegalArgumentException(
+ "Unsupported embedding type: "
+ + embedding.embeddings().get(0).getClass().getName()
+ + ". Expected TextEmbeddingFloatResults.Embedding or TextEmbeddingByteResults.Embedding."
+ );
+ };
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorRequestIterator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorRequestIterator.java
new file mode 100644
index 0000000000000..4860f7fdcb6ec
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorRequestIterator.java
@@ -0,0 +1,81 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.inference.textembedding;
+
+import org.elasticsearch.compute.data.BytesRefBlock;
+import org.elasticsearch.core.Releasables;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.esql.inference.InputTextReader;
+import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
+
+import java.util.List;
+import java.util.NoSuchElementException;
+
+/**
+ * This iterator reads text inputs from a {@link BytesRefBlock} and converts them into individual {@link InferenceAction.Request} instances
+ * of type {@link TaskType#TEXT_EMBEDDING}.
+ */
+public class TextEmbeddingOperatorRequestIterator implements BulkInferenceRequestIterator {
+
+ private final InputTextReader textReader;
+ private final String inferenceId;
+ private final int size;
+ private int currentPos = 0;
+
+ /**
+ * Constructs a new iterator from the given block of text inputs.
+ *
+ * @param textBlock The input block containing text to embed.
+ * @param inferenceId The ID of the inference model to invoke.
+ */
+ public TextEmbeddingOperatorRequestIterator(BytesRefBlock textBlock, String inferenceId) {
+ this.textReader = new InputTextReader(textBlock);
+ this.size = textBlock.getPositionCount();
+ this.inferenceId = inferenceId;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return currentPos < size;
+ }
+
+ @Override
+ public InferenceAction.Request next() {
+ if (hasNext() == false) {
+ throw new NoSuchElementException();
+ }
+
+ /*
+ * Keep only the first value in case of multi-valued fields.
+ * TODO: check if it is consistent with how the query vector builder is working.
+ */
+ return inferenceRequest(textReader.readText(currentPos++, 1));
+ }
+
+ /**
+ * Wraps a single text string into an {@link InferenceAction.Request} for text embedding.
+ */
+ private InferenceAction.Request inferenceRequest(String text) {
+ if (text == null) {
+ return null;
+ }
+
+ return InferenceAction.Request.builder(inferenceId, TaskType.TEXT_EMBEDDING).setInput(List.of(text)).build();
+ }
+
+ @Override
+ public int estimatedSize() {
+ return textReader.estimatedSize();
+ }
+
+ @Override
+ public void close() {
+ Releasables.close(textReader);
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java
index fdd8e1318f636..11c57b38c2331 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java
@@ -8,8 +8,13 @@
package org.elasticsearch.xpack.esql.optimizer;
import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.SubscribableListener;
+import org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer.FoldInferenceFunctions;
+import org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer.LogicalPlanPreOptimizerRule;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
+import java.util.List;
+
/**
* The class is responsible for invoking any steps that need to be applied to the logical plan,
* before this is being optimized.
@@ -18,11 +23,14 @@
*
*/
public class LogicalPlanPreOptimizer {
-
- private final LogicalPreOptimizerContext preOptimizerContext;
+ private final List preOptimizerRules;
public LogicalPlanPreOptimizer(LogicalPreOptimizerContext preOptimizerContext) {
- this.preOptimizerContext = preOptimizerContext;
+ this(List.of(new FoldInferenceFunctions(preOptimizerContext)));
+ }
+
+ LogicalPlanPreOptimizer(List preOptimizerRules) {
+ this.preOptimizerRules = preOptimizerRules;
}
/**
@@ -43,8 +51,17 @@ public void preOptimize(LogicalPlan plan, ActionListener listener)
}));
}
+ /**
+ * Loop over the rules and apply them sequentially to the logical plan.
+ *
+ * @param plan the analyzed logical plan to pre-optimize
+ * @param listener the listener returning the pre-optimized plan when pre-optimization is complete
+ */
private void doPreOptimize(LogicalPlan plan, ActionListener listener) {
- // this is where we will be executing async tasks
- listener.onResponse(plan);
+ SubscribableListener rulesListener = SubscribableListener.newSucceeded(plan);
+ for (LogicalPlanPreOptimizerRule preOptimizerRule : preOptimizerRules) {
+ rulesListener = rulesListener.andThen((l, p) -> preOptimizerRule.apply(p, l));
+ }
+ rulesListener.addListener(listener);
}
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPreOptimizerContext.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPreOptimizerContext.java
index d082bd56fc46d..233bf9c2945be 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPreOptimizerContext.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPreOptimizerContext.java
@@ -8,36 +8,29 @@
package org.elasticsearch.xpack.esql.optimizer;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
-
-import java.util.Objects;
+import org.elasticsearch.xpack.esql.inference.InferenceService;
public class LogicalPreOptimizerContext {
private final FoldContext foldCtx;
- public LogicalPreOptimizerContext(FoldContext foldCtx) {
+ private final InferenceService inferenceService;
+
+ public LogicalPreOptimizerContext(FoldContext foldCtx, InferenceService inferenceService) {
this.foldCtx = foldCtx;
+ this.inferenceService = inferenceService;
}
public FoldContext foldCtx() {
return foldCtx;
}
- @Override
- public boolean equals(Object obj) {
- if (obj == this) return true;
- if (obj == null || obj.getClass() != this.getClass()) return false;
- var that = (LogicalPreOptimizerContext) obj;
- return this.foldCtx.equals(that.foldCtx);
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(foldCtx);
- }
-
@Override
public String toString() {
return "LogicalPreOptimizerContext[foldCtx=" + foldCtx + ']';
}
+
+ public InferenceService inferenceService() {
+ return inferenceService;
+ }
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/FoldInferenceFunctions.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/FoldInferenceFunctions.java
new file mode 100644
index 0000000000000..4a3fc87c43795
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/FoldInferenceFunctions.java
@@ -0,0 +1,131 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer;
+
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.CountDownActionListener;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction;
+import org.elasticsearch.xpack.esql.inference.InferenceFunctionEvaluator;
+import org.elasticsearch.xpack.esql.optimizer.LogicalPreOptimizerContext;
+import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Pre-optimizer rule that performs constant folding for inference functions in logical plans.
+ *
+ * This rule identifies inference functions with constant parameters and evaluates them at optimization time,
+ * replacing them with their computed results.
+ *
+ * The folding process is recursive and handles nested inference functions by processing them in multiple
+ * passes until no more foldable functions remain.
+ *
+ * This method collects all foldable inference functions, evaluates them in parallel,
+ * and then replaces them with their computed results. If new foldable inference functions are remaining
+ * after the first round of folding (due to nested function resolution), it recursively processes
+ * them until no more foldable functions remain.
+ *
+ *
+ * @param plan the logical plan to fold inference functions in
+ * @param listener the listener to notify when the folding is complete
+ */
+ private void foldInferenceFunctions(LogicalPlan plan, ActionListener listener) {
+ // Collect all foldable inference functions from the current plan
+ List> inferenceFunctions = collectFoldableInferenceFunctions(plan);
+
+ if (inferenceFunctions.isEmpty()) {
+ // No foldable inference functions were found - return the original plan unchanged
+ listener.onResponse(plan);
+ return;
+ }
+
+ // Map to store the computed results for each inference function
+ Map, Expression> inferenceFunctionsToResults = new HashMap<>();
+
+ // Create a countdown listener that will be triggered when all inference functions complete
+ // Once all are done, replace the functions in the plan with their results and recursively
+ // process any remaining foldable inference functions
+ CountDownActionListener completionListener = new CountDownActionListener(
+ inferenceFunctions.size(),
+ listener.delegateFailureIgnoreResponseAndWrap(l -> {
+ // Transform the plan by replacing inference functions with their computed results
+ LogicalPlan transformedPlan = plan.transformExpressionsUp(
+ InferenceFunction.class,
+ f -> inferenceFunctionsToResults.getOrDefault(f, f)
+ );
+
+ // Recursively process the transformed plan to handle any remaining inference functions
+ foldInferenceFunctions(transformedPlan, l);
+ })
+ );
+
+ // Evaluate each inference function asynchronously
+ for (InferenceFunction> inferenceFunction : inferenceFunctions) {
+ inferenceFunctionEvaluator.fold(inferenceFunction, completionListener.delegateFailureAndWrap((l, result) -> {
+ inferenceFunctionsToResults.put(inferenceFunction, result);
+ l.onResponse(null);
+ }));
+ }
+ }
+
+ /**
+ * Collects all foldable inference functions from the logical plan.
+ *
+ * A function is considered foldable if it meets all of the following criteria:
+ *
+ *
It's an instance of {@link InferenceFunction}
+ *
It's marked as foldable (all parameters are constants)
+ *
It doesn't contain nested inference functions (to avoid dependency issues)
+ *
+ *
+ * Functions with nested inference functions are excluded to ensure proper evaluation order.
+ * They will be considered for folding in subsequent recursive passes after their nested
+ * functions have been resolved.
+ *
+ * @param plan the logical plan to collect inference functions from
+ * @return a list of foldable inference functions, may be empty if none are found
+ */
+ private List> collectFoldableInferenceFunctions(LogicalPlan plan) {
+ List> inferenceFunctions = new ArrayList<>();
+
+ plan.forEachExpressionUp(InferenceFunction.class, f -> {
+ if (f.foldable() && f.hasNestedInferenceFunction() == false) {
+ inferenceFunctions.add(f);
+ }
+ });
+
+ return inferenceFunctions;
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/LogicalPlanPreOptimizerRule.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/LogicalPlanPreOptimizerRule.java
new file mode 100644
index 0000000000000..29d9be564f1bf
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/LogicalPlanPreOptimizerRule.java
@@ -0,0 +1,25 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer;
+
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
+
+/**
+ * A rule that can be applied to an analyzed logical plan before it is optimized.
+ */
+public interface LogicalPlanPreOptimizerRule {
+
+ /**
+ * Apply the rule to the logical plan.
+ *
+ * @param plan the analyzed logical plan to pre-optimize
+ * @param listener the listener returning the pre-optimized plan when pre-optimization rule is applied
+ */
+ void apply(LogicalPlan plan, ActionListener listener);
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java
index 1d3b710baac35..4d47df6789e8c 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java
@@ -754,7 +754,7 @@ private void analyzeWithRetry(
}
private void resolveInferences(LogicalPlan plan, PreAnalysisResult preAnalysisResult, ActionListener l) {
- inferenceService.inferenceResolver().resolveInferenceIds(plan, l.map(preAnalysisResult::withInferenceResolution));
+ inferenceService.inferenceResolver(functionRegistry).resolveInferenceIds(plan, l.map(preAnalysisResult::withInferenceResolution));
}
private PhysicalPlan logicalPlanToPhysicalPlan(LogicalPlan optimizedPlan, EsqlQueryRequest request) {
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java
index 7a4de299caf79..4a63aea448f9e 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java
@@ -303,6 +303,10 @@ public final void test() throws Throwable {
"can't use KQL function in csv tests",
testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KQL_FUNCTION.capabilityName())
);
+ assumeFalse(
+ "can't use TEXT_EMBEDDING function in csv tests",
+ testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.capabilityName())
+ );
assumeFalse(
"can't use KNN function in csv tests",
testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KNN_FUNCTION_V5.capabilityName())
@@ -584,7 +588,7 @@ private ActualResults executePlan(BigArrays bigArrays) throws Exception {
null,
null,
null,
- new LogicalPlanPreOptimizer(new LogicalPreOptimizerContext(foldCtx)),
+ new LogicalPlanPreOptimizer(new LogicalPreOptimizerContext(foldCtx, null)),
functionRegistry,
new LogicalPlanOptimizer(new LogicalOptimizerContext(configuration, foldCtx)),
mapper,
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java
index fbfa18dccc477..78a8ca5483246 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java
@@ -9,6 +9,7 @@
import org.elasticsearch.index.IndexMode;
import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.enrich.EnrichPolicy;
import org.elasticsearch.xpack.esql.EsqlTestUtils;
import org.elasticsearch.xpack.esql.core.type.EsField;
@@ -26,6 +27,7 @@
import org.elasticsearch.xpack.esql.session.Configuration;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
@@ -195,14 +197,39 @@ public static EnrichResolution defaultEnrichResolution() {
return enrichResolution;
}
+ public static final String RERANKING_INFERENCE_ID = "reranking-inference-id";
+ public static final String COMPLETION_INFERENCE_ID = "completion-inference-id";
+ public static final String TEXT_EMBEDDING_INFERENCE_ID = "text-embedding-inference-id";
+ public static final String CHAT_COMPLETION_INFERENCE_ID = "chat-completion-inference-id";
+ public static final String SPARSE_EMBEDDING_INFERENCE_ID = "sparse-embedding-inference-id";
+ public static final List VALID_INFERENCE_IDS = List.of(
+ RERANKING_INFERENCE_ID,
+ COMPLETION_INFERENCE_ID,
+ TEXT_EMBEDDING_INFERENCE_ID,
+ CHAT_COMPLETION_INFERENCE_ID,
+ SPARSE_EMBEDDING_INFERENCE_ID
+ );
+ public static final String ERROR_INFERENCE_ID = "error-inference-id";
+
public static InferenceResolution defaultInferenceResolution() {
return InferenceResolution.builder()
- .withResolvedInference(new ResolvedInference("reranking-inference-id", TaskType.RERANK))
- .withResolvedInference(new ResolvedInference("completion-inference-id", TaskType.COMPLETION))
- .withError("error-inference-id", "error with inference resolution")
+ .withResolvedInference(new ResolvedInference(RERANKING_INFERENCE_ID, TaskType.RERANK))
+ .withResolvedInference(new ResolvedInference(COMPLETION_INFERENCE_ID, TaskType.COMPLETION))
+ .withResolvedInference(new ResolvedInference(TEXT_EMBEDDING_INFERENCE_ID, TaskType.TEXT_EMBEDDING))
+ .withResolvedInference(new ResolvedInference(CHAT_COMPLETION_INFERENCE_ID, TaskType.CHAT_COMPLETION))
+ .withResolvedInference(new ResolvedInference(SPARSE_EMBEDDING_INFERENCE_ID, TaskType.SPARSE_EMBEDDING))
+ .withError(ERROR_INFERENCE_ID, "error with inference resolution")
.build();
}
+ public static String randomInferenceId() {
+ return ESTestCase.randomFrom(VALID_INFERENCE_IDS);
+ }
+
+ public static String randomInferenceId(String... excludes) {
+ return ESTestCase.randomValueOtherThanMany(Arrays.asList(excludes)::contains, AnalyzerTestUtils::randomInferenceId);
+ }
+
public static void loadEnrichPolicyResolution(
EnrichResolution enrich,
String policyType,
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
index 662e1bdabcab2..79aa821ce4040 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
@@ -56,6 +56,7 @@
import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString;
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
import org.elasticsearch.xpack.esql.expression.function.grouping.TBucket;
+import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDateNanos;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDatetime;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDenseVector;
@@ -123,6 +124,7 @@
import static org.elasticsearch.xpack.esql.EsqlTestUtils.referenceAttribute;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
import static org.elasticsearch.xpack.esql.analysis.Analyzer.NO_FIELDS;
+import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.TEXT_EMBEDDING_INFERENCE_ID;
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.analyze;
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.analyzer;
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.analyzerDefaultMapping;
@@ -130,6 +132,7 @@
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.defaultInferenceResolution;
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.indexWithDateDateNanosUnionType;
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.loadMapping;
+import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.randomInferenceId;
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.tsdbIndexResolution;
import static org.elasticsearch.xpack.esql.core.plugin.EsqlCorePlugin.DENSE_VECTOR_FEATURE_FLAG;
import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY;
@@ -3765,6 +3768,112 @@ private void assertEmptyEsRelation(LogicalPlan plan) {
assertThat(esRelation.output(), equalTo(NO_FIELDS));
}
+ public void testTextEmbeddingResolveInferenceId() {
+ assumeTrue("TEXT_EMBEDDING function required", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled());
+
+ LogicalPlan plan = analyze(
+ String.format(Locale.ROOT, """
+ FROM books METADATA _score | EVAL embedding = TEXT_EMBEDDING("italian food recipe", "%s")""", TEXT_EMBEDDING_INFERENCE_ID),
+ "mapping-books.json"
+ );
+
+ Eval eval = as(as(plan, Limit.class).child(), Eval.class);
+ assertThat(eval.fields(), hasSize(1));
+ Alias alias = as(eval.fields().get(0), Alias.class);
+ assertThat(alias.name(), equalTo("embedding"));
+ TextEmbedding function = as(alias.child(), TextEmbedding.class);
+
+ assertThat(function.inputText(), equalTo(string("italian food recipe")));
+ assertThat(function.inferenceId(), equalTo(string(TEXT_EMBEDDING_INFERENCE_ID)));
+ }
+
+ public void testTextEmbeddingFunctionResolveType() {
+ assumeTrue("TEXT_EMBEDDING function required", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled());
+
+ LogicalPlan plan = analyze(
+ String.format(Locale.ROOT, """
+ FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING("italian food recipe", "%s")""", TEXT_EMBEDDING_INFERENCE_ID),
+ "mapping-books.json"
+ );
+
+ Eval eval = as(as(plan, Limit.class).child(), Eval.class);
+ assertThat(eval.fields(), hasSize(1));
+ Alias alias = as(eval.fields().get(0), Alias.class);
+ assertThat(alias.name(), equalTo("embedding"));
+
+ TextEmbedding function = as(alias.child(), TextEmbedding.class);
+
+ assertThat(function.foldable(), equalTo(true));
+ assertThat(function.dataType(), equalTo(DENSE_VECTOR));
+ }
+
+ public void testTextEmbeddingFunctionMissingInferenceIdError() {
+ assumeTrue("TEXT_EMBEDDING function required", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled());
+
+ VerificationException ve = expectThrows(
+ VerificationException.class,
+ () -> analyze(
+ String.format(Locale.ROOT, """
+ FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING("italian food recipe", "%s")""", "unknow-inference-id"),
+ "mapping-books.json"
+ )
+ );
+
+ assertThat(ve.getMessage(), containsString("unresolved inference [unknow-inference-id]"));
+ }
+
+ public void testTextEmbeddingFunctionInvalidInferenceIdError() {
+ assumeTrue("TEXT_EMBEDDING function required", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled());
+
+ String inferenceId = randomInferenceId(TEXT_EMBEDDING_INFERENCE_ID);
+ VerificationException ve = expectThrows(
+ VerificationException.class,
+ () -> analyze(
+ String.format(Locale.ROOT, """
+ FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING("italian food recipe", "%s")""", inferenceId),
+ "mapping-books.json"
+ )
+ );
+
+ assertThat(
+ ve.getMessage(),
+ containsString(String.format(Locale.ROOT, "cannot use inference endpoint [%s] with task type", inferenceId))
+ );
+ }
+
+ public void testTextEmbeddingFunctionWithoutModel() {
+ assumeTrue("TEXT_EMBEDDING function required", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled());
+
+ ParsingException ve = expectThrows(ParsingException.class, () -> analyze("""
+ FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING("italian food recipe")""", "mapping-books.json"));
+
+ assertThat(
+ ve.getMessage(),
+ containsString(" error building [text_embedding]: function [text_embedding] expects exactly two arguments")
+ );
+ }
+
+ public void testKnnFunctionWithTextEmbedding() {
+ assumeTrue("KNN function capability required", EsqlCapabilities.Cap.KNN_FUNCTION_V5.isEnabled());
+ assumeTrue("TEXT_EMBEDDING function required", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled());
+
+ LogicalPlan plan = analyze(
+ String.format(Locale.ROOT, """
+ from test | where KNN(float_vector, TEXT_EMBEDDING("italian food recipe", "%s"))""", TEXT_EMBEDDING_INFERENCE_ID),
+ "mapping-dense_vector.json"
+ );
+
+ Limit limit = as(plan, Limit.class);
+ Filter filter = as(limit.child(), Filter.class);
+ Knn knn = as(filter.condition(), Knn.class);
+ assertThat(knn.field(), instanceOf(FieldAttribute.class));
+ assertThat(((FieldAttribute) knn.field()).name(), equalTo("float_vector"));
+
+ TextEmbedding textEmbedding = as(knn.query(), TextEmbedding.class);
+ assertThat(textEmbedding.inputText(), equalTo(string("italian food recipe")));
+ assertThat(textEmbedding.inferenceId(), equalTo(string(TEXT_EMBEDDING_INFERENCE_ID)));
+ }
+
public void testResolveRerankInferenceId() {
{
LogicalPlan plan = analyze("""
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java
index ba2ef6d152e45..66c8b2d6e463c 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java
@@ -41,6 +41,7 @@
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_CFG;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.paramAsConstant;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
+import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.TEXT_EMBEDDING_INFERENCE_ID;
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.loadMapping;
import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN;
import static org.elasticsearch.xpack.esql.core.type.DataType.CARTESIAN_POINT;
@@ -2675,6 +2676,53 @@ public void testSortInTimeSeries() {
and the first aggregation [STATS avg(network.connections)] is not allowed"""));
}
+ public void testTextEmbeddingFunctionInvalidQuery() {
+ assertThat(
+ error("from test | EVAL embedding = TEXT_EMBEDDING(null, ?)", defaultAnalyzer, TEXT_EMBEDDING_INFERENCE_ID),
+ equalTo("1:30: first argument of [TEXT_EMBEDDING(null, ?)] cannot be null, received [null]")
+ );
+
+ assertThat(
+ error("from test | EVAL embedding = TEXT_EMBEDDING(42, ?)", defaultAnalyzer, TEXT_EMBEDDING_INFERENCE_ID),
+ equalTo("1:30: first argument of [TEXT_EMBEDDING(42, ?)] must be [string], found value [42] type [integer]")
+ );
+
+ assertThat(
+ error("from test | EVAL embedding = TEXT_EMBEDDING(last_name, ?)", defaultAnalyzer, TEXT_EMBEDDING_INFERENCE_ID),
+ equalTo("1:30: first argument of [TEXT_EMBEDDING(last_name, ?)] must be a constant, received [last_name]")
+ );
+ }
+
+ public void testTextEmbeddingFunctionInvalidInferenceId() {
+ assertThat(
+ error("from test | EVAL embedding = TEXT_EMBEDDING(?, null)", defaultAnalyzer, "query text"),
+ equalTo("1:30: second argument of [TEXT_EMBEDDING(?, null)] cannot be null, received [null]")
+ );
+
+ assertThat(
+ error("from test | EVAL embedding = TEXT_EMBEDDING(?, 42)", defaultAnalyzer, "query text"),
+ equalTo("1:30: second argument of [TEXT_EMBEDDING(?, 42)] must be [string], found value [42] type [integer]")
+ );
+
+ assertThat(
+ error("from test | EVAL embedding = TEXT_EMBEDDING(?, last_name)", defaultAnalyzer, "query text"),
+ equalTo("1:30: second argument of [TEXT_EMBEDDING(?, last_name)] must be a constant, received [last_name]")
+ );
+ }
+
+ // public void testTextEmbeddingFunctionInvalidInferenceId() {
+ // assumeTrue("TEXT_EMBEDDING function required", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled());
+ //
+ // ParsingException ve = expectThrows(ParsingException.class, () -> analyze("""
+ // FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING("italian food recipe", CONCAT("machin", title))""",
+ // "mapping-books.json"));
+ //
+ // assertThat(
+ // ve.getMessage(),
+ // containsString(" error building [text_embedding]: function [text_embedding] expects exactly two arguments")
+ // );
+ // }
+
private void checkVectorFunctionsNullArgs(String functionInvocation) throws Exception {
query("from test | eval similarity = " + functionInvocation, fullTextAnalyzer);
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingErrorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingErrorTests.java
new file mode 100644
index 0000000000000..9af017bd5207f
--- /dev/null
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingErrorTests.java
@@ -0,0 +1,74 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.inference;
+
+import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase;
+import org.elasticsearch.xpack.esql.expression.function.ErrorsForCasesWithoutExamplesTestCase;
+import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
+import org.hamcrest.Matcher;
+import org.junit.Before;
+
+import java.util.List;
+import java.util.Locale;
+import java.util.Set;
+
+import static org.hamcrest.Matchers.equalTo;
+
+/** Tests error conditions and type validation for TEXT_EMBEDDING function. */
+public class TextEmbeddingErrorTests extends ErrorsForCasesWithoutExamplesTestCase {
+
+ @Before
+ public void checkCapability() {
+ assumeTrue("TEXT_EMBEDDING is not enabled", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled());
+ }
+
+ @Override
+ protected List cases() {
+ return paramsToSuppliers(TextEmbeddingTests.parameters());
+ }
+
+ @Override
+ protected Expression build(Source source, List args) {
+ return new TextEmbedding(source, args.get(0), args.get(1));
+ }
+
+ @Override
+ protected Matcher expectedTypeErrorMatcher(List> validPerPosition, List signature) {
+ return equalTo(typeErrorMessage(true, validPerPosition, signature, (v, p) -> "string"));
+ }
+
+ protected static String typeErrorMessage(
+ boolean includeOrdinal,
+ List> validPerPosition,
+ List signature,
+ AbstractFunctionTestCase.PositionalErrorMessageSupplier positionalErrorMessageSupplier
+ ) {
+ for (int i = 0; i < signature.size(); i++) {
+ if (signature.get(i) == DataType.NULL) {
+ String ordinal = includeOrdinal ? TypeResolutions.ParamOrdinal.fromIndex(i).name().toLowerCase(Locale.ROOT) + " " : "";
+ return ordinal + "argument of [" + sourceForSignature(signature) + "] cannot be null, received []";
+ }
+
+ if (validPerPosition.get(i).contains(signature.get(i)) == false) {
+ break;
+ }
+ }
+
+ return ErrorsForCasesWithoutExamplesTestCase.typeErrorMessage(
+ includeOrdinal,
+ validPerPosition,
+ signature,
+ positionalErrorMessageSupplier
+ );
+ }
+}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingSerializationTests.java
new file mode 100644
index 0000000000000..5d7e1dfa4301a
--- /dev/null
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingSerializationTests.java
@@ -0,0 +1,46 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.inference;
+
+import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests;
+import org.junit.Before;
+
+import java.io.IOException;
+
+/** Tests serialization/deserialization of TEXT_EMBEDDING function instances. */
+public class TextEmbeddingSerializationTests extends AbstractExpressionSerializationTests {
+
+ @Before
+ public void checkCapability() {
+ assumeTrue("TEXT_EMBEDDING is not enabled", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled());
+ }
+
+ @Override
+ protected TextEmbedding createTestInstance() {
+ Source source = randomSource();
+ Expression inputText = randomChild();
+ Expression inferenceId = randomChild();
+ return new TextEmbedding(source, inputText, inferenceId);
+ }
+
+ @Override
+ protected TextEmbedding mutateInstance(TextEmbedding instance) throws IOException {
+ Source source = instance.source();
+ Expression inputText = instance.inputText();
+ Expression inferenceId = instance.inferenceId();
+ if (randomBoolean()) {
+ inputText = randomValueOtherThan(inputText, AbstractExpressionSerializationTests::randomChild);
+ } else {
+ inferenceId = randomValueOtherThan(inferenceId, AbstractExpressionSerializationTests::randomChild);
+ }
+ return new TextEmbedding(source, inputText, inferenceId);
+ }
+}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingTests.java
new file mode 100644
index 0000000000000..b6fdc7addf984
--- /dev/null
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingTests.java
@@ -0,0 +1,72 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.inference;
+
+import com.carrotsearch.randomizedtesting.annotations.Name;
+import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
+
+import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase;
+import org.elasticsearch.xpack.esql.expression.function.FunctionName;
+import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
+import org.hamcrest.Matchers;
+import org.junit.Before;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.Supplier;
+
+import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
+import static org.hamcrest.Matchers.equalTo;
+
+@FunctionName("text_embedding")
+public class TextEmbeddingTests extends AbstractFunctionTestCase {
+ @Before
+ public void checkCapability() {
+ assumeTrue("TEXT_EMBEDDING is not enabled", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled());
+ }
+
+ public TextEmbeddingTests(@Name("TestCase") Supplier testCaseSupplier) {
+ this.testCase = testCaseSupplier.get();
+ }
+
+ @ParametersFactory
+ public static Iterable