diff --git a/docs/reference/query-languages/esql/_snippets/functions/description/text_embedding.md b/docs/reference/query-languages/esql/_snippets/functions/description/text_embedding.md new file mode 100644 index 0000000000000..70ebddbdb7c0a --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/description/text_embedding.md @@ -0,0 +1,6 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Description** + +Generates dense vector embeddings for text using a specified inference endpoint. + diff --git a/docs/reference/query-languages/esql/_snippets/functions/examples/text_embedding.md b/docs/reference/query-languages/esql/_snippets/functions/examples/text_embedding.md new file mode 100644 index 0000000000000..71d05c9524350 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/examples/text_embedding.md @@ -0,0 +1,13 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Example** + +Generate text embeddings using the 'test_dense_inference' inference endpoint. + +```esql +ROW input="Who is Victor Hugo?" +| EVAL embedding = TEXT_EMBEDDING("Who is Victor Hugo?", "test_dense_inference") +; +``` + + diff --git a/docs/reference/query-languages/esql/_snippets/functions/layout/text_embedding.md b/docs/reference/query-languages/esql/_snippets/functions/layout/text_embedding.md new file mode 100644 index 0000000000000..a120fff2d7a22 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/layout/text_embedding.md @@ -0,0 +1,27 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +## `TEXT_EMBEDDING` [esql-text_embedding] +```{applies_to} +stack: development +serverless: preview +``` + +**Syntax** + +:::{image} ../../../images/functions/text_embedding.svg +:alt: Embedded +:class: text-center +::: + + +:::{include} ../parameters/text_embedding.md +::: + +:::{include} ../description/text_embedding.md +::: + +:::{include} ../types/text_embedding.md +::: + +:::{include} ../examples/text_embedding.md +::: diff --git a/docs/reference/query-languages/esql/_snippets/functions/parameters/text_embedding.md b/docs/reference/query-languages/esql/_snippets/functions/parameters/text_embedding.md index e2b852912c5f5..80175caaf09dd 100644 --- a/docs/reference/query-languages/esql/_snippets/functions/parameters/text_embedding.md +++ b/docs/reference/query-languages/esql/_snippets/functions/parameters/text_embedding.md @@ -1,4 +1,4 @@ -% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. **Parameters** diff --git a/docs/reference/query-languages/esql/_snippets/functions/types/text_embedding.md b/docs/reference/query-languages/esql/_snippets/functions/types/text_embedding.md new file mode 100644 index 0000000000000..6e45a6eb84c5c --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/types/text_embedding.md @@ -0,0 +1,8 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Supported types** + +| text | inference_id | result | +| --- | --- | --- | +| keyword | keyword | dense_vector | + diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/text_embedding.json b/docs/reference/query-languages/esql/kibana/definition/functions/text_embedding.json index 5f1f68a2b14bd..343cf597a0a53 100644 --- a/docs/reference/query-languages/esql/kibana/definition/functions/text_embedding.json +++ b/docs/reference/query-languages/esql/kibana/definition/functions/text_embedding.json @@ -3,7 +3,26 @@ "type" : "scalar", "name" : "text_embedding", "description" : "Generates dense vector embeddings for text using a specified inference endpoint.", - "signatures" : [ ], + "signatures" : [ + { + "params" : [ + { + "name" : "text", + "type" : "keyword", + "optional" : false, + "description" : "Text to generate embeddings from" + }, + { + "name" : "inference_id", + "type" : "keyword", + "optional" : false, + "description" : "Identifier of the inference endpoint" + } + ], + "variadic" : false, + "returnType" : "dense_vector" + } + ], "examples" : [ "ROW input=\"Who is Victor Hugo?\"\n| EVAL embedding = TEXT_EMBEDDING(\"Who is Victor Hugo?\", \"test_dense_inference\")\n;" ], diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java index fed250051d84d..2ecd089dedd2f 100644 --- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java @@ -77,6 +77,7 @@ import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.RERANK; import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.SEMANTIC_TEXT_FIELD_CAPS; import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.SOURCE_FIELD_MAPPING; +import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION; import static org.elasticsearch.xpack.esql.qa.rest.RestEsqlTestCase.assertNotPartial; import static org.elasticsearch.xpack.esql.qa.rest.RestEsqlTestCase.hasCapabilities; @@ -224,7 +225,8 @@ protected boolean requiresInferenceEndpoint() { SEMANTIC_TEXT_FIELD_CAPS.capabilityName(), RERANK.capabilityName(), COMPLETION.capabilityName(), - KNN_FUNCTION_V5.capabilityName() + KNN_FUNCTION_V5.capabilityName(), + TEXT_EMBEDDING_FUNCTION.capabilityName() ).anyMatch(testCase.requiredCapabilities::contains); } diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec index f026800598e10..86e0fcd0eb6a4 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec @@ -1,6 +1,6 @@ -placeholder +text_embedding using a row source operator required_capability: text_embedding_function -required_capability: not_existing_capability +required_capability: dense_vector_field_type_released // tag::embedding-eval[] ROW input="Who is Victor Hugo?" @@ -8,8 +8,84 @@ ROW input="Who is Victor Hugo?" ; // end::embedding-eval[] +input:keyword | embedding:dense_vector +Who is Victor Hugo? | [56.0, 50.0, 48.0] +; + + +text_embedding using a row source operator with query build using CONCAT +required_capability: text_embedding_function +required_capability: dense_vector_field_type_released + +ROW input="Who is Victor Hugo?" +| EVAL embedding = TEXT_EMBEDDING(CONCAT("Who is ", "Victor Hugo?"), "test_dense_inference") +; input:keyword | embedding:dense_vector Who is Victor Hugo? | [56.0, 50.0, 48.0] ; + +text_embedding with knn on semantic_text_dense_field +required_capability: text_embedding_function +required_capability: dense_vector_field_type_released +required_capability: knn_function_v5 +required_capability: semantic_text_field_caps + +FROM semantic_text METADATA _score +| EVAL query_embedding = TEXT_EMBEDDING("be excellent to each other", "test_dense_inference") +| WHERE KNN(semantic_text_dense_field, query_embedding) +| SORT _score DESC +| LIMIT 10 +| KEEP semantic_text_field, query_embedding +; + +semantic_text_field:text | query_embedding:dense_vector +be excellent to each other | [45.0, 55.0, 54.0] +live long and prosper | [45.0, 55.0, 54.0] +all we have to decide is what to do with the time that is given to us | [45.0, 55.0, 54.0] +; + +text_embedding with knn (inline) on semantic_text_dense_field +required_capability: text_embedding_function +required_capability: dense_vector_field_type_released +required_capability: knn_function_v5 +required_capability: semantic_text_field_caps + +FROM semantic_text METADATA _score +| WHERE KNN(semantic_text_dense_field, TEXT_EMBEDDING("be excellent to each other", "test_dense_inference")) +| SORT _score DESC +| LIMIT 10 +| KEEP semantic_text_field +; + +semantic_text_field:text +be excellent to each other +live long and prosper +all we have to decide is what to do with the time that is given to us +; + + +text_embedding with multiple knn queries in fork +required_capability: text_embedding_function +required_capability: dense_vector_field_type_released +required_capability: knn_function_v5 +required_capability: fork_v9 +required_capability: semantic_text_field_caps + +FROM semantic_text METADATA _score +| FORK (EVAL query_embedding = TEXT_EMBEDDING("be excellent to each other", "test_dense_inference") | WHERE KNN(semantic_text_dense_field, query_embedding)) + (EVAL query_embedding = TEXT_EMBEDDING("live long and prosper", "test_dense_inference") | WHERE KNN(semantic_text_dense_field, query_embedding)) +| SORT _score DESC, _fork ASC +| LIMIT 10 +| KEEP semantic_text_field, query_embedding, _fork +; + +semantic_text_field:text | query_embedding:dense_vector | _fork:keyword +be excellent to each other | [45.0, 55.0, 54.0] | fork1 +live long and prosper | [50.0, 57.0, 56.0] | fork2 +live long and prosper | [45.0, 55.0, 54.0] | fork1 +be excellent to each other | [50.0, 57.0, 56.0] | fork2 +all we have to decide is what to do with the time that is given to us | [45.0, 55.0, 54.0] | fork1 +all we have to decide is what to do with the time that is given to us | [50.0, 57.0, 56.0] | fork2 +; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java index 2043176f24a29..974b73718ff0b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java @@ -88,7 +88,7 @@ public void esql( indexResolver, enrichPolicyResolver, preAnalyzer, - new LogicalPlanPreOptimizer(new LogicalPreOptimizerContext(foldContext)), + new LogicalPlanPreOptimizer(new LogicalPreOptimizerContext(foldContext, services.inferenceService())), functionRegistry, new LogicalPlanOptimizer(new LogicalOptimizerContext(cfg, foldContext)), mapper, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java new file mode 100644 index 0000000000000..e5cc4301c5683 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java @@ -0,0 +1,228 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.lucene.BytesRefs; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BlockUtils; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.indices.breaker.AllCircuitBreakerStats; +import org.elasticsearch.indices.breaker.CircuitBreakerService; +import org.elasticsearch.indices.breaker.CircuitBreakerStats; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.evaluator.EvalMapper; +import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction; +import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding; +import org.elasticsearch.xpack.esql.inference.textembedding.TextEmbeddingOperator; + +/** + * Evaluator for inference functions that performs constant folding by executing inference operations + * at optimization time and replacing them with their computed results. + */ +public class InferenceFunctionEvaluator { + + private static final Factory FACTORY = new Factory(); + + public static InferenceFunctionEvaluator.Factory factory() { + return FACTORY; + } + + private final FoldContext foldContext; + private final InferenceOperatorProvider inferenceOperatorProvider; + + /** + * Creates a new inference function evaluator with a custom operator provider. + * This constructor is primarily used for testing to inject mock operator providers. + * + * @param foldContext the fold context containing circuit breakers and evaluation settings + * @param inferenceOperatorProvider custom provider for creating inference operators + */ + InferenceFunctionEvaluator(FoldContext foldContext, InferenceOperatorProvider inferenceOperatorProvider) { + this.foldContext = foldContext; + this.inferenceOperatorProvider = inferenceOperatorProvider; + } + + /** + * Folds an inference function by executing it and replacing it with its computed result. + *

+ * This method performs the following steps: + *

    + *
  1. Validates that the function is foldable (has constant parameters)
  2. + *
  3. Sets up a minimal execution context with appropriate circuit breakers
  4. + *
  5. Creates and configures the appropriate inference operator
  6. + *
  7. Executes the inference operation asynchronously
  8. + *
  9. Converts the result to a {@link Literal} expression
  10. + *
+ * + * @param f the inference function to fold - must be foldable (have constant parameters) + * @param listener the listener to notify when folding completes successfully or fails + */ + public void fold(InferenceFunction f, ActionListener listener) { + if (f.foldable() == false) { + listener.onFailure(new IllegalArgumentException("Inference function must be foldable")); + return; + } + + // Set up a DriverContext for executing the inference operator. + // This follows the same pattern as EvaluatorMapper but in a simplified context + // suitable for constant folding during optimization. + CircuitBreaker breaker = foldContext.circuitBreakerView(f.source()); + BigArrays bigArrays = new BigArrays(null, new CircuitBreakerService() { + @Override + public CircuitBreaker getBreaker(String name) { + if (name.equals(CircuitBreaker.REQUEST) == false) { + throw new UnsupportedOperationException("Only REQUEST circuit breaker is supported"); + } + return breaker; + } + + @Override + public AllCircuitBreakerStats stats() { + throw new UnsupportedOperationException("Circuit breaker stats not supported in fold context"); + } + + @Override + public CircuitBreakerStats stats(String name) { + throw new UnsupportedOperationException("Circuit breaker stats not supported in fold context"); + } + }, CircuitBreaker.REQUEST).withCircuitBreaking(); + + DriverContext driverContext = new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays)); + + // Create the inference operator for the specific function type using the provider + try { + Operator inferenceOperator = inferenceOperatorProvider.getOperator(f, driverContext); + + try { + // Feed the operator with a single page to trigger execution + // The actual input data is already bound in the operator through expression evaluators + inferenceOperator.addInput(new Page(1)); + + // Execute the inference operation asynchronously and handle the result + // The operator will perform the actual inference call and return a page with the result + driverContext.waitForAsyncActions(listener.delegateFailureIgnoreResponseAndWrap(l -> { + Page output = inferenceOperator.getOutput(); + + try { + if (output == null) { + l.onFailure(new IllegalStateException("Expected output page from inference operator")); + return; + } + + if (output.getPositionCount() != 1 || output.getBlockCount() != 1) { + l.onFailure(new IllegalStateException("Expected a single block with a single value from inference operator")); + return; + } + + // Convert the operator result back to an ESQL expression (Literal) + l.onResponse(Literal.of(f, BlockUtils.toJavaObject(output.getBlock(0), 0))); + } finally { + Releasables.close(inferenceOperator); + if (output != null) { + output.releaseBlocks(); + } + } + })); + } catch (Exception e) { + Releasables.close(inferenceOperator); + listener.onFailure(e); + } + } catch (Exception e) { + listener.onFailure(e); + } finally { + driverContext.finish(); + } + } + + /** + * Functional interface for providing inference operators. + *

+ * This interface abstracts the creation of inference operators for different function types, + * allowing for easier testing and potential future extensibility. The provider is responsible + * for creating an appropriate operator instance given an inference function and driver context. + */ + interface InferenceOperatorProvider { + /** + * Creates an inference operator for the given function and driver context. + * + * @param f the inference function to create an operator for + * @param driverContext the driver context to use for operator creation + * @return an operator instance configured for the given function + */ + Operator getOperator(InferenceFunction f, DriverContext driverContext); + } + + /** + * Factory for creating {@link InferenceFunctionEvaluator} instances. + */ + public static class Factory { + private Factory() {} + + /** + * Creates a new inference function evaluator. + * + * @param foldContext the fold context + * @param inferenceService the inference service + * @return a new instance of {@link InferenceFunctionEvaluator} + */ + public InferenceFunctionEvaluator create(FoldContext foldContext, InferenceService inferenceService) { + return new InferenceFunctionEvaluator(foldContext, createInferenceOperatorProvider(foldContext, inferenceService)); + } + + /** + * Creates an {@link InferenceOperatorProvider} that can produce operators for all supported inference functions. + */ + private InferenceOperatorProvider createInferenceOperatorProvider(FoldContext foldContext, InferenceService inferenceService) { + return (inferenceFunction, driverContext) -> { + Operator.OperatorFactory operatorFactory = switch (inferenceFunction) { + case TextEmbedding textEmbedding -> new TextEmbeddingOperator.Factory( + inferenceService, + inferenceId(inferenceFunction, foldContext), + expressionEvaluatorFactory(textEmbedding.inputText(), foldContext) + ); + default -> throw new IllegalArgumentException("Unknown inference function: " + inferenceFunction.getClass().getName()); + }; + + return operatorFactory.get(driverContext); + }; + } + + /** + * Extracts the inference endpoint ID from an inference function. + * + * @param f the inference function containing the inference ID + * @return the inference endpoint ID as a string + */ + private String inferenceId(InferenceFunction f, FoldContext foldContext) { + return BytesRefs.toString(f.inferenceId().fold(foldContext)); + } + + /** + * Creates an expression evaluator factory for a foldable expression. + *

+ * This method converts a foldable expression into an evaluator factory that can be used by inference + * operators. The expression is first folded to its constant value and then wrapped in a literal. + * + * @param e the foldable expression to create an evaluator factory for + * @return an expression evaluator factory for the given expression + */ + private EvalOperator.ExpressionEvaluator.Factory expressionEvaluatorFactory(Expression e, FoldContext foldContext) { + assert e.foldable() : "Input expression must be foldable"; + return EvalMapper.toEvaluator(foldContext, Literal.of(foldContext, e), null); + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java index fdd8e1318f636..11c57b38c2331 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java @@ -8,8 +8,13 @@ package org.elasticsearch.xpack.esql.optimizer; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer.FoldInferenceFunctions; +import org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer.LogicalPlanPreOptimizerRule; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import java.util.List; + /** * The class is responsible for invoking any steps that need to be applied to the logical plan, * before this is being optimized. @@ -18,11 +23,14 @@ *

*/ public class LogicalPlanPreOptimizer { - - private final LogicalPreOptimizerContext preOptimizerContext; + private final List preOptimizerRules; public LogicalPlanPreOptimizer(LogicalPreOptimizerContext preOptimizerContext) { - this.preOptimizerContext = preOptimizerContext; + this(List.of(new FoldInferenceFunctions(preOptimizerContext))); + } + + LogicalPlanPreOptimizer(List preOptimizerRules) { + this.preOptimizerRules = preOptimizerRules; } /** @@ -43,8 +51,17 @@ public void preOptimize(LogicalPlan plan, ActionListener listener) })); } + /** + * Loop over the rules and apply them sequentially to the logical plan. + * + * @param plan the analyzed logical plan to pre-optimize + * @param listener the listener returning the pre-optimized plan when pre-optimization is complete + */ private void doPreOptimize(LogicalPlan plan, ActionListener listener) { - // this is where we will be executing async tasks - listener.onResponse(plan); + SubscribableListener rulesListener = SubscribableListener.newSucceeded(plan); + for (LogicalPlanPreOptimizerRule preOptimizerRule : preOptimizerRules) { + rulesListener = rulesListener.andThen((l, p) -> preOptimizerRule.apply(p, l)); + } + rulesListener.addListener(listener); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPreOptimizerContext.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPreOptimizerContext.java index d082bd56fc46d..d08c639d0470b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPreOptimizerContext.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPreOptimizerContext.java @@ -8,36 +8,11 @@ package org.elasticsearch.xpack.esql.optimizer; import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.inference.InferenceService; -import java.util.Objects; - -public class LogicalPreOptimizerContext { - - private final FoldContext foldCtx; - - public LogicalPreOptimizerContext(FoldContext foldCtx) { - this.foldCtx = foldCtx; - } - - public FoldContext foldCtx() { - return foldCtx; - } - - @Override - public boolean equals(Object obj) { - if (obj == this) return true; - if (obj == null || obj.getClass() != this.getClass()) return false; - var that = (LogicalPreOptimizerContext) obj; - return this.foldCtx.equals(that.foldCtx); - } - - @Override - public int hashCode() { - return Objects.hash(foldCtx); - } +/** + * Context passed to logical pre-optimizer rules. + */ +public record LogicalPreOptimizerContext(FoldContext foldCtx, InferenceService inferenceService) { - @Override - public String toString() { - return "LogicalPreOptimizerContext[foldCtx=" + foldCtx + ']'; - } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/FoldInferenceFunctions.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/FoldInferenceFunctions.java new file mode 100644 index 0000000000000..4a3fc87c43795 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/FoldInferenceFunctions.java @@ -0,0 +1,131 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.CountDownActionListener; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction; +import org.elasticsearch.xpack.esql.inference.InferenceFunctionEvaluator; +import org.elasticsearch.xpack.esql.optimizer.LogicalPreOptimizerContext; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Pre-optimizer rule that performs constant folding for inference functions in logical plans. + *

+ * This rule identifies inference functions with constant parameters and evaluates them at optimization time, + * replacing them with their computed results. + *

+ * The folding process is recursive and handles nested inference functions by processing them in multiple + * passes until no more foldable functions remain. + *

+ * Example transformation: + * {@code TEXT_EMBEDDING("hello world", "model1")} → {@code [0.1, 0.2, 0.3, ...]} + */ +public class FoldInferenceFunctions implements LogicalPlanPreOptimizerRule { + + private final InferenceFunctionEvaluator inferenceFunctionEvaluator; + + public FoldInferenceFunctions(LogicalPreOptimizerContext preOptimizerContext) { + this(InferenceFunctionEvaluator.factory().create(preOptimizerContext.foldCtx(), preOptimizerContext.inferenceService())); + } + + protected FoldInferenceFunctions(InferenceFunctionEvaluator inferenceFunctionEvaluator) { + this.inferenceFunctionEvaluator = inferenceFunctionEvaluator; + } + + @Override + public void apply(LogicalPlan plan, ActionListener listener) { + foldInferenceFunctions(plan, listener); + } + + /** + * Recursively folds inference functions in the logical plan. + *

+ * This method collects all foldable inference functions, evaluates them in parallel, + * and then replaces them with their computed results. If new foldable inference functions are remaining + * after the first round of folding (due to nested function resolution), it recursively processes + * them until no more foldable functions remain. + *

+ * + * @param plan the logical plan to fold inference functions in + * @param listener the listener to notify when the folding is complete + */ + private void foldInferenceFunctions(LogicalPlan plan, ActionListener listener) { + // Collect all foldable inference functions from the current plan + List> inferenceFunctions = collectFoldableInferenceFunctions(plan); + + if (inferenceFunctions.isEmpty()) { + // No foldable inference functions were found - return the original plan unchanged + listener.onResponse(plan); + return; + } + + // Map to store the computed results for each inference function + Map, Expression> inferenceFunctionsToResults = new HashMap<>(); + + // Create a countdown listener that will be triggered when all inference functions complete + // Once all are done, replace the functions in the plan with their results and recursively + // process any remaining foldable inference functions + CountDownActionListener completionListener = new CountDownActionListener( + inferenceFunctions.size(), + listener.delegateFailureIgnoreResponseAndWrap(l -> { + // Transform the plan by replacing inference functions with their computed results + LogicalPlan transformedPlan = plan.transformExpressionsUp( + InferenceFunction.class, + f -> inferenceFunctionsToResults.getOrDefault(f, f) + ); + + // Recursively process the transformed plan to handle any remaining inference functions + foldInferenceFunctions(transformedPlan, l); + }) + ); + + // Evaluate each inference function asynchronously + for (InferenceFunction inferenceFunction : inferenceFunctions) { + inferenceFunctionEvaluator.fold(inferenceFunction, completionListener.delegateFailureAndWrap((l, result) -> { + inferenceFunctionsToResults.put(inferenceFunction, result); + l.onResponse(null); + })); + } + } + + /** + * Collects all foldable inference functions from the logical plan. + *

+ * A function is considered foldable if it meets all of the following criteria: + *

    + *
  1. It's an instance of {@link InferenceFunction}
  2. + *
  3. It's marked as foldable (all parameters are constants)
  4. + *
  5. It doesn't contain nested inference functions (to avoid dependency issues)
  6. + *
+ *

+ * Functions with nested inference functions are excluded to ensure proper evaluation order. + * They will be considered for folding in subsequent recursive passes after their nested + * functions have been resolved. + * + * @param plan the logical plan to collect inference functions from + * @return a list of foldable inference functions, may be empty if none are found + */ + private List> collectFoldableInferenceFunctions(LogicalPlan plan) { + List> inferenceFunctions = new ArrayList<>(); + + plan.forEachExpressionUp(InferenceFunction.class, f -> { + if (f.foldable() && f.hasNestedInferenceFunction() == false) { + inferenceFunctions.add(f); + } + }); + + return inferenceFunctions; + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/LogicalPlanPreOptimizerRule.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/LogicalPlanPreOptimizerRule.java new file mode 100644 index 0000000000000..29d9be564f1bf --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/LogicalPlanPreOptimizerRule.java @@ -0,0 +1,25 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; + +/** + * A rule that can be applied to an analyzed logical plan before it is optimized. + */ +public interface LogicalPlanPreOptimizerRule { + + /** + * Apply the rule to the logical plan. + * + * @param plan the analyzed logical plan to pre-optimize + * @param listener the listener returning the pre-optimized plan when pre-optimization rule is applied + */ + void apply(LogicalPlan plan, ActionListener listener); +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java index 56b97759db83c..c3eece1f85b4d 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java @@ -589,7 +589,7 @@ private ActualResults executePlan(BigArrays bigArrays) throws Exception { null, null, null, - new LogicalPlanPreOptimizer(new LogicalPreOptimizerContext(foldCtx)), + new LogicalPlanPreOptimizer(new LogicalPreOptimizerContext(foldCtx, mock(InferenceService.class))), functionRegistry, new LogicalPlanOptimizer(new LogicalOptimizerContext(configuration, foldCtx)), mapper, diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingErrorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingErrorTests.java new file mode 100644 index 0000000000000..3864c8ccc8685 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingErrorTests.java @@ -0,0 +1,76 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.inference; + +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; +import org.elasticsearch.xpack.esql.expression.function.ErrorsForCasesWithoutExamplesTestCase; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.hamcrest.Matcher; +import org.junit.Before; + +import java.util.List; +import java.util.Locale; +import java.util.Set; + +import static org.hamcrest.Matchers.equalTo; + +/** + * Tests error conditions and type validation for TEXT_EMBEDDING function. + */ +public class TextEmbeddingErrorTests extends ErrorsForCasesWithoutExamplesTestCase { + + @Before + public void checkCapability() { + assumeTrue("TEXT_EMBEDDING is not enabled", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()); + } + + @Override + protected List cases() { + return paramsToSuppliers(TextEmbeddingTests.parameters()); + } + + @Override + protected Expression build(Source source, List args) { + return new TextEmbedding(source, args.get(0), args.get(1)); + } + + @Override + protected Matcher expectedTypeErrorMatcher(List> validPerPosition, List signature) { + return equalTo(typeErrorMessage(true, validPerPosition, signature, (v, p) -> "string")); + } + + protected static String typeErrorMessage( + boolean includeOrdinal, + List> validPerPosition, + List signature, + AbstractFunctionTestCase.PositionalErrorMessageSupplier positionalErrorMessageSupplier + ) { + for (int i = 0; i < signature.size(); i++) { + if (signature.get(i) == DataType.NULL) { + String ordinal = includeOrdinal ? TypeResolutions.ParamOrdinal.fromIndex(i).name().toLowerCase(Locale.ROOT) + " " : ""; + return ordinal + "argument of [" + sourceForSignature(signature) + "] cannot be null, received []"; + } + + if (validPerPosition.get(i).contains(signature.get(i)) == false) { + break; + } + } + + return ErrorsForCasesWithoutExamplesTestCase.typeErrorMessage( + includeOrdinal, + validPerPosition, + signature, + positionalErrorMessageSupplier + ); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingTests.java new file mode 100644 index 0000000000000..f7a1ed4cf2025 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingTests.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.inference; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; +import org.elasticsearch.xpack.esql.expression.function.FunctionName; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.hamcrest.Matchers; +import org.junit.Before; + +import java.util.List; +import java.util.function.Supplier; + +import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; +import static org.elasticsearch.xpack.esql.core.type.DataType.KEYWORD; +import static org.hamcrest.Matchers.equalTo; + +@FunctionName("text_embedding") +public class TextEmbeddingTests extends AbstractFunctionTestCase { + @Before + public void checkCapability() { + assumeTrue("TEXT_EMBEDDING is not enabled", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()); + } + + public TextEmbeddingTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + return parameterSuppliersFromTypedData( + List.of( + new TestCaseSupplier( + List.of(KEYWORD, KEYWORD), + () -> new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(randomBytesReference(10).toBytesRef(), KEYWORD, "text"), + new TestCaseSupplier.TypedData(randomBytesReference(10).toBytesRef(), KEYWORD, "inference_id") + ), + Matchers.blankOrNullString(), + DENSE_VECTOR, + equalTo(true) + ) + ) + ) + ); + } + + @Override + protected Expression build(Source source, List args) { + return new TextEmbedding(source, args.get(0), args.get(1)); + } + + @Override + protected boolean canSerialize() { + return false; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluatorTests.java new file mode 100644 index 0000000000000..b4c63e6553b95 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluatorTests.java @@ -0,0 +1,186 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.compute.test.ComputeTestCase; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction; +import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding; +import org.junit.After; +import org.junit.Before; + +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; + +import static org.elasticsearch.xpack.esql.EsqlTestUtils.as; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class InferenceFunctionEvaluatorTests extends ComputeTestCase { + + private ThreadPool threadPool; + + @Before + public void setupThreadPool() { + this.threadPool = createThreadPool(); + } + + @After + public void tearDownThreadPool() { + terminate(threadPool); + } + + public void testFoldTextEmbeddingFunction() throws Exception { + // Create a mock TextEmbedding function + TextEmbedding textEmbeddingFunction = new TextEmbedding( + Source.EMPTY, + Literal.keyword(Source.EMPTY, "test-model"), + Literal.keyword(Source.EMPTY, "test input") + ); + + // Create a mock operator that returns a result + Operator operator = mock(Operator.class); + + Float[] embedding = randomArray(1, 100, Float[]::new, ESTestCase::randomFloat); + + when(operator.getOutput()).thenAnswer(i -> { + FloatBlock.Builder outputBlockBuilder = blockFactory().newFloatBlockBuilder(1).beginPositionEntry(); + + for (int j = 0; j < embedding.length; j++) { + outputBlockBuilder.appendFloat(embedding[j]); + } + + outputBlockBuilder.endPositionEntry(); + + return new Page(outputBlockBuilder.build()); + }); + + InferenceFunctionEvaluator.InferenceOperatorProvider inferenceOperatorProvider = (f, driverContext) -> operator; + + // Execute the fold operation + InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator(FoldContext.small(), inferenceOperatorProvider); + + AtomicReference resultExpression = new AtomicReference<>(); + evaluator.fold(textEmbeddingFunction, ActionListener.wrap(resultExpression::set, ESTestCase::fail)); + + assertBusy(() -> { + assertNotNull(resultExpression.get()); + Literal result = as(resultExpression.get(), Literal.class); + assertThat(result.dataType(), equalTo(DataType.DENSE_VECTOR)); + assertThat(as(result.value(), List.class).toArray(), equalTo(embedding)); + }); + + // Check all breakers are empty after the operation is executed + allBreakersEmpty(); + } + + public void testFoldWithNonFoldableFunction() { + // A function with a non-literal argument is not foldable. + TextEmbedding textEmbeddingFunction = new TextEmbedding( + Source.EMPTY, + mock(Attribute.class), + Literal.keyword(Source.EMPTY, "test input") + ); + + InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator( + FoldContext.small(), + (f, driverContext) -> mock(Operator.class) + ); + + AtomicReference error = new AtomicReference<>(); + evaluator.fold(textEmbeddingFunction, ActionListener.wrap(r -> fail("should have failed"), error::set)); + + assertNotNull(error.get()); + assertThat(error.get(), instanceOf(IllegalArgumentException.class)); + assertThat(error.get().getMessage(), equalTo("Inference function must be foldable")); + } + + public void testFoldWithAsyncFailure() throws Exception { + TextEmbedding textEmbeddingFunction = new TextEmbedding( + Source.EMPTY, + Literal.keyword(Source.EMPTY, "test-model"), + Literal.keyword(Source.EMPTY, "test input") + ); + + // Mock an operator that will trigger an async failure + Operator operator = mock(Operator.class); + doAnswer(invocation -> { + // Simulate the operator finishing and then immediately calling the failure listener. + // In that case getOutput() will replay the failure when called allowing us to catch the error. + throw new RuntimeException("async failure"); + }).when(operator).getOutput(); + + InferenceFunctionEvaluator.InferenceOperatorProvider inferenceOperatorProvider = (f, driverContext) -> operator; + InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator(FoldContext.small(), inferenceOperatorProvider); + + AtomicReference error = new AtomicReference<>(); + evaluator.fold(textEmbeddingFunction, ActionListener.wrap(r -> fail("should have failed"), error::set)); + + assertBusy(() -> assertNotNull(error.get())); + assertThat(error.get(), instanceOf(RuntimeException.class)); + assertThat(error.get().getMessage(), equalTo("async failure")); + + allBreakersEmpty(); + } + + public void testFoldWithNullOutputPage() throws Exception { + TextEmbedding textEmbeddingFunction = new TextEmbedding( + Source.EMPTY, + Literal.keyword(Source.EMPTY, "test-model"), + Literal.keyword(Source.EMPTY, "test input") + ); + + Operator operator = mock(Operator.class); + when(operator.getOutput()).thenReturn(null); + + InferenceFunctionEvaluator.InferenceOperatorProvider inferenceOperatorProvider = (f, driverContext) -> operator; + InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator(FoldContext.small(), inferenceOperatorProvider); + + AtomicReference error = new AtomicReference<>(); + evaluator.fold(textEmbeddingFunction, ActionListener.wrap(r -> fail("should have failed"), error::set)); + + assertBusy(() -> assertNotNull(error.get())); + assertThat(error.get(), instanceOf(IllegalStateException.class)); + assertThat(error.get().getMessage(), equalTo("Expected output page from inference operator")); + + allBreakersEmpty(); + } + + public void testFoldWithUnsupportedFunction() throws Exception { + InferenceFunction unsupported = mock(InferenceFunction.class); + when(unsupported.foldable()).thenReturn(true); + + InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator(FoldContext.small(), (f, driverContext) -> { + throw new IllegalArgumentException("Unknown inference function: " + f.getClass().getName()); + }); + + AtomicReference error = new AtomicReference<>(); + evaluator.fold(unsupported, ActionListener.wrap(r -> fail("should have failed"), error::set)); + + assertNotNull(error.get()); + assertThat(error.get(), instanceOf(IllegalArgumentException.class)); + assertThat(error.get().getMessage(), containsString("Unknown inference function")); + + allBreakersEmpty(); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java index 8e573dd1cf3c9..ad7c1dbea1a50 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java @@ -30,6 +30,7 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.of; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.notNullValue; +import static org.mockito.Mockito.mock; public class LogicalPlanPreOptimizerTests extends ESTestCase { @@ -72,7 +73,8 @@ public LogicalPlan preOptimizedPlan(LogicalPlan plan) throws Exception { } private LogicalPlanPreOptimizer preOptimizer() { - LogicalPreOptimizerContext preOptimizerContext = new LogicalPreOptimizerContext(FoldContext.small()); + var inferenceService = mock(org.elasticsearch.xpack.esql.inference.InferenceService.class); + LogicalPreOptimizerContext preOptimizerContext = new LogicalPreOptimizerContext(FoldContext.small(), inferenceService); return new LogicalPlanPreOptimizer(preOptimizerContext); }