Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
da21972
ESQL: Add asynchronous pre-optimization step for logical plan
afoucret Jul 17, 2025
8b5082b
[CI] Auto commit changes from spotless
Jul 17, 2025
658b54b
Revert uselss change in QueryPlanningBenchmark
afoucret Jul 17, 2025
76831a0
Add EMBED_TEXT function infrastructure for dense vector embeddings
afoucret Jul 11, 2025
f735f46
Extend InferenceResolver to collect inference IDs from functions and …
afoucret Jul 11, 2025
f03e155
Add Analyzer support for EMBED_TEXT inference functions
afoucret Jul 28, 2025
d365b0d
Rename EMBED_TEXT function to TEXT_EMBEDDING and update all references
afoucret Jul 11, 2025
df1936d
Add PreOptimizer infrastructure for async pre-optimization steps
afoucret Jul 11, 2025
1f0c62c
[ESQL] Add async transformation methods to Node and QueryPlan
afoucret Jul 16, 2025
1f034a0
[ESQL] Add TEXT_EMBEDDING function evaluator for dense vector embeddings
afoucret Jul 16, 2025
419d523
[ESQL] Complete TEXT_EMBEDDING function integration
afoucret Jul 16, 2025
3a506c5
Fix error.
afoucret Jul 16, 2025
5e55664
Lint
afoucret Jul 16, 2025
1336099
Introduce a new PRE_OPTIMIZED to the LogicalPlan
afoucret Jul 16, 2025
9c619b9
Add basic csv tests.
afoucret Jul 16, 2025
e90971d
Fix forbidden API usage
afoucret Jul 17, 2025
567175d
Exclude TEXT_EMBEDDING from CsvTests.
afoucret Jul 17, 2025
0497e7e
Refactored inference folding in PreOptimizer.
afoucret Jul 17, 2025
35c27a4
Revert useless node transform changes.
afoucret Jul 17, 2025
24e242f
Fix typo
afoucret Jul 17, 2025
b459ef3
Renamed PreOptimizer into LogicalPlanPreOptimizer
afoucret Jul 17, 2025
4040292
Add TEXT_EMBEDDING_FUNCTION capability to EsqlSpecTestCase::requiresI…
afoucret Jul 17, 2025
0fda848
Restore previous implementation of foldable for TextEmbedding function.
afoucret Jul 17, 2025
cd03311
Fix CsvTests
afoucret Jul 17, 2025
f92e1f5
Get rid of the InferenceServices class.
afoucret Jul 17, 2025
f3c9324
Improved the inference pre-optimization.
afoucret Jul 18, 2025
3824f4f
Checkstyle!
afoucret Jul 18, 2025
70819e1
Huge simplification of the bul inference runner & operator
afoucret Jul 24, 2025
c09800a
Lint
afoucret Jul 24, 2025
cab92a1
Fix tests after rebasing.
afoucret Jul 28, 2025
71d591b
Update from main.
afoucret Aug 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -155,15 +155,15 @@ public String toString() {
return Strings.toString(this);
}

float[] toFloatArray() {
public float[] toFloatArray() {
float[] floatArray = new float[values.length];
for (int i = 0; i < values.length; i++) {
floatArray[i] = ((Byte) values[i]).floatValue();
}
return floatArray;
}

double[] toDoubleArray() {
public double[] toDoubleArray() {
double[] doubleArray = new double[values.length];
for (int i = 0; i < values.length; i++) {
doubleArray[i] = ((Byte) values[i]).doubleValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.RERANK;
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.SEMANTIC_TEXT_FIELD_CAPS;
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.SOURCE_FIELD_MAPPING;
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION;
import static org.elasticsearch.xpack.esql.qa.rest.RestEsqlTestCase.assertNotPartial;
import static org.elasticsearch.xpack.esql.qa.rest.RestEsqlTestCase.hasCapabilities;

Expand Down Expand Up @@ -211,8 +212,12 @@ protected boolean supportsInferenceTestService() {
}

protected boolean requiresInferenceEndpoint() {
return Stream.of(SEMANTIC_TEXT_FIELD_CAPS.capabilityName(), RERANK.capabilityName(), COMPLETION.capabilityName())
.anyMatch(testCase.requiredCapabilities::contains);
return Stream.of(
SEMANTIC_TEXT_FIELD_CAPS.capabilityName(),
RERANK.capabilityName(),
COMPLETION.capabilityName(),
TEXT_EMBEDDING_FUNCTION.capabilityName()
).anyMatch(testCase.requiredCapabilities::contains);
}

protected boolean supportsIndexModeLookup() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,10 @@ public static void createInferenceEndpoints(RestClient client) throws IOExceptio
createSparseEmbeddingInferenceEndpoint(client);
}

if (clusterHasDenseEmbeddingInferenceEndpoint(client) == false) {
createDenseEmbeddingInferenceEndpoint(client);
}

if (clusterHasRerankInferenceEndpoint(client) == false) {
createRerankInferenceEndpoint(client);
}
Expand All @@ -414,6 +418,7 @@ public static void createInferenceEndpoints(RestClient client) throws IOExceptio

public static void deleteInferenceEndpoints(RestClient client) throws IOException {
deleteSparseEmbeddingInferenceEndpoint(client);
deleteDenseEmbeddingInferenceEndpoint(client);
deleteRerankInferenceEndpoint(client);
deleteCompletionInferenceEndpoint(client);
}
Expand All @@ -437,6 +442,24 @@ public static boolean clusterHasSparseEmbeddingInferenceEndpoint(RestClient clie
return clusterHasInferenceEndpoint(client, TaskType.SPARSE_EMBEDDING, "test_sparse_inference");
}

public static void createDenseEmbeddingInferenceEndpoint(RestClient client) throws IOException {
createInferenceEndpoint(client, TaskType.TEXT_EMBEDDING, "test_dense_inference", """
{
"service": "text_embedding_test_service",
"service_settings": { "model": "my_model", "api_key": "abc64", "dimensions": 10 },
"task_settings": { }
}
""");
}

public static void deleteDenseEmbeddingInferenceEndpoint(RestClient client) throws IOException {
deleteInferenceEndpoint(client, "test_dense_inference");
}

public static boolean clusterHasDenseEmbeddingInferenceEndpoint(RestClient client) throws IOException {
return clusterHasInferenceEndpoint(client, TaskType.TEXT_EMBEDDING, "test_dense_inference");
}

public static void createRerankInferenceEndpoint(RestClient client) throws IOException {
createInferenceEndpoint(client, TaskType.RERANK, "test_reranker", """
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
text_embedding using a ROW source operator
required_capability: text_embedding_function
required_capability: dense_vector_field_type

ROW input="Who is Victor Hugo?"
| EVAL embedding = TEXT_EMBEDDING("Who is Victor Hugo?", "test_dense_inference")
;

input:keyword | embedding:dense_vector
Who is Victor Hugo? | [56.0, 50.0, 48.0, 50.0, 54.0, 52.0, 49.0, 51.0, 51.0, 56.0]
;
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,11 @@ public enum Cap {
*/
LIKE_ON_INDEX_FIELDS,

/**
* Support for the {@code TEXT_EMBEDDING} function for generating dense vector embeddings.
*/
TEXT_EMBEDDING_FUNCTION(Build.current().isSnapshot()),

/**
* Forbid usage of brackets in unquoted index and enrich policy names
* https://github.com/elastic/elasticsearch/issues/130378
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
import org.elasticsearch.xpack.esql.expression.function.aggregate.SumOverTime;
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction;
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest;
Expand Down Expand Up @@ -1311,7 +1312,35 @@ private static class ResolveInference extends ParameterizedRule<LogicalPlan, Log

@Override
public LogicalPlan apply(LogicalPlan plan, AnalyzerContext context) {
return plan.transformDown(InferencePlan.class, p -> resolveInferencePlan(p, context));
return plan.transformDown(InferencePlan.class, p -> resolveInferencePlan(p, context))
.transformExpressionsOnly(InferenceFunction.class, f -> resolveInferenceFunction(f, context));
}

private InferenceFunction<?> resolveInferenceFunction(InferenceFunction<?> inferenceFunction, AnalyzerContext context) {
assert inferenceFunction.inferenceId().resolved() && inferenceFunction.inferenceId().foldable();

String inferenceId = BytesRefs.toString(inferenceFunction.inferenceId().fold(FoldContext.small()));
ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId);

if (resolvedInference == null) {
String error = context.inferenceResolution().getError(inferenceId);
return inferenceFunction.withInferenceResolutionError(inferenceId, error);
}

if (resolvedInference.taskType() != inferenceFunction.taskType()) {
String error = "cannot use inference endpoint ["
+ inferenceId
+ "] with task type ["
+ resolvedInference.taskType()
+ "] within a "
+ context.functionRegistry().snapshotRegistry().functionName(inferenceFunction.getClass())
+ " function. Only inference endpoints with the task type ["
+ inferenceFunction.taskType()
+ "] are supported.";
return inferenceFunction.withInferenceResolutionError(inferenceId, error);
}

return inferenceFunction;
}

private LogicalPlan resolveInferencePlan(InferencePlan<?> plan, AnalyzerContext context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public void esql(
indexResolver,
enrichPolicyResolver,
preAnalyzer,
new LogicalPlanPreOptimizer(new LogicalPreOptimizerContext(foldContext)),
new LogicalPlanPreOptimizer(services, new LogicalPreOptimizerContext(foldContext)),
functionRegistry,
new LogicalPlanOptimizer(new LogicalOptimizerContext(cfg, foldContext)),
mapper,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
package org.elasticsearch.xpack.esql.expression;

import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
import org.elasticsearch.xpack.esql.core.expression.ExpressionCoreWritables;
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateWritables;
import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextWritables;
import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding;
import org.elasticsearch.xpack.esql.expression.function.scalar.ScalarFunctionWritables;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FromBase64;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToAggregateMetricDouble;
Expand Down Expand Up @@ -119,6 +121,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
entries.addAll(fullText());
entries.addAll(unaryScalars());
entries.addAll(vector());
entries.addAll(inference());
return entries;
}

Expand Down Expand Up @@ -262,4 +265,11 @@ private static List<NamedWriteableRegistry.Entry> fullText() {
private static List<NamedWriteableRegistry.Entry> vector() {
return VectorWritables.getNamedWritables();
}

private static List<NamedWriteableRegistry.Entry> inference() {
if (EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()) {
return List.of(TextEmbedding.ENTRY);
}
return List.of();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import org.elasticsearch.xpack.esql.expression.function.fulltext.Term;
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Least;
Expand Down Expand Up @@ -485,6 +486,7 @@ private static FunctionDefinition[][] snapshotFunctions() {
def(Score.class, uni(Score::new), Score.NAME),
def(Term.class, bi(Term::new), "term"),
def(Knn.class, quad(Knn::new), "knn"),
def(TextEmbedding.class, TextEmbedding::new, "text_embedding"),
def(StGeohash.class, StGeohash::new, "st_geohash"),
def(StGeohashToLong.class, StGeohashToLong::new, "st_geohash_to_long"),
def(StGeohashToString.class, StGeohashToString::new, "st_geohash_to_string"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.expression.function.inference;

import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.function.Function;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.inference.InferenceFunctionEvaluator;

import java.util.List;

/**
* A function is a function using an inference model.
*/
public abstract class InferenceFunction<PlanType extends InferenceFunction<PlanType>> extends Function {

public static final String INFERENCE_ID_PARAMETER_NAME = "inference_id";

protected InferenceFunction(Source source, List<Expression> children) {
super(source, children);
}

/**
* Returns the inference model ID expression.
*/
public abstract Expression inferenceId();

/**
* Returns the task type of the inference model.
*/
public abstract TaskType taskType();

/**
* Returns a new instance of the function with the specified inference resolution error.
*/
public abstract PlanType withInferenceResolutionError(String inferenceId, String error);

/**
* Returns the inference function evaluator factory.
*/
public abstract InferenceFunctionEvaluator.Factory inferenceEvaluatorFactory();

/**
* Returns true if the function has a nested inference function.
*/
public boolean hasNestedInferenceFunction() {
return anyMatch(e -> e instanceof InferenceFunction && e != this);
}
}
Loading
Loading