Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
59087fb
Add text embedding function definition.
afoucret Sep 9, 2025
8f4c409
Fix tests.
afoucret Sep 9, 2025
847c998
InferenceResolution for text embedding function.
afoucret Sep 11, 2025
4fe169a
Text embedding analysis and verification.
afoucret Sep 11, 2025
44ace1a
Merge branch 'main' of https://github.com/elastic/elasticsearch into …
afoucret Sep 11, 2025
8056fde
Fix test can fail if byte or bit vectors are not supported
afoucret Sep 11, 2025
c4a0871
Add the text_embedding function to xpack usage tests
afoucret Sep 13, 2025
2b95815
Merge branch 'main' of https://github.com/elastic/elasticsearch into …
afoucret Sep 13, 2025
cee2935
Fix error in xpack usage
afoucret Sep 13, 2025
9d7a23f
Create the text embedding request iterator
afoucret Sep 13, 2025
b8c5f11
Clean analyzer tests to avoid forbidden api usage.
afoucret Sep 13, 2025
678438b
Add text embedding output builder.
afoucret Sep 13, 2025
2068e91
Text embedding inference operator.
afoucret Sep 13, 2025
9298450
More flexible output builder.
afoucret Sep 16, 2025
0ae1e0e
Init inference function evaluator.
afoucret Sep 16, 2025
54e9ca6
Merge branch 'main' of https://github.com/elastic/elasticsearch into …
afoucret Sep 18, 2025
6d064f0
Implementing inference evaluation in the pre-optimizer.
afoucret Sep 18, 2025
6fc5f99
[CI] Auto commit changes from spotless
Sep 18, 2025
5ee2382
Remove overengineered type param on the InferenceOperator.OutputBuilder
afoucret Sep 18, 2025
3e9d3d5
Unit tests for InferenceFunctionEvaluator
afoucret Sep 18, 2025
5604055
Unit tests for InferenceFunctionEvaluator
afoucret Sep 18, 2025
238c0c2
LogicalPlanPreOptimizerTests rule chain tests.
afoucret Sep 18, 2025
816c410
More unit tests.
afoucret Sep 18, 2025
ea3de8b
More CSV tests.
afoucret Sep 18, 2025
42bc936
Lint
afoucret Sep 18, 2025
1d1cc3a
Remove useless changes.
afoucret Sep 18, 2025
75ed988
More CSV tests :tada:
afoucret Sep 18, 2025
2d4749b
Fix a typo error
afoucret Sep 18, 2025
3fe8248
Merge branch 'main' into esql_text_embedding_function
afoucret Sep 18, 2025
e789dc2
Update docs/changelog/134573.yaml
afoucret Sep 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/134573.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 134573
summary: Esql text embedding function
area: ES|QL
type: feature
issues: []
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ public String toString() {
return Strings.toString(this);
}

float[] toFloatArray() {
public float[] toFloatArray() {
float[] floatArray = new float[values.length];
for (int i = 0; i < values.length; i++) {
floatArray[i] = ((Byte) values[i]).floatValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.RERANK;
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.SEMANTIC_TEXT_FIELD_CAPS;
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.SOURCE_FIELD_MAPPING;
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION;
import static org.elasticsearch.xpack.esql.qa.rest.RestEsqlTestCase.assertNotPartial;
import static org.elasticsearch.xpack.esql.qa.rest.RestEsqlTestCase.hasCapabilities;

Expand Down Expand Up @@ -205,7 +206,8 @@ protected boolean requiresInferenceEndpoint() {
SEMANTIC_TEXT_FIELD_CAPS.capabilityName(),
RERANK.capabilityName(),
COMPLETION.capabilityName(),
KNN_FUNCTION_V5.capabilityName()
KNN_FUNCTION_V5.capabilityName(),
TEXT_EMBEDDING_FUNCTION.capabilityName()
).anyMatch(testCase.requiredCapabilities::contains);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
text_embedding using a row source operator
required_capability: text_embedding_function
required_capability: dense_vector_field_type

ROW input="Who is Victor Hugo?"
| EVAL embedding = TEXT_EMBEDDING("Who is Victor Hugo?", "test_dense_inference")
;

input:keyword | embedding:dense_vector
Who is Victor Hugo? | [56.0, 50.0, 48.0]
;


text_embedding using a row source operator with query build using CONCAT
required_capability: text_embedding_function
required_capability: dense_vector_field_type

ROW input="Who is Victor Hugo?"
| EVAL embedding = TEXT_EMBEDDING(CONCAT("Who is ", "Victor Hugo?"), "test_dense_inference")
;

input:keyword | embedding:dense_vector
Who is Victor Hugo? | [56.0, 50.0, 48.0]
;


text_embedding with knn on semantic_text_dense_field
required_capability: text_embedding_function
required_capability: dense_vector_field_type
required_capability: knn_function_v5
required_capability: semantic_text_field_caps

FROM semantic_text METADATA _score
| EVAL query_embedding = TEXT_EMBEDDING("be excellent to each other", "test_dense_inference")
| WHERE KNN(semantic_text_dense_field, query_embedding)
| KEEP semantic_text_field, query_embedding, _score
| EVAL _score = ROUND(_score, 4)
| SORT _score DESC
| LIMIT 10
;

semantic_text_field:text | query_embedding:dense_vector | _score:double
be excellent to each other | [45.0, 55.0, 54.0] | 1.0
live long and prosper | [45.0, 55.0, 54.0] | 0.0295
all we have to decide is what to do with the time that is given to us | [45.0, 55.0, 54.0] | 0.0214
;

text_embedding with knn (inline) on semantic_text_dense_field
required_capability: text_embedding_function
required_capability: dense_vector_field_type
required_capability: knn_function_v5
required_capability: semantic_text_field_caps

FROM semantic_text METADATA _score
| WHERE KNN(semantic_text_dense_field, TEXT_EMBEDDING("be excellent to each other", "test_dense_inference"))
| KEEP semantic_text_field, _score
| EVAL _score = ROUND(_score, 4)
| SORT _score DESC
| LIMIT 10
;

semantic_text_field:text | _score:double
be excellent to each other | 1.0
live long and prosper | 0.0295
all we have to decide is what to do with the time that is given to us | 0.0214
;
Original file line number Diff line number Diff line change
Expand Up @@ -1309,6 +1309,11 @@ public enum Cap {
*/
KNN_FUNCTION_V5(Build.current().isSnapshot()),

/**
* Support for the {@code TEXT_EMBEDDING} function for generating dense vector embeddings.
*/
TEXT_EMBEDDING_FUNCTION(Build.current().isSnapshot()),

/**
* Support for the LIKE operator with a list of wildcards.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
import org.elasticsearch.xpack.esql.expression.function.aggregate.SummationMode;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction;
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest;
Expand Down Expand Up @@ -1414,7 +1415,8 @@ private static class ResolveInference extends ParameterizedRule<LogicalPlan, Log

@Override
public LogicalPlan apply(LogicalPlan plan, AnalyzerContext context) {
return plan.transformDown(InferencePlan.class, p -> resolveInferencePlan(p, context));
return plan.transformDown(InferencePlan.class, p -> resolveInferencePlan(p, context))
.transformExpressionsOnly(InferenceFunction.class, f -> resolveInferenceFunction(f, context));
}

private LogicalPlan resolveInferencePlan(InferencePlan<?> plan, AnalyzerContext context) {
Expand Down Expand Up @@ -1443,6 +1445,36 @@ private LogicalPlan resolveInferencePlan(InferencePlan<?> plan, AnalyzerContext

return plan;
}

private InferenceFunction<?> resolveInferenceFunction(InferenceFunction<?> inferenceFunction, AnalyzerContext context) {
if (inferenceFunction.inferenceId().resolved()
&& inferenceFunction.inferenceId().foldable()
&& DataType.isString(inferenceFunction.inferenceId().dataType())) {

String inferenceId = BytesRefs.toString(inferenceFunction.inferenceId().fold(FoldContext.small()));
ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId);

if (resolvedInference == null) {
String error = context.inferenceResolution().getError(inferenceId);
return inferenceFunction.withInferenceResolutionError(inferenceId, error);
}

if (resolvedInference.taskType() != inferenceFunction.taskType()) {
String error = "cannot use inference endpoint ["
+ inferenceId
+ "] with task type ["
+ resolvedInference.taskType()
+ "] within a "
+ context.functionRegistry().snapshotRegistry().functionName(inferenceFunction.getClass())
+ " function. Only inference endpoints with the task type ["
+ inferenceFunction.taskType()
+ "] are supported.";
return inferenceFunction.withInferenceResolutionError(inferenceId, error);
}
}

return inferenceFunction;
}
}

private static class AddImplicitLimit extends ParameterizedRule<LogicalPlan, LogicalPlan, AnalyzerContext> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public void esql(
indexResolver,
enrichPolicyResolver,
preAnalyzer,
new LogicalPlanPreOptimizer(new LogicalPreOptimizerContext(foldContext)),
new LogicalPlanPreOptimizer(new LogicalPreOptimizerContext(foldContext, services.inferenceService())),
functionRegistry,
new LogicalPlanOptimizer(new LogicalOptimizerContext(cfg, foldContext)),
mapper,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateWritables;
import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextWritables;
import org.elasticsearch.xpack.esql.expression.function.inference.InferenceWritables;
import org.elasticsearch.xpack.esql.expression.function.scalar.ScalarFunctionWritables;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FromBase64;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToAggregateMetricDouble;
Expand Down Expand Up @@ -120,6 +121,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
entries.addAll(fullText());
entries.addAll(unaryScalars());
entries.addAll(vector());
entries.addAll(inference());
return entries;
}

Expand Down Expand Up @@ -265,4 +267,8 @@ private static List<NamedWriteableRegistry.Entry> fullText() {
private static List<NamedWriteableRegistry.Entry> vector() {
return VectorWritables.getNamedWritables();
}

private static List<NamedWriteableRegistry.Entry> inference() {
return InferenceWritables.getNamedWritables();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
import org.elasticsearch.xpack.esql.expression.function.grouping.TBucket;
import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Least;
Expand Down Expand Up @@ -539,7 +540,8 @@ private static FunctionDefinition[][] snapshotFunctions() {
def(Magnitude.class, Magnitude::new, "v_magnitude"),
def(Hamming.class, Hamming::new, "v_hamming"),
def(UrlEncode.class, UrlEncode::new, "url_encode"),
def(UrlDecode.class, UrlDecode::new, "url_decode") } };
def(UrlDecode.class, UrlDecode::new, "url_decode"),
def(TextEmbedding.class, bi(TextEmbedding::new), "text_embedding") } };
}

public EsqlFunctionRegistry snapshotRegistry() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.expression.function.inference;

import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.function.Function;
import org.elasticsearch.xpack.esql.core.tree.Source;

import java.util.List;

/**
* Base class for ESQL functions that use inference endpoints (e.g., TEXT_EMBEDDING).
*/
public abstract class InferenceFunction<PlanType extends InferenceFunction<PlanType>> extends Function {

public static final String INFERENCE_ID_PARAMETER_NAME = "inference_id";

protected InferenceFunction(Source source, List<Expression> children) {
super(source, children);
}

/** The inference endpoint identifier expression. */
public abstract Expression inferenceId();

/** The task type required by this function (e.g., TEXT_EMBEDDING). */
public abstract TaskType taskType();

/** Returns a copy with inference resolution error for display to user. */
public abstract PlanType withInferenceResolutionError(String inferenceId, String error);

/** True if this function contains nested inference function calls. */
public boolean hasNestedInferenceFunction() {
return anyMatch(e -> e instanceof InferenceFunction && e != this);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.expression.function.inference;

import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

/**
* Defines the named writables for inference functions in ESQL.
*/
public final class InferenceWritables {

private InferenceWritables() {
// Utility class
throw new UnsupportedOperationException();
}

public static List<NamedWriteableRegistry.Entry> getNamedWritables() {
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();

if (EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()) {
entries.add(TextEmbedding.ENTRY);
}

return Collections.unmodifiableList(entries);
}
}
Loading