diff --git a/docs/reference/query-languages/esql/images/functions/text_dense_vector_embedding.svg b/docs/reference/query-languages/esql/images/functions/text_dense_vector_embedding.svg
new file mode 100644
index 0000000000000..c628f0137ff91
--- /dev/null
+++ b/docs/reference/query-languages/esql/images/functions/text_dense_vector_embedding.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/text_dense_vector_embedding.json b/docs/reference/query-languages/esql/kibana/definition/functions/text_dense_vector_embedding.json
new file mode 100644
index 0000000000000..c64df7894f836
--- /dev/null
+++ b/docs/reference/query-languages/esql/kibana/definition/functions/text_dense_vector_embedding.json
@@ -0,0 +1,9 @@
+{
+ "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.",
+ "type" : "scalar",
+ "name" : "text_dense_vector_embedding",
+ "description" : "Embed input text into a dense vector representation using an inference model.",
+ "signatures" : [ ],
+ "preview" : true,
+ "snapshot_only" : true
+}
diff --git a/docs/reference/query-languages/esql/kibana/docs/functions/text_dense_vector_embedding.md b/docs/reference/query-languages/esql/kibana/docs/functions/text_dense_vector_embedding.md
new file mode 100644
index 0000000000000..7ee06e487fb0f
--- /dev/null
+++ b/docs/reference/query-languages/esql/kibana/docs/functions/text_dense_vector_embedding.md
@@ -0,0 +1,5 @@
+% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.
+
+### TEXT DENSE VECTOR EMBEDDING
+Embed input text into a dense vector representation using an inference model.
+
diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/MapExpression.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/MapExpression.java
index 24736ac3a2514..6682515e062c4 100644
--- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/MapExpression.java
+++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/MapExpression.java
@@ -120,10 +120,18 @@ public Expression get(Object key) {
return map.get(key);
} else {
// the key(literal) could be converted to BytesRef by ConvertStringToByteRef
- return keyFoldedMap.containsKey(key) ? keyFoldedMap.get(key) : keyFoldedMap.get(new BytesRef(key.toString()));
+ return keyFoldedMap.containsKey(key) ? keyFoldedMap.get(key) : keyFoldedMap.get(getKeyAsBytesRef(key));
}
}
+ public Expression getOrDefault(Object key, Expression defaultValue) {
+ return containsKey(key) ? get(key) : defaultValue;
+ }
+
+ public boolean containsKey(Object key) {
+ return keyFoldedMap.containsKey(key) || keyFoldedMap.containsKey(getKeyAsBytesRef(key));
+ }
+
@Override
public boolean equals(Object obj) {
if (this == obj) {
@@ -142,4 +150,8 @@ public String toString() {
String str = entryExpressions.stream().map(String::valueOf).collect(Collectors.joining(", "));
return "{ " + str + " }";
}
+
+ private BytesRef getKeyAsBytesRef(Object key) {
+ return new BytesRef(key.toString());
+ }
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
index 4dcab5f0c9270..2855c2107b5de 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
@@ -1197,6 +1197,11 @@ public enum Cap {
*/
KNN_FUNCTION_V2(Build.current().isSnapshot()),
+ /**
+ * Support for dense vector embedding function
+ */
+ DENSE_VECTOR_EMBEDDING_FUNCTION(Build.current().isSnapshot()),
+
LIKE_WITH_LIST_OF_PATTERNS,
/**
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java
index f48c95397dcab..31c7efe719c2a 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java
@@ -97,6 +97,7 @@
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
+import org.elasticsearch.xpack.esql.plan.logical.inference.embedding.DenseVectorEmbedding;
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
import org.elasticsearch.xpack.esql.plan.logical.join.JoinConfig;
import org.elasticsearch.xpack.esql.plan.logical.join.JoinType;
@@ -138,6 +139,7 @@
import static org.elasticsearch.xpack.esql.core.type.DataType.DATETIME;
import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_NANOS;
import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_PERIOD;
+import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE;
import static org.elasticsearch.xpack.esql.core.type.DataType.FLOAT;
import static org.elasticsearch.xpack.esql.core.type.DataType.GEO_POINT;
@@ -406,7 +408,7 @@ protected LogicalPlan rule(InferencePlan> plan, AnalyzerContext context) {
ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId);
if (resolvedInference != null && resolvedInference.taskType() == plan.taskType()) {
- return plan;
+ return plan.withModelConfigurations(resolvedInference.modelConfigurations());
} else if (resolvedInference != null) {
String error = "cannot use inference endpoint ["
+ inferenceId
@@ -516,6 +518,10 @@ protected LogicalPlan rule(LogicalPlan plan, AnalyzerContext context) {
return resolveEval(p, childrenOutput);
}
+ if (plan instanceof DenseVectorEmbedding dve) {
+ return resolveDenseVectorEmbedding(dve, childrenOutput);
+ }
+
if (plan instanceof Enrich p) {
return resolveEnrich(p, childrenOutput);
}
@@ -820,6 +826,28 @@ private LogicalPlan resolveFork(Fork fork, AnalyzerContext context) {
return changed ? new Fork(fork.source(), newSubPlans, newOutput) : fork;
}
+ private LogicalPlan resolveDenseVectorEmbedding(DenseVectorEmbedding p, List childrenOutput) {
+ // Resolve the input expression
+ Expression input = p.input();
+ if (input.resolved() == false) {
+ input = input.transformUp(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput));
+ }
+
+ // Resolve the target field (similar to Completion)
+ Attribute targetField = p.embeddingField();
+ if (targetField instanceof UnresolvedAttribute ua) {
+ targetField = new ReferenceAttribute(ua.source(), ua.name(), DENSE_VECTOR);
+ }
+
+ // Create a new DenseVectorEmbedding with resolved expressions
+ // Only create a new instance if something changed to avoid unnecessary object creation
+ if (input != p.input() || targetField != p.embeddingField()) {
+ return p.withTargetField(targetField);
+ }
+
+ return p;
+ }
+
private LogicalPlan resolveRerank(Rerank rerank, List childrenOutput) {
List newFields = new ArrayList<>();
boolean changed = false;
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzer.java
index 5b9f41876d6e1..f50175be653a5 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzer.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzer.java
@@ -7,7 +7,9 @@
package org.elasticsearch.xpack.esql.analysis;
+import org.elasticsearch.common.lucene.BytesRefs;
import org.elasticsearch.index.IndexMode;
+import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.util.Holder;
import org.elasticsearch.xpack.esql.plan.IndexPattern;
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
@@ -21,6 +23,7 @@
import java.util.Set;
import static java.util.Collections.emptyList;
+import static java.util.Collections.emptySet;
/**
* This class is part of the planner. Acts somewhat like a linker, to find the indices and enrich policies referenced by the query.
@@ -28,25 +31,25 @@
public class PreAnalyzer {
public static class PreAnalysis {
- public static final PreAnalysis EMPTY = new PreAnalysis(null, emptyList(), emptyList(), emptyList(), emptyList());
+ public static final PreAnalysis EMPTY = new PreAnalysis(null, emptyList(), emptyList(), emptySet(), emptyList());
public final IndexMode indexMode;
public final List indices;
public final List enriches;
- public final List> inferencePlans;
+ public final Set inferenceIds;
public final List lookupIndices;
public PreAnalysis(
IndexMode indexMode,
List indices,
List enriches,
- List> inferencePlans,
+ Set inferenceIds,
List lookupIndices
) {
this.indexMode = indexMode;
this.indices = indices;
this.enriches = enriches;
- this.inferencePlans = inferencePlans;
+ this.inferenceIds = inferenceIds;
this.lookupIndices = lookupIndices;
}
}
@@ -64,7 +67,7 @@ protected PreAnalysis doPreAnalyze(LogicalPlan plan) {
List unresolvedEnriches = new ArrayList<>();
List lookupIndices = new ArrayList<>();
- List> unresolvedInferencePlans = new ArrayList<>();
+ Set unresolvedInferenceIds = new HashSet<>();
Holder indexMode = new Holder<>();
plan.forEachUp(UnresolvedRelation.class, p -> {
if (p.indexMode() == IndexMode.LOOKUP) {
@@ -78,11 +81,28 @@ protected PreAnalysis doPreAnalyze(LogicalPlan plan) {
});
plan.forEachUp(Enrich.class, unresolvedEnriches::add);
- plan.forEachUp(InferencePlan.class, unresolvedInferencePlans::add);
// mark plan as preAnalyzed (if it were marked, there would be no analysis)
plan.forEachUp(LogicalPlan::setPreAnalyzed);
- return new PreAnalysis(indexMode.get(), indices.stream().toList(), unresolvedEnriches, unresolvedInferencePlans, lookupIndices);
+ return new PreAnalysis(indexMode.get(), indices.stream().toList(), unresolvedEnriches, inferenceIds(plan), lookupIndices);
+ }
+
+ protected Set inferenceIds(LogicalPlan plan) {
+ Set inferenceIds = new HashSet<>();
+
+ List> inferencePlans = new ArrayList<>();
+ plan.forEachUp(InferencePlan.class, inferencePlans::add);
+ inferencePlans.stream().map(this::inferenceId).forEach(inferenceIds::add);
+
+ return inferenceIds;
+ }
+
+ private String inferenceId(InferencePlan> inferencePlan) {
+ if (inferencePlan.inferenceId() instanceof Literal literal) {
+ return BytesRefs.toString(literal.value());
+ }
+
+ throw new IllegalStateException("inferenceId is not a literal");
}
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java
index a3f6d3a089d49..4b2e6224f3431 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java
@@ -13,6 +13,7 @@
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateWritables;
import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextWritables;
+import org.elasticsearch.xpack.esql.expression.function.inference.DenseVectorEmbeddingFunction;
import org.elasticsearch.xpack.esql.expression.function.scalar.ScalarFunctionWritables;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FromBase64;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToAggregateMetricDouble;
@@ -119,6 +120,7 @@ public static List getNamedWriteables() {
entries.addAll(fullText());
entries.addAll(unaryScalars());
entries.addAll(vector());
+ entries.addAll(inference());
return entries;
}
@@ -264,4 +266,11 @@ private static List vector() {
}
return List.of();
}
+
+ private static List inference() {
+ if (EsqlCapabilities.Cap.DENSE_VECTOR_EMBEDDING_FUNCTION.isEnabled()) {
+ return List.of(DenseVectorEmbeddingFunction.ENTRY);
+ }
+ return List.of();
+ }
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java
index 630c9c2008a13..d78fc713f7f95 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java
@@ -52,6 +52,7 @@
import org.elasticsearch.xpack.esql.expression.function.fulltext.Term;
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
+import org.elasticsearch.xpack.esql.expression.function.inference.DenseVectorEmbeddingFunction;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Least;
@@ -479,6 +480,7 @@ private static FunctionDefinition[][] snapshotFunctions() {
def(FirstOverTime.class, uni(FirstOverTime::new), "first_over_time"),
def(Term.class, bi(Term::new), "term"),
def(Knn.class, Knn::new, "knn"),
+ def(DenseVectorEmbeddingFunction.class, bi(DenseVectorEmbeddingFunction::new), "text_dense_vector_embedding"),
def(StGeohash.class, StGeohash::new, "st_geohash"),
def(StGeohashToLong.class, StGeohashToLong::new, "st_geohash_to_long"),
def(StGeohashToString.class, StGeohashToString::new, "st_geohash_to_string"),
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/DenseVectorEmbeddingFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/DenseVectorEmbeddingFunction.java
new file mode 100644
index 0000000000000..0f5ac162ac283
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/DenseVectorEmbeddingFunction.java
@@ -0,0 +1,152 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.inference;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.xpack.esql.core.expression.Attribute;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.Literal;
+import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
+import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
+import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo;
+import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle;
+import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
+import org.elasticsearch.xpack.esql.expression.function.MapParam;
+import org.elasticsearch.xpack.esql.expression.function.Param;
+import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Objects;
+import java.util.UUID;
+
+/**
+ * * A function that embeds input text into a dense vector representation using an inference model.
+ */
+public class DenseVectorEmbeddingFunction extends InferenceFunction {
+
+ public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
+ Expression.class,
+ "TextDenseVectorEmbedding",
+ DenseVectorEmbeddingFunction::new
+ );
+
+ private final Expression inputText;
+ private final Attribute tmpAttribute;
+
+ @FunctionInfo(
+ returnType = "dense_vector",
+ preview = true,
+ description = "Embed input text into a dense vector representation using an inference model.",
+ appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) }
+ )
+ public DenseVectorEmbeddingFunction(
+ Source source,
+ @Param(name = "inputText", type = { "keyword", "text" }, description = "Input text") Expression inputText,
+ @MapParam(
+ name = "options",
+ params = { @MapParam.MapParamEntry(name = "inference_id", type = "keyword", description = "Inference endpoint to use.") },
+ optional = true
+ ) Expression options
+ ) {
+ this(source, inputText, options, new ReferenceAttribute(Source.EMPTY, ENTRY.name + "_" + UUID.randomUUID(), DataType.DOUBLE));
+ }
+
+ private DenseVectorEmbeddingFunction(Source source, Expression inputText, Expression options, Attribute tmpAttribute) {
+ super(source, List.of(inputText, tmpAttribute), options);
+ this.inputText = inputText;
+ this.tmpAttribute = tmpAttribute;
+ }
+
+ public DenseVectorEmbeddingFunction(StreamInput in) throws IOException {
+ this(
+ Source.readFrom((PlanStreamInput) in),
+ in.readNamedWriteable(Expression.class),
+ in.readNamedWriteable(Expression.class),
+ in.readNamedWriteable(Attribute.class)
+ );
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ source().writeTo(out);
+ out.writeNamedWriteable(inputText);
+ out.writeNamedWriteable(options());
+ out.writeNamedWriteable(tmpAttribute);
+ }
+
+ @Override
+ public String functionName() {
+ super.functionName();
+ return getWriteableName();
+ }
+
+ @Override
+ public DataType dataType() {
+ return DataType.DENSE_VECTOR;
+ }
+
+ @Override
+ public DenseVectorEmbeddingFunction replaceChildren(List newChildren) {
+ return new DenseVectorEmbeddingFunction(
+ source(),
+ newChildren.get(0),
+ newChildren.size() > 1 ? newChildren.get(1) : null,
+ tmpAttribute
+ );
+ }
+
+ @Override
+ protected NodeInfo extends Expression> info() {
+ return NodeInfo.create(this, DenseVectorEmbeddingFunction::new, inputText, options(), tmpAttribute);
+ }
+
+ @Override
+ public String getWriteableName() {
+ return ENTRY.name;
+ }
+
+ @Override
+ protected Literal defaultInferenceId() {
+ return Literal.NULL;
+ }
+
+ @Override
+ public List temporaryAttributes() {
+ return List.of(tmpAttribute);
+ }
+
+ @Override
+ protected TypeResolution resolveParams() {
+ return TypeResolutions.isString(inputText, sourceText(), TypeResolutions.ParamOrdinal.FIRST);
+ }
+
+ @Override
+ protected TypeResolutions.ParamOrdinal optionsParamsOrdinal() {
+ return TypeResolutions.ParamOrdinal.SECOND;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ if (super.equals(o) == false) return false;
+ DenseVectorEmbeddingFunction that = (DenseVectorEmbeddingFunction) o;
+ return Objects.equals(inputText, that.inputText) && Objects.equals(tmpAttribute, that.tmpAttribute);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(super.hashCode(), inputText, tmpAttribute);
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java
new file mode 100644
index 0000000000000..4bbf32744bd7e
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java
@@ -0,0 +1,192 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.inference;
+
+import org.elasticsearch.xpack.esql.core.expression.Attribute;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.MapExpression;
+import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
+import org.elasticsearch.xpack.esql.core.expression.function.Function;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.expression.function.OptionalArgument;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.function.Supplier;
+import java.util.stream.Stream;
+
+/**
+ * Base class for ESQL functions that perform inference using an `inference_id` and optional parameters.
+ */
+public abstract class InferenceFunction extends Function implements OptionalArgument {
+ public static final String INFERENCE_ID_OPTION_NAME = "inference_id";
+
+ public static final List DEFAULT_OPTIONAL_ARGUMENTS_VALIDATORS = List.of(
+ new InferenceIdOptionalArgumentsValidator()
+ );
+
+ private final Expression inferenceId;
+ private final Expression options;
+
+ @SuppressWarnings("this-escape")
+ protected InferenceFunction(Source source, List children, Expression options) {
+ super(source, Stream.concat(children.stream(), Stream.of(options)).toList());
+ this.inferenceId = parseInferenceId(options, this::defaultInferenceId);
+ this.options = options;
+ }
+
+ /**
+ * Returns the expression representing the {@code inference_id} used by the function.
+ *
+ * @return the inference ID expression
+ */
+ public Expression inferenceId() {
+ return inferenceId;
+ }
+
+ /**
+ * Returns the expression representing the options passed to the function.
+ *
+ * @return the options expression
+ */
+ public Expression options() {
+ return options;
+ }
+
+ @Override
+ protected TypeResolution resolveType() {
+ if (childrenResolved() == false) {
+ return new TypeResolution("Unresolved children");
+ }
+
+ return resolveParams().and(resolveOptions());
+ }
+
+ /**
+ * Returns the default inference ID expression to use when no {@code inference_id}
+ * is specified in the options.
+ *
+ * @return the default inference ID expression
+ */
+ protected abstract Expression defaultInferenceId();
+
+ /**
+ * When an inference function is resolved it is replaced with a temporary attributes that in an ad-hoc inference command.
+ * These attributes need to be cleansed once they are not used anymore.
+ *
+ * @return the list of temporary attributes
+ */
+ public abstract List temporaryAttributes();
+
+ /**
+ * Resolves the types of the core parameters passed to this function.
+ *
+ * @return the result of parameter type resolution
+ */
+ protected abstract TypeResolution resolveParams();
+
+ /**
+ * Return the param ordinal of the optional arguments parameters.
+ *
+ * @return the result of option type resolution
+ */
+ protected abstract TypeResolutions.ParamOrdinal optionsParamsOrdinal();
+
+ @Override
+ public boolean equals(Object o) {
+ if (o == null || getClass() != o.getClass()) return false;
+ if (super.equals(o) == false) return false;
+ InferenceFunction that = (InferenceFunction) o;
+ return Objects.equals(options, that.options);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(super.hashCode(), options);
+ }
+
+ protected TypeResolution resolveOptions() {
+ TypeResolution resolution = TypeResolutions.isMapExpression(options(), sourceText(), optionsParamsOrdinal());
+ if (resolution.unresolved()) {
+ return resolution;
+ }
+
+ MapExpression options = (MapExpression) options();
+ for (Map.Entry optionEntry : options.keyFoldedMap().entrySet()) {
+ for (OptionalArgumentsValidator validator : optionalArgumentsValidators()) {
+ if (validator.applyTo(optionEntry.getKey(), optionEntry.getValue())) {
+ TypeResolution optionResolution = validator.resolveOptionValue(
+ optionEntry.getKey(),
+ optionEntry.getValue(),
+ optionsParamsOrdinal()
+ );
+ if (optionResolution.unresolved()) {
+ return optionResolution;
+ }
+ break;
+ }
+ }
+ }
+
+ return TypeResolution.TYPE_RESOLVED;
+ }
+
+ protected List optionalArgumentsValidators() {
+ return DEFAULT_OPTIONAL_ARGUMENTS_VALIDATORS;
+ }
+
+ /**
+ * Extracts the {@code inference_id} expression from the options.
+ * Fallback to the provided inference id if the option is missing.
+ *
+ * @param options the options map expression
+ * @param defaultInferenceIdSupplier the supplier for the default inference ID
+ * @return the resolved inference ID expression
+ */
+ private static Expression parseInferenceId(Expression options, Supplier defaultInferenceIdSupplier) {
+ return readOption("inference_id", options, defaultInferenceIdSupplier);
+ }
+
+ /**
+ * Reads an option value from a map expression with a fallback to a default value.
+ *
+ * @param optionName the name of the option to retrieve
+ * @param options the map expression containing options
+ * @param defaultValueSupplier the supplier of the default value
+ * @return the option value as an expression or the default if not present
+ */
+ private static Expression readOption(String optionName, Expression options, Supplier defaultValueSupplier) {
+ if (options != null && options.dataType() != DataType.NULL && options instanceof MapExpression mapOptions) {
+ return mapOptions.getOrDefault(optionName, defaultValueSupplier.get());
+ }
+
+ return defaultValueSupplier.get();
+ }
+
+ public interface OptionalArgumentsValidator {
+ boolean applyTo(String optionName, Expression optionValue);
+
+ TypeResolution resolveOptionValue(String optionName, Expression optionValue, TypeResolutions.ParamOrdinal paramOrdinal);
+ }
+
+ public static class InferenceIdOptionalArgumentsValidator implements OptionalArgumentsValidator {
+ private InferenceIdOptionalArgumentsValidator() {}
+
+ public boolean applyTo(String optionName, Expression optionValue) {
+ return optionName.equals(INFERENCE_ID_OPTION_NAME);
+ }
+
+ public TypeResolution resolveOptionValue(String optionName, Expression optionValue, TypeResolutions.ParamOrdinal paramOrdinal) {
+ return TypeResolutions.isString(optionValue, optionName, paramOrdinal)
+ .and(TypeResolutions.isNotNull(optionValue, optionName, paramOrdinal))
+ .and(TypeResolutions.isFoldable(optionValue, optionName, paramOrdinal));
+ }
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceRunner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceRunner.java
index d67d6817742c0..c1696257edd00 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceRunner.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceRunner.java
@@ -18,9 +18,7 @@
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
-import java.util.List;
import java.util.Set;
-import java.util.stream.Collectors;
import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
@@ -39,12 +37,7 @@ public ThreadPool threadPool() {
return threadPool;
}
- public void resolveInferenceIds(List> plans, ActionListener listener) {
- resolveInferenceIds(plans.stream().map(InferenceRunner::planInferenceId).collect(Collectors.toSet()), listener);
-
- }
-
- private void resolveInferenceIds(Set inferenceIds, ActionListener listener) {
+ public void resolveInferenceIds(Set inferenceIds, ActionListener listener) {
if (inferenceIds.isEmpty()) {
listener.onResponse(InferenceResolution.EMPTY);
@@ -63,7 +56,7 @@ private void resolveInferenceIds(Set inferenceIds, ActionListener {
- ResolvedInference resolvedInference = new ResolvedInference(inferenceId, r.getEndpoints().getFirst().getTaskType());
+ ResolvedInference resolvedInference = new ResolvedInference(inferenceId, r.getEndpoints().getFirst());
inferenceResolutionBuilder.withResolvedInference(resolvedInference);
countdownListener.onResponse(null);
}, e -> {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/ResolvedInference.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/ResolvedInference.java
index 455ed6488379a..0bf876d6ff8ed 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/ResolvedInference.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/ResolvedInference.java
@@ -7,22 +7,12 @@
package org.elasticsearch.xpack.esql.inference;
-import org.elasticsearch.common.io.stream.StreamInput;
-import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskType;
-import java.io.IOException;
+public record ResolvedInference(String inferenceId, ModelConfigurations modelConfigurations) {
-public record ResolvedInference(String inferenceId, TaskType taskType) implements Writeable {
-
- public ResolvedInference(StreamInput in) throws IOException {
- this(in.readString(), TaskType.valueOf(in.readString()));
- }
-
- @Override
- public void writeTo(StreamOutput out) throws IOException {
- out.writeString(inferenceId);
- out.writeString(taskType.name());
+ public TaskType taskType() {
+ return modelConfigurations.getTaskType();
}
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperator.java
new file mode 100644
index 0000000000000..d3c1e471891a1
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperator.java
@@ -0,0 +1,125 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.inference.embedding;
+
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.FloatBlock;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.compute.operator.DriverContext;
+import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
+import org.elasticsearch.compute.operator.Operator;
+import org.elasticsearch.core.Releasables;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xpack.esql.inference.InferenceOperator;
+import org.elasticsearch.xpack.esql.inference.InferenceRunner;
+import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutionConfig;
+
+import java.util.stream.IntStream;
+
+/**
+ * {@link DenseEmbeddingOperator} is an inference operator that compute vector embeddings from textual data .
+ */
+public class DenseEmbeddingOperator extends InferenceOperator {
+
+ // Default number of rows to include per inference request
+ private static final int DEFAULT_BATCH_SIZE = 20;
+
+ // Encodes each input row into a string representation for the model
+ private final ExpressionEvaluator inputEvaluator;
+
+ // Numbers of dimensions for the vector
+ private final int dimensions;
+
+ // Batch size used to group rows into a single inference request (currently fixed)
+ // TODO: make it configurable either in the command or as query pragmas
+ private final int batchSize = DEFAULT_BATCH_SIZE;
+
+ public DenseEmbeddingOperator(
+ DriverContext driverContext,
+ InferenceRunner inferenceRunner,
+ ThreadPool threadPool,
+ String inferenceId,
+ int dimensions,
+ ExpressionEvaluator inputEvaluator
+ ) {
+ super(driverContext, inferenceRunner, BulkInferenceExecutionConfig.DEFAULT, threadPool, inferenceId);
+ this.dimensions = dimensions;
+ this.inputEvaluator = inputEvaluator;
+ }
+
+ @Override
+ public void addInput(Page input) {
+ try {
+ Block inputBlock = inputEvaluator.eval(input);
+ super.addInput(input.appendBlock(inputBlock));
+ } catch (Exception e) {
+ releasePageOnAnyThread(input);
+ throw e;
+ }
+ }
+
+ @Override
+ protected void doClose() {
+ Releasables.close(inputEvaluator);
+ }
+
+ @Override
+ public String toString() {
+ return "DenseEmbeddingOperator[inference_id=[" + inferenceId() + "]]";
+ }
+
+ /**
+ * Returns the request iterator responsible for batching and converting input rows into inference requests.
+ */
+ @Override
+ protected DenseEmbeddingOperatorRequestIterator requests(Page inputPage) {
+ int inputBlockChannel = inputPage.getBlockCount() - 1;
+ return new DenseEmbeddingOperatorRequestIterator(inputPage.getBlock(inputBlockChannel), inferenceId(), batchSize);
+ }
+
+ /**
+ * Returns the output builder responsible for collecting inference responses and building the output page.
+ */
+ @Override
+ protected DenseEmbeddingOperatorOutputBuilder outputBuilder(Page input) {
+ FloatBlock.Builder outputBlockBuilder = blockFactory().newFloatBlockBuilder(input.getPositionCount() * dimensions);
+ return new DenseEmbeddingOperatorOutputBuilder(
+ outputBlockBuilder,
+ input.projectBlocks(IntStream.range(0, input.getBlockCount() - 1).toArray()),
+ dimensions
+ );
+ }
+
+ /**
+ * Factory for creating {@link DenseEmbeddingOperator} instances
+ */
+ public record Factory(
+ InferenceRunner inferenceRunner,
+ String inferenceId,
+ int dimensions,
+ ExpressionEvaluator.Factory inputEvaluatorFactory
+ ) implements OperatorFactory {
+
+ @Override
+ public String describe() {
+ return "DenseEmbeddingOperator[inference_id=[" + inferenceId + "]]";
+ }
+
+ @Override
+ public Operator get(DriverContext driverContext) {
+ return new DenseEmbeddingOperator(
+ driverContext,
+ inferenceRunner,
+ inferenceRunner.threadPool(),
+ inferenceId,
+ dimensions,
+ inputEvaluatorFactory().get(driverContext)
+ );
+ }
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorOutputBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorOutputBuilder.java
new file mode 100644
index 0000000000000..de262cb4155f6
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorOutputBuilder.java
@@ -0,0 +1,122 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.inference.embedding;
+
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.FloatBlock;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.core.Releasables;
+import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.results.EmbeddingResults;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
+import org.elasticsearch.xpack.esql.inference.InferenceOperator;
+
+import java.util.Iterator;
+import java.util.stream.IntStream;
+
+/**
+ * Builds the output page for the {@link DenseEmbeddingOperator}.
+ */
+public class DenseEmbeddingOperatorOutputBuilder implements InferenceOperator.OutputBuilder {
+
+ private final FloatBlock.Builder outputBlockBuilder;
+ private final Page inputPage;
+ private final int dimensions;
+
+ public DenseEmbeddingOperatorOutputBuilder(FloatBlock.Builder outputBlockBuilder, Page inputPage, int dimensions) {
+ this.outputBlockBuilder = outputBlockBuilder;
+ this.inputPage = inputPage;
+ this.dimensions = dimensions;
+ }
+
+ @Override
+ public void close() {
+ Releasables.close(outputBlockBuilder);
+ releasePageOnAnyThread(inputPage);
+ }
+
+ /**
+ * Constructs a new output {@link Page} with dense embedding in the last column.
+ */
+ @Override
+ public Page buildOutput() {
+ Block outputBlock = outputBlockBuilder.build();
+ assert outputBlock.getPositionCount() == inputPage.getPositionCount();
+ return inputPage.shallowCopy().appendBlock(outputBlock);
+ }
+
+ /**
+ * Extracts the embedding results from the inference response and append them to the output block builder.
+ *
+ * If the response is not of type {@link TextEmbeddingResults} an {@link IllegalStateException} is thrown.
+ *
+ *
+ * The responses must be added in the same order as the corresponding inference requests were generated.
+ * Failing to preserve order may lead to incorrect or misaligned output rows.
+ *
+ */
+ @Override
+ public void addInferenceResponse(InferenceAction.Response inferenceResponse) {
+ EmbeddingValueReader embeddingValueReader = EmbeddingValueReader.of(inferenceResponse, dimensions);
+ while (embeddingValueReader.hasNext()) {
+ writeEmbeddings(embeddingValueReader.next());
+ }
+ }
+
+ private void writeEmbeddings(float[] values) {
+ outputBlockBuilder.beginPositionEntry();
+ for (float value : values) {
+ outputBlockBuilder.appendFloat(value);
+ }
+ outputBlockBuilder.endPositionEntry();
+ }
+
+ private static class EmbeddingValueReader implements Iterator {
+ private final int dimensions;
+
+ private final Iterator extends EmbeddingResults.Embedding>> embeddingsIterator;
+
+ private EmbeddingValueReader(Iterator extends EmbeddingResults.Embedding>> embeddingsIterator, int dimensions) {
+ this.dimensions = dimensions;
+ this.embeddingsIterator = embeddingsIterator;
+ }
+
+ public boolean hasNext() {
+ return embeddingsIterator.hasNext();
+ }
+
+ public float[] next() {
+ EmbeddingResults.Embedding> embedding = embeddingsIterator.next();
+ float[] values = switch (embedding) {
+ case TextEmbeddingFloatResults.Embedding textEmbeddingFloat -> textEmbeddingFloat.values();
+ case TextEmbeddingByteResults.Embedding textEmbeddingBytes -> toFloatArray(textEmbeddingBytes.values());
+ default -> throw new IllegalStateException("Unsupported embedding type [" + embedding.getClass() + "]");
+ };
+
+ assert values.length == dimensions : "Unexpected vector size: " + values.length;
+
+ return values;
+ }
+
+ private static float[] toFloatArray(byte[] bytes) {
+ float[] floatValues = new float[bytes.length];
+ IntStream.range(0, floatValues.length).forEach(i -> floatValues[i] = ((Byte) bytes[i]).floatValue());
+ return floatValues;
+ }
+
+ public static EmbeddingValueReader of(InferenceAction.Response inferenceResponse, int dimensions) {
+ TextEmbeddingResults> inferenceResults = InferenceOperator.OutputBuilder.inferenceResults(
+ inferenceResponse,
+ TextEmbeddingResults.class
+ );
+ return new EmbeddingValueReader(inferenceResults.embeddings().iterator(), dimensions);
+ }
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorRequestIterator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorRequestIterator.java
new file mode 100644
index 0000000000000..e1f3a3c4ebec1
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/embedding/DenseEmbeddingOperatorRequestIterator.java
@@ -0,0 +1,83 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.inference.embedding;
+
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.common.lucene.BytesRefs;
+import org.elasticsearch.compute.data.BytesRefBlock;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.NoSuchElementException;
+
+/**
+ * Iterator over input data blocks to create batched inference requests for the dense vector text embedding task.
+ *
+ *
This iterator reads from a {@link BytesRefBlock} containing text to be embedded. It slices the input into batches
+ * of configurable size and converts each batch into an {@link InferenceAction.Request} with the task type {@link TaskType#TEXT_EMBEDDING}.
+ */
+public class DenseEmbeddingOperatorRequestIterator implements BulkInferenceRequestIterator {
+ private final BytesRefBlock inputBlock;
+ private final String inferenceId;
+ private final int batchSize;
+ private int remainingPositions;
+
+ public DenseEmbeddingOperatorRequestIterator(BytesRefBlock inputBlock, String inferenceId, int batchSize) {
+ this.inputBlock = inputBlock;
+ this.inferenceId = inferenceId;
+ this.batchSize = batchSize;
+ this.remainingPositions = inputBlock.getPositionCount();
+ }
+
+ @Override
+ public boolean hasNext() {
+ return remainingPositions > 0;
+ }
+
+ @Override
+ public InferenceAction.Request next() {
+ if (hasNext() == false) {
+ throw new NoSuchElementException();
+ }
+
+ final int inputSize = Math.min(remainingPositions, batchSize);
+ final List inputs = new ArrayList<>(inputSize);
+ BytesRef scratch = new BytesRef();
+
+ int startIndex = inputBlock.getPositionCount() - remainingPositions;
+ for (int i = 0; i < inputSize; i++) {
+ int pos = startIndex + i;
+ if (inputBlock.isNull(pos)) {
+ inputs.add("");
+ } else {
+ scratch = inputBlock.getBytesRef(inputBlock.getFirstValueIndex(pos), scratch);
+ inputs.add(BytesRefs.toString(scratch));
+ }
+ }
+
+ remainingPositions -= inputSize;
+ return inferenceRequest(inputs);
+ }
+
+ @Override
+ public int estimatedSize() {
+ return inputBlock.getPositionCount();
+ }
+
+ private InferenceAction.Request inferenceRequest(List inputs) {
+ return InferenceAction.Request.builder(inferenceId, TaskType.TEXT_EMBEDDING).setInput(inputs).build();
+ }
+
+ @Override
+ public void close() {
+
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java
index 2fe9f5182ae00..940d8adda546c 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java
@@ -26,6 +26,7 @@
import org.elasticsearch.xpack.esql.plan.logical.TopN;
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
+import org.elasticsearch.xpack.esql.plan.logical.inference.embedding.DenseVectorEmbedding;
import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin;
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject;
@@ -55,6 +56,7 @@
import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
import org.elasticsearch.xpack.esql.plan.physical.inference.CompletionExec;
import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec;
+import org.elasticsearch.xpack.esql.plan.physical.inference.embedding.DenseVectorEmbeddingExec;
import java.util.ArrayList;
import java.util.List;
@@ -72,6 +74,7 @@ public static List logical() {
return List.of(
Aggregate.ENTRY,
Completion.ENTRY,
+ DenseVectorEmbedding.ENTRY,
Dissect.ENTRY,
Enrich.ENTRY,
EsRelation.ENTRY,
@@ -99,6 +102,7 @@ public static List physical() {
return List.of(
AggregateExec.ENTRY,
CompletionExec.ENTRY,
+ DenseVectorEmbeddingExec.ENTRY,
DissectExec.ENTRY,
EnrichExec.ENTRY,
EsQueryExec.ENTRY,
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java
index 620e8726865d6..1ad945dffa4a8 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/InferencePlan.java
@@ -8,6 +8,7 @@
package org.elasticsearch.xpack.esql.plan.logical.inference;
import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
@@ -69,4 +70,9 @@ public int hashCode() {
public PlanType withInferenceResolutionError(String inferenceId, String error) {
return withInferenceId(new UnresolvedAttribute(inferenceId().source(), inferenceId, error));
}
+
+ @SuppressWarnings("unchecked")
+ public PlanType withModelConfigurations(ModelConfigurations modelConfig) {
+ return (PlanType) this;
+ }
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/embedding/DenseVectorEmbedding.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/embedding/DenseVectorEmbedding.java
new file mode 100644
index 0000000000000..c654bcff590be
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/embedding/DenseVectorEmbedding.java
@@ -0,0 +1,210 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.plan.logical.inference.embedding;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.lucene.BytesRefs;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.xpack.esql.capabilities.TelemetryAware;
+import org.elasticsearch.xpack.esql.core.capabilities.Unresolvable;
+import org.elasticsearch.xpack.esql.core.expression.Attribute;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
+import org.elasticsearch.xpack.esql.core.expression.Literal;
+import org.elasticsearch.xpack.esql.core.expression.NameId;
+import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
+import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
+import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes;
+
+public class DenseVectorEmbedding extends InferencePlan implements TelemetryAware {
+
+ public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
+ LogicalPlan.class,
+ "DenseVectorEmbedding",
+ DenseVectorEmbedding::new
+ );
+
+ private final Expression input;
+ private final Expression dimensions;
+ private final Attribute targetField;
+ private List lazyOutput;
+
+ public DenseVectorEmbedding(Source source, LogicalPlan child, Expression inferenceId, Expression input, Attribute targetField) {
+ this(source, child, inferenceId, new UnresolvedDimensions(inferenceId), input, targetField);
+ }
+
+ DenseVectorEmbedding(
+ Source source,
+ LogicalPlan child,
+ Expression inferenceId,
+ Expression dimensions,
+ Expression input,
+ Attribute targetField
+ ) {
+ super(source, child, inferenceId);
+ this.input = input;
+ this.targetField = targetField;
+ this.dimensions = dimensions;
+ }
+
+ public DenseVectorEmbedding(StreamInput in) throws IOException {
+ this(
+ Source.readFrom((PlanStreamInput) in),
+ in.readNamedWriteable(LogicalPlan.class),
+ in.readNamedWriteable(Expression.class),
+ in.readNamedWriteable(Expression.class),
+ in.readNamedWriteable(Expression.class),
+ in.readNamedWriteable(Attribute.class)
+ );
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ source().writeTo(out);
+ out.writeNamedWriteable(child());
+ out.writeNamedWriteable(inferenceId());
+ out.writeNamedWriteable(dimensions);
+ out.writeNamedWriteable(input);
+ out.writeNamedWriteable(targetField);
+ }
+
+ public Expression input() {
+ return input;
+ }
+
+ public Attribute embeddingField() {
+ return targetField;
+ }
+
+ @Override
+ public TaskType taskType() {
+ return TaskType.TEXT_EMBEDDING;
+ }
+
+ public Expression dimensions() {
+ return dimensions;
+ }
+
+ @Override
+ public String getWriteableName() {
+ return ENTRY.name;
+ }
+
+ @Override
+ public List output() {
+ if (lazyOutput == null) {
+ lazyOutput = mergeOutputAttributes(List.of(targetField), child().output());
+ }
+ return lazyOutput;
+ }
+
+ @Override
+ public List generatedAttributes() {
+ return List.of(targetField);
+ }
+
+ @Override
+ public DenseVectorEmbedding withGeneratedNames(List newNames) {
+ checkNumberOfNewNames(newNames);
+ return new DenseVectorEmbedding(source(), child(), inferenceId(), dimensions, input, this.renameTargetField(newNames.get(0)));
+ }
+
+ private Attribute renameTargetField(String newName) {
+ if (newName.equals(targetField.name())) {
+ return targetField;
+ }
+
+ return targetField.withName(newName).withId(new NameId());
+ }
+
+ @Override
+ public boolean expressionsResolved() {
+ return super.expressionsResolved() && input.resolved() && targetField.resolved() && dimensions.resolved();
+ }
+
+ @Override
+ public DenseVectorEmbedding withInferenceId(Expression newInferenceId) {
+ return new DenseVectorEmbedding(source(), child(), newInferenceId, dimensions, input, targetField);
+ }
+
+ public DenseVectorEmbedding withDimensions(Expression newDimensions) {
+ return new DenseVectorEmbedding(source(), child(), inferenceId(), newDimensions, input, targetField);
+ }
+
+ public DenseVectorEmbedding withTargetField(Attribute targetField) {
+ return new DenseVectorEmbedding(source(), child(), inferenceId(), dimensions, input, targetField);
+ }
+
+ @Override
+ public DenseVectorEmbedding withModelConfigurations(ModelConfigurations modelConfig) {
+ boolean hasChanged = false;
+ Expression newDimensions = dimensions;
+
+ if (dimensions.resolved() == false
+ && modelConfig.getServiceSettings() != null
+ && modelConfig.getServiceSettings().dimensions() > 0) {
+ hasChanged = true;
+ newDimensions = new Literal(Source.EMPTY, modelConfig.getServiceSettings().dimensions(), DataType.INTEGER);
+ }
+
+ return hasChanged ? withDimensions(newDimensions) : this;
+ }
+
+ @Override
+ public DenseVectorEmbedding replaceChild(LogicalPlan newChild) {
+ return new DenseVectorEmbedding(source(), newChild, inferenceId(), dimensions, input, targetField);
+ }
+
+ @Override
+ protected NodeInfo extends LogicalPlan> info() {
+ return NodeInfo.create(this, DenseVectorEmbedding::new, child(), inferenceId(), dimensions, input, targetField);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ if (super.equals(o) == false) return false;
+ DenseVectorEmbedding that = (DenseVectorEmbedding) o;
+ return Objects.equals(input, that.input)
+ && Objects.equals(dimensions, that.dimensions)
+ && Objects.equals(targetField, that.targetField);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(super.hashCode(), input, targetField, dimensions);
+ }
+
+ private static class UnresolvedDimensions extends Literal implements Unresolvable {
+
+ private final String inferenceId;
+
+ private UnresolvedDimensions(Expression inferenceId) {
+ super(Source.EMPTY, null, DataType.NULL);
+ this.inferenceId = BytesRefs.toString(inferenceId.fold(FoldContext.small()));
+ }
+
+ @Override
+ public String unresolvedMessage() {
+ return "Dimensions cannot be resolved for inference endpoint[" + inferenceId + "]";
+ }
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/embedding/DenseVectorEmbeddingExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/embedding/DenseVectorEmbeddingExec.java
new file mode 100644
index 0000000000000..4eae30daeef8e
--- /dev/null
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/inference/embedding/DenseVectorEmbeddingExec.java
@@ -0,0 +1,130 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.plan.physical.inference.embedding;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.xpack.esql.core.expression.Attribute;
+import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
+import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
+import org.elasticsearch.xpack.esql.plan.physical.UnaryExec;
+import org.elasticsearch.xpack.esql.plan.physical.inference.InferenceExec;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes;
+
+public class DenseVectorEmbeddingExec extends InferenceExec {
+
+ public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
+ PhysicalPlan.class,
+ "DenseVectorEmbeddingExec",
+ DenseVectorEmbeddingExec::new
+ );
+
+ private final Expression input;
+ private final Expression dimensions;
+ private final Attribute targetField;
+ private List lazyOutput;
+
+ public DenseVectorEmbeddingExec(
+ Source source,
+ PhysicalPlan child,
+ Expression inferenceId,
+ Expression dimensions,
+ Expression input,
+ Attribute targetField
+ ) {
+ super(source, child, inferenceId);
+ this.input = input;
+ this.dimensions = dimensions;
+ this.targetField = targetField;
+ }
+
+ public DenseVectorEmbeddingExec(StreamInput in) throws IOException {
+ this(
+ Source.readFrom((PlanStreamInput) in),
+ in.readNamedWriteable(PhysicalPlan.class),
+ in.readNamedWriteable(Expression.class),
+ in.readNamedWriteable(Expression.class),
+ in.readNamedWriteable(Expression.class),
+ in.readNamedWriteable(Attribute.class)
+ );
+ }
+
+ public Expression input() {
+ return input;
+ }
+
+ public Attribute targetField() {
+ return targetField;
+ }
+
+ @Override
+ public String getWriteableName() {
+ return ENTRY.name;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ super.writeTo(out);
+ out.writeNamedWriteable(dimensions);
+ out.writeNamedWriteable(input);
+ out.writeNamedWriteable(targetField);
+ }
+
+ public Expression dimensions() {
+ return dimensions;
+ }
+
+ @Override
+ protected NodeInfo extends PhysicalPlan> info() {
+ return NodeInfo.create(this, DenseVectorEmbeddingExec::new, child(), inferenceId(), input, dimensions, targetField);
+ }
+
+ @Override
+ public UnaryExec replaceChild(PhysicalPlan newChild) {
+ return new DenseVectorEmbeddingExec(source(), newChild, inferenceId(), input, dimensions, targetField);
+ }
+
+ @Override
+ public List output() {
+ if (lazyOutput == null) {
+ lazyOutput = mergeOutputAttributes(List.of(targetField), child().output());
+ }
+ return lazyOutput;
+ }
+
+ @Override
+ protected AttributeSet computeReferences() {
+ return input.references();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ if (super.equals(o) == false) return false;
+ DenseVectorEmbeddingExec that = (DenseVectorEmbeddingExec) o;
+ return Objects.equals(input, that.input)
+ && Objects.equals(dimensions, that.dimensions)
+ && Objects.equals(targetField, that.targetField);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(super.hashCode(), input, dimensions, targetField);
+ }
+}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java
index a92d2f439a0ea..333eab10b217a 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java
@@ -88,6 +88,7 @@
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
import org.elasticsearch.xpack.esql.inference.XContentRowEncoder;
import org.elasticsearch.xpack.esql.inference.completion.CompletionOperator;
+import org.elasticsearch.xpack.esql.inference.embedding.DenseEmbeddingOperator;
import org.elasticsearch.xpack.esql.inference.rerank.RerankOperator;
import org.elasticsearch.xpack.esql.plan.logical.Fork;
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
@@ -119,6 +120,7 @@
import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
import org.elasticsearch.xpack.esql.plan.physical.inference.CompletionExec;
import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec;
+import org.elasticsearch.xpack.esql.plan.physical.inference.embedding.DenseVectorEmbeddingExec;
import org.elasticsearch.xpack.esql.planner.EsPhysicalOperationProviders.ShardContext;
import org.elasticsearch.xpack.esql.plugin.QueryPragmas;
import org.elasticsearch.xpack.esql.score.ScoreMapper;
@@ -266,6 +268,8 @@ private PhysicalOperation plan(PhysicalPlan node, LocalExecutionPlannerContext c
return planChangePoint(changePoint, context);
} else if (node instanceof CompletionExec completion) {
return planCompletion(completion, context);
+ } else if (node instanceof DenseVectorEmbeddingExec embedding) {
+ return planDenseVectorEmbedding(embedding, context);
} else if (node instanceof SampleExec Sample) {
return planSample(Sample, context);
}
@@ -319,6 +323,31 @@ private PhysicalOperation planCompletion(CompletionExec completion, LocalExecuti
return source.with(new CompletionOperator.Factory(inferenceRunner, inferenceId, promptEvaluatorFactory), outputLayout);
}
+ private PhysicalOperation planDenseVectorEmbedding(DenseVectorEmbeddingExec embedding, LocalExecutionPlannerContext context) {
+ PhysicalOperation source = plan(embedding.child(), context);
+ String inferenceId = BytesRefs.toString(embedding.inferenceId().fold(context.foldCtx()));
+
+ int dimensions;
+ if (embedding.dimensions() instanceof Literal literal) {
+ Object val = literal.value() instanceof BytesRef br ? BytesRefs.toString(br) : literal.value();
+ dimensions = stringToInt(val.toString());
+ } else {
+ throw new EsqlIllegalArgumentException("dimensions only supported with literal values");
+ }
+
+ Layout outputLayout = source.layout.builder().append(embedding.targetField()).build();
+ EvalOperator.ExpressionEvaluator.Factory inputEvaluatorFactory = EvalMapper.toEvaluator(
+ context.foldCtx(),
+ embedding.input(),
+ source.layout
+ );
+
+ return source.with(
+ new DenseEmbeddingOperator.Factory(inferenceRunner, inferenceId, dimensions, inputEvaluatorFactory),
+ outputLayout
+ );
+ }
+
private PhysicalOperation planRrfScoreEvalExec(RrfScoreEvalExec rrf, LocalExecutionPlannerContext context) {
PhysicalOperation source = plan(rrf.child(), context);
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java
index 4851de1616844..884f50b71a549 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java
@@ -29,6 +29,7 @@
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
+import org.elasticsearch.xpack.esql.plan.logical.inference.embedding.DenseVectorEmbedding;
import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation;
import org.elasticsearch.xpack.esql.plan.logical.show.ShowInfo;
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
@@ -47,6 +48,7 @@
import org.elasticsearch.xpack.esql.plan.physical.TimeSeriesAggregateExec;
import org.elasticsearch.xpack.esql.plan.physical.inference.CompletionExec;
import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec;
+import org.elasticsearch.xpack.esql.plan.physical.inference.embedding.DenseVectorEmbeddingExec;
import org.elasticsearch.xpack.esql.planner.AbstractPhysicalOperationProviders;
import java.util.List;
@@ -106,6 +108,17 @@ static PhysicalPlan mapUnary(UnaryPlan p, PhysicalPlan child) {
return new CompletionExec(completion.source(), child, completion.inferenceId(), completion.prompt(), completion.targetField());
}
+ if (p instanceof DenseVectorEmbedding embedding) {
+ return new DenseVectorEmbeddingExec(
+ embedding.source(),
+ child,
+ embedding.inferenceId(),
+ embedding.dimensions(),
+ embedding.input(),
+ embedding.embeddingField()
+ );
+ }
+
if (p instanceof Enrich enrich) {
return new EnrichExec(
enrich.source(),
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java
index 4ff65f59bbd72..920201897199b 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java
@@ -85,7 +85,6 @@
import org.elasticsearch.xpack.esql.plan.logical.TopN;
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
-import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin;
import org.elasticsearch.xpack.esql.plan.logical.join.JoinTypes;
import org.elasticsearch.xpack.esql.plan.logical.join.LookupJoin;
@@ -372,7 +371,7 @@ public void analyzedPlan(
l -> enrichPolicyResolver.resolvePolicies(unresolvedPolicies, executionInfo, l)
)
.andThen((l, enrichResolution) -> resolveFieldNames(parsed, enrichResolution, l))
- .andThen((l, preAnalysisResult) -> resolveInferences(preAnalysis.inferencePlans, preAnalysisResult, l));
+ .andThen((l, preAnalysisResult) -> resolveInferences(preAnalysis.inferenceIds, preAnalysisResult, l));
// first resolve the lookup indices, then the main indices
for (var index : preAnalysis.lookupIndices) {
listener = listener.andThen((l, preAnalysisResult) -> { preAnalyzeLookupIndex(index, preAnalysisResult, l); });
@@ -588,12 +587,8 @@ private static void resolveFieldNames(LogicalPlan parsed, EnrichResolution enric
}
}
- private void resolveInferences(
- List> inferencePlans,
- PreAnalysisResult preAnalysisResult,
- ActionListener l
- ) {
- inferenceRunner.resolveInferenceIds(inferencePlans, l.map(preAnalysisResult::withInferenceResolution));
+ private void resolveInferences(Set inferenceIds, PreAnalysisResult preAnalysisResult, ActionListener l) {
+ inferenceRunner.resolveInferenceIds(inferenceIds, l.map(preAnalysisResult::withInferenceResolution));
}
static PreAnalysisResult fieldNames(LogicalPlan parsed, Set enrichPolicyMatchFields, PreAnalysisResult result) {
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java
index 5e6c37545a396..a016efa137956 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java
@@ -8,6 +8,7 @@
package org.elasticsearch.xpack.esql.analysis;
import org.elasticsearch.index.IndexMode;
+import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.core.enrich.EnrichPolicy;
import org.elasticsearch.xpack.esql.EsqlTestUtils;
@@ -39,6 +40,8 @@
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
public final class AnalyzerTestUtils {
@@ -189,12 +192,18 @@ public static EnrichResolution defaultEnrichResolution() {
public static InferenceResolution defaultInferenceResolution() {
return InferenceResolution.builder()
- .withResolvedInference(new ResolvedInference("reranking-inference-id", TaskType.RERANK))
- .withResolvedInference(new ResolvedInference("completion-inference-id", TaskType.COMPLETION))
+ .withResolvedInference(mockedResolvedInference("reranking-inference-id", TaskType.RERANK))
+ .withResolvedInference(mockedResolvedInference("completion-inference-id", TaskType.COMPLETION))
.withError("error-inference-id", "error with inference resolution")
.build();
}
+ private static ResolvedInference mockedResolvedInference(String inferenceId, TaskType taskType) {
+ ModelConfigurations modelConfigurations = mock(ModelConfigurations.class);
+ when(modelConfigurations.getTaskType()).thenReturn(taskType);
+ return new ResolvedInference(inferenceId, modelConfigurations);
+ }
+
public static void loadEnrichPolicyResolution(
EnrichResolution enrich,
String policyType,
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java
index 00f20b9376a6f..4e256a1563af2 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java
@@ -798,6 +798,7 @@ public static void testFunctionInfo() {
Set returnTypes = Arrays.stream(description.returnType())
.filter(t -> DataType.UNDER_CONSTRUCTION.containsKey(DataType.fromNameOrAlias(t)) == false)
.collect(Collectors.toCollection(TreeSet::new));
+
assertEquals(returnFromSignature, returnTypes);
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/DenseVectorEmbeddingErrorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/DenseVectorEmbeddingErrorTests.java
new file mode 100644
index 0000000000000..4158b2cc9380e
--- /dev/null
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/DenseVectorEmbeddingErrorTests.java
@@ -0,0 +1,79 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.inference;
+
+import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase;
+import org.elasticsearch.xpack.esql.expression.function.ErrorsForCasesWithoutExamplesTestCase;
+import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
+import org.hamcrest.Matcher;
+import org.junit.Before;
+
+import java.util.List;
+import java.util.Locale;
+import java.util.Set;
+import java.util.stream.Stream;
+
+import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
+import static org.hamcrest.Matchers.equalTo;
+
+public class DenseVectorEmbeddingErrorTests extends ErrorsForCasesWithoutExamplesTestCase {
+
+ @Before
+ public void checkCapability() {
+ assumeTrue("DENSE_VECTOR_EMBEDDING_FUNCTION is not enabled", EsqlCapabilities.Cap.DENSE_VECTOR_EMBEDDING_FUNCTION.isEnabled());
+ }
+
+ @Override
+ protected List cases() {
+ return paramsToSuppliers(DenseVectorEmbeddingFunctionTests.parameters());
+ }
+
+ @Override
+ protected Stream> testCandidates(List cases, Set> valid) {
+ // Don't test null, as it is not allowed but the expected message is not a type error - so we check it separately in VerifierTests
+ return super.testCandidates(cases, valid).filter(sig -> false == sig.contains(DataType.NULL));
+ }
+
+ @Override
+ protected Expression build(Source source, List args) {
+ return new DenseVectorEmbeddingFunction(source, args.get(0), args.get(1));
+ }
+
+ @Override
+ protected Matcher expectedTypeErrorMatcher(List> validPerPosition, List signature) {
+ return equalTo(errorMessageString(validPerPosition, signature, (v, p) -> "string"));
+ }
+
+ private static String errorMessageString(
+ List> validPerPosition,
+ List signature,
+ AbstractFunctionTestCase.PositionalErrorMessageSupplier positionalErrorMessageSupplier
+ ) {
+ for (int i = 0; i < signature.size(); i++) {
+ if (validPerPosition.get(i).contains(signature.get(i)) == false) {
+ // Map expressions have different error messages
+ if (i == signature.size() - 1) {
+ return format(
+ null,
+ "{} argument of [{}] must be a map expression, received []",
+ TypeResolutions.ParamOrdinal.fromIndex(i).name().toLowerCase(Locale.ROOT),
+ sourceForSignature(signature)
+ );
+ }
+ break;
+ }
+ }
+
+ return typeErrorMessage(true, validPerPosition, signature, positionalErrorMessageSupplier);
+ }
+}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/DenseVectorEmbeddingFunctionTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/DenseVectorEmbeddingFunctionTests.java
new file mode 100644
index 0000000000000..c493cf7177043
--- /dev/null
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/DenseVectorEmbeddingFunctionTests.java
@@ -0,0 +1,103 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.inference;
+
+import com.carrotsearch.randomizedtesting.annotations.Name;
+import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
+
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.xpack.core.security.authc.support.mapper.expressiondsl.FieldExpression;
+import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.Literal;
+import org.elasticsearch.xpack.esql.core.expression.MapExpression;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase;
+import org.elasticsearch.xpack.esql.expression.function.FunctionName;
+import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
+import org.hamcrest.Matchers;
+import org.junit.Before;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.Supplier;
+import java.util.stream.Stream;
+
+import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY;
+import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
+import static org.elasticsearch.xpack.esql.core.type.DataType.UNSUPPORTED;
+import static org.hamcrest.Matchers.equalTo;
+
+@FunctionName("text_dense_vector_embedding")
+public class DenseVectorEmbeddingFunctionTests extends AbstractFunctionTestCase {
+ @Before
+ public void checkCapability() {
+ assumeTrue("DENSE_VECTOR_EMBEDDING_FUNCTION is not enabled", EsqlCapabilities.Cap.DENSE_VECTOR_EMBEDDING_FUNCTION.isEnabled());
+ }
+
+ public DenseVectorEmbeddingFunctionTests(@Name("TestCase") Supplier testCaseSupplier) {
+ this.testCase = testCaseSupplier.get();
+ }
+
+ @ParametersFactory
+ public static Iterable