From da21972506db6a46e10a98237ce1156f4e85a2ee Mon Sep 17 00:00:00 2001 From: afoucret Date: Thu, 17 Jul 2025 12:49:25 +0200 Subject: [PATCH 01/31] ESQL: Add asynchronous pre-optimization step for logical plan --- .../_nightly/esql/QueryPlanningBenchmark.java | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/QueryPlanningBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/QueryPlanningBenchmark.java index 3b4d445002073..2a99bcb96f1e0 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/QueryPlanningBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/QueryPlanningBenchmark.java @@ -9,6 +9,7 @@ package org.elasticsearch.benchmark._nightly.esql; +import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.logging.LogConfigurator; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.IndexMode; @@ -26,6 +27,8 @@ import org.elasticsearch.xpack.esql.inference.InferenceResolution; import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer; +import org.elasticsearch.xpack.esql.optimizer.LogicalPlanPreOptimizer; +import org.elasticsearch.xpack.esql.optimizer.LogicalPreOptimizerContext; import org.elasticsearch.xpack.esql.parser.EsqlParser; import org.elasticsearch.xpack.esql.parser.QueryParams; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; @@ -70,6 +73,7 @@ public class QueryPlanningBenchmark { private EsqlParser defaultParser; private Analyzer manyFieldsAnalyzer; private LogicalPlanOptimizer defaultOptimizer; + private LogicalPlanPreOptimizer defaultPreOptimizer; private Configuration config; @Setup @@ -112,18 +116,22 @@ public void setup() { ), new Verifier(new Metrics(functionRegistry), new XPackLicenseState(() -> 0L)) ); + defaultOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(config, FoldContext.small())); + defaultPreOptimizer = new LogicalPlanPreOptimizer(new LogicalPreOptimizerContext(FoldContext.small())); } - private LogicalPlan plan(EsqlParser parser, Analyzer analyzer, LogicalPlanOptimizer optimizer, String query) { + private LogicalPlan plan(EsqlParser parser, Analyzer analyzer, LogicalPlanOptimizer optimizer, LogicalPlanPreOptimizer preOptimizer, String query) { + PlainActionFuture future = new PlainActionFuture<>(); var parsed = parser.createStatement(query, new QueryParams(), telemetry, config); var analyzed = analyzer.analyze(parsed); - var optimized = optimizer.optimize(analyzed); - return optimized; + analyzed.setAnalyzed(); + preOptimizer.preOptimize(analyzed, future.map(optimizer::optimize)); + return future.actionGet(); } @Benchmark public void manyFields(Blackhole blackhole) { - blackhole.consume(plan(defaultParser, manyFieldsAnalyzer, defaultOptimizer, "FROM test | LIMIT 10")); + blackhole.consume(plan(defaultParser, manyFieldsAnalyzer, defaultOptimizer, defaultPreOptimizer, "FROM test | LIMIT 10")); } } From 8b5082b6218499d1e7d4ee927bbe9d77cfb156a0 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Thu, 17 Jul 2025 12:45:57 +0000 Subject: [PATCH 02/31] [CI] Auto commit changes from spotless --- .../benchmark/_nightly/esql/QueryPlanningBenchmark.java | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/QueryPlanningBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/QueryPlanningBenchmark.java index 2a99bcb96f1e0..db97db98e6a2b 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/QueryPlanningBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/QueryPlanningBenchmark.java @@ -121,7 +121,13 @@ public void setup() { defaultPreOptimizer = new LogicalPlanPreOptimizer(new LogicalPreOptimizerContext(FoldContext.small())); } - private LogicalPlan plan(EsqlParser parser, Analyzer analyzer, LogicalPlanOptimizer optimizer, LogicalPlanPreOptimizer preOptimizer, String query) { + private LogicalPlan plan( + EsqlParser parser, + Analyzer analyzer, + LogicalPlanOptimizer optimizer, + LogicalPlanPreOptimizer preOptimizer, + String query + ) { PlainActionFuture future = new PlainActionFuture<>(); var parsed = parser.createStatement(query, new QueryParams(), telemetry, config); var analyzed = analyzer.analyze(parsed); From 658b54b4a1240aab8b75bb5eb24e8bd4862467c8 Mon Sep 17 00:00:00 2001 From: afoucret Date: Thu, 17 Jul 2025 15:41:34 +0200 Subject: [PATCH 03/31] Revert uselss change in QueryPlanningBenchmark --- .../_nightly/esql/QueryPlanningBenchmark.java | 22 ++++--------------- 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/QueryPlanningBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/QueryPlanningBenchmark.java index db97db98e6a2b..3b4d445002073 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/QueryPlanningBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/QueryPlanningBenchmark.java @@ -9,7 +9,6 @@ package org.elasticsearch.benchmark._nightly.esql; -import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.logging.LogConfigurator; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.IndexMode; @@ -27,8 +26,6 @@ import org.elasticsearch.xpack.esql.inference.InferenceResolution; import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer; -import org.elasticsearch.xpack.esql.optimizer.LogicalPlanPreOptimizer; -import org.elasticsearch.xpack.esql.optimizer.LogicalPreOptimizerContext; import org.elasticsearch.xpack.esql.parser.EsqlParser; import org.elasticsearch.xpack.esql.parser.QueryParams; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; @@ -73,7 +70,6 @@ public class QueryPlanningBenchmark { private EsqlParser defaultParser; private Analyzer manyFieldsAnalyzer; private LogicalPlanOptimizer defaultOptimizer; - private LogicalPlanPreOptimizer defaultPreOptimizer; private Configuration config; @Setup @@ -116,28 +112,18 @@ public void setup() { ), new Verifier(new Metrics(functionRegistry), new XPackLicenseState(() -> 0L)) ); - defaultOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(config, FoldContext.small())); - defaultPreOptimizer = new LogicalPlanPreOptimizer(new LogicalPreOptimizerContext(FoldContext.small())); } - private LogicalPlan plan( - EsqlParser parser, - Analyzer analyzer, - LogicalPlanOptimizer optimizer, - LogicalPlanPreOptimizer preOptimizer, - String query - ) { - PlainActionFuture future = new PlainActionFuture<>(); + private LogicalPlan plan(EsqlParser parser, Analyzer analyzer, LogicalPlanOptimizer optimizer, String query) { var parsed = parser.createStatement(query, new QueryParams(), telemetry, config); var analyzed = analyzer.analyze(parsed); - analyzed.setAnalyzed(); - preOptimizer.preOptimize(analyzed, future.map(optimizer::optimize)); - return future.actionGet(); + var optimized = optimizer.optimize(analyzed); + return optimized; } @Benchmark public void manyFields(Blackhole blackhole) { - blackhole.consume(plan(defaultParser, manyFieldsAnalyzer, defaultOptimizer, defaultPreOptimizer, "FROM test | LIMIT 10")); + blackhole.consume(plan(defaultParser, manyFieldsAnalyzer, defaultOptimizer, "FROM test | LIMIT 10")); } } From 76831a096727c53ca63abd699c20f42aec8eb479 Mon Sep 17 00:00:00 2001 From: afoucret Date: Fri, 11 Jul 2025 10:58:59 +0200 Subject: [PATCH 04/31] Add EMBED_TEXT function infrastructure for dense vector embeddings --- .../esql/images/functions/embed_text.svg | 1 + .../definition/functions/embed_text.json | 9 ++ .../esql/kibana/docs/functions/embed_text.md | 4 + .../xpack/esql/action/EsqlCapabilities.java | 5 + .../esql/expression/ExpressionWritables.java | 9 ++ .../function/EsqlFunctionRegistry.java | 2 + .../function/inference/EmbedText.java | 150 ++++++++++++++++++ .../function/inference/InferenceFunction.java | 21 +++ .../inference/EmbedTextErrorTests.java | 73 +++++++++ .../EmbedTextSerializationTests.java | 45 ++++++ .../function/inference/EmbedTextTests.java | 72 +++++++++ 11 files changed, 391 insertions(+) create mode 100644 docs/reference/query-languages/esql/images/functions/embed_text.svg create mode 100644 docs/reference/query-languages/esql/kibana/definition/functions/embed_text.json create mode 100644 docs/reference/query-languages/esql/kibana/docs/functions/embed_text.md create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedText.java create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextErrorTests.java create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextSerializationTests.java create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextTests.java diff --git a/docs/reference/query-languages/esql/images/functions/embed_text.svg b/docs/reference/query-languages/esql/images/functions/embed_text.svg new file mode 100644 index 0000000000000..9bb6cab692c4e --- /dev/null +++ b/docs/reference/query-languages/esql/images/functions/embed_text.svg @@ -0,0 +1 @@ +EMBED_TEXT(text,inference_id) \ No newline at end of file diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/embed_text.json b/docs/reference/query-languages/esql/kibana/definition/functions/embed_text.json new file mode 100644 index 0000000000000..edfafb213e16d --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/definition/functions/embed_text.json @@ -0,0 +1,9 @@ +{ + "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.", + "type" : "scalar", + "name" : "embed_text", + "description" : "Generates dense vector embeddings for text using a specified inference deployment.", + "signatures" : [ ], + "preview" : true, + "snapshot_only" : true +} diff --git a/docs/reference/query-languages/esql/kibana/docs/functions/embed_text.md b/docs/reference/query-languages/esql/kibana/docs/functions/embed_text.md new file mode 100644 index 0000000000000..fc6d8d0772d64 --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/docs/functions/embed_text.md @@ -0,0 +1,4 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +### EMBED TEXT +Generates dense vector embeddings for text using a specified inference deployment. diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 6a2b112b58deb..0e1771400bd87 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -1296,6 +1296,11 @@ public enum Cap { */ LIKE_ON_INDEX_FIELDS, + /** + * Support for the {@code EMBED_TEXT} function for generating dense vector embeddings. + */ + EMBED_TEXT_FUNCTION(Build.current().isSnapshot()), + /** * Forbid usage of brackets in unquoted index and enrich policy names * https://github.com/elastic/elasticsearch/issues/130378 diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java index 311f666581279..983dea6fd58ba 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java @@ -12,6 +12,7 @@ import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateWritables; import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextWritables; +import org.elasticsearch.xpack.esql.expression.function.inference.EmbedText; import org.elasticsearch.xpack.esql.expression.function.scalar.ScalarFunctionWritables; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FromBase64; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToAggregateMetricDouble; @@ -119,6 +120,7 @@ public static List getNamedWriteables() { entries.addAll(fullText()); entries.addAll(unaryScalars()); entries.addAll(vector()); + entries.addAll(inference()); return entries; } @@ -262,4 +264,11 @@ private static List fullText() { private static List vector() { return VectorWritables.getNamedWritables(); } + + private static List inference() { + if (EsqlCapabilities.Cap.EMBED_TEXT_FUNCTION.isEnabled()) { + return List.of(EmbedText.ENTRY); + } + return List.of(); + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index 649503b1443d2..b643e4f49005b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -53,6 +53,7 @@ import org.elasticsearch.xpack.esql.expression.function.fulltext.Term; import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket; import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize; +import org.elasticsearch.xpack.esql.expression.function.inference.EmbedText; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Least; @@ -485,6 +486,7 @@ private static FunctionDefinition[][] snapshotFunctions() { def(Score.class, uni(Score::new), Score.NAME), def(Term.class, bi(Term::new), "term"), def(Knn.class, quad(Knn::new), "knn"), + def(EmbedText.class, EmbedText::new, "embed_text"), def(StGeohash.class, StGeohash::new, "st_geohash"), def(StGeohashToLong.class, StGeohashToLong::new, "st_geohash_to_long"), def(StGeohashToString.class, StGeohashToString::new, "st_geohash_to_string"), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedText.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedText.java new file mode 100644 index 0000000000000..4c621aa3beba0 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedText.java @@ -0,0 +1,150 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.inference; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.function.Function; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isString; + +/** + * EMBED_TEXT function that generates dense vector embeddings for text using a specified inference deployment. + */ +public class EmbedText extends Function implements InferenceFunction { + + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "EmbedText", + EmbedText::new + ); + + private final Expression inferenceId; + private final Expression inputText; + + @FunctionInfo( + returnType = "dense_vector", + description = "Generates dense vector embeddings for text using a specified inference deployment.", + appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) }, + preview = true + ) + public EmbedText( + Source source, + @Param(name = "text", type = { "keyword", "text" }, description = "Text to embed") Expression inputText, + @Param(name = "inference_id", type = { "keyword", "text" }, description = "Inference deployment ID") Expression inferenceId + ) { + super(source, List.of(inputText, inferenceId)); + this.inferenceId = inferenceId; + this.inputText = inputText; + } + + private EmbedText(StreamInput in) throws IOException { + this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class), in.readNamedWriteable(Expression.class)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + source().writeTo(out); + out.writeNamedWriteable(inputText); + out.writeNamedWriteable(inferenceId); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + public Expression inputText() { + return inputText; + } + + @Override + public Expression inferenceId() { + return inferenceId; + } + + @Override + public DataType dataType() { + return DataType.DENSE_VECTOR; + } + + @Override + protected TypeResolution resolveType() { + if (childrenResolved() == false) { + return new TypeResolution("Unresolved children"); + } + + TypeResolution textResolution = isNotNull(inputText, sourceText(), FIRST).and(isFoldable(inferenceId, sourceText(), FIRST)) + .and(isString(inputText, sourceText(), FIRST)); + + if (textResolution.unresolved()) { + return textResolution; + } + + TypeResolution inferenceIdResolution = isNotNull(inferenceId, sourceText(), SECOND).and(isString(inferenceId, sourceText(), SECOND)) + .and(isFoldable(inferenceId, sourceText(), SECOND)); + + if (inferenceIdResolution.unresolved()) { + return inferenceIdResolution; + } + + return TypeResolution.TYPE_RESOLVED; + } + + @Override + public boolean foldable() { + // The function is foldable only if both arguments are foldable + return inputText.foldable() && inferenceId.foldable(); + } + + @Override + public Expression replaceChildren(List newChildren) { + return new EmbedText(source(), newChildren.get(0), newChildren.get(1)); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, EmbedText::new, inputText, inferenceId); + } + + @Override + public String toString() { + return "EMBED_TEXT(" + inputText + ", " + inferenceId + ")"; + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + if (super.equals(o) == false) return false; + EmbedText embedText = (EmbedText) o; + return Objects.equals(inferenceId, embedText.inferenceId) && Objects.equals(inputText, embedText.inputText); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), inferenceId, inputText); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java new file mode 100644 index 0000000000000..44af928db6a60 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java @@ -0,0 +1,21 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.inference; + +import org.elasticsearch.xpack.esql.core.expression.Expression; + +/** + * A function is a function using an inference model. + */ +public interface InferenceFunction { + + /** + * Returns the inference model ID expression. + */ + Expression inferenceId(); +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextErrorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextErrorTests.java new file mode 100644 index 0000000000000..4b490a86b7587 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextErrorTests.java @@ -0,0 +1,73 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.inference; + +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; +import org.elasticsearch.xpack.esql.expression.function.ErrorsForCasesWithoutExamplesTestCase; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.hamcrest.Matcher; +import org.junit.Before; + +import java.util.List; +import java.util.Locale; +import java.util.Set; + +import static org.hamcrest.Matchers.equalTo; + +public class EmbedTextErrorTests extends ErrorsForCasesWithoutExamplesTestCase { + + @Before + public void checkCapability() { + assumeTrue("EMBED_TEXT is not enabled", EsqlCapabilities.Cap.EMBED_TEXT_FUNCTION.isEnabled()); + } + + @Override + protected List cases() { + return paramsToSuppliers(EmbedTextTests.parameters()); + } + + @Override + protected Expression build(Source source, List args) { + return new EmbedText(source, args.get(0), args.get(1)); + } + + @Override + protected Matcher expectedTypeErrorMatcher(List> validPerPosition, List signature) { + return equalTo(typeErrorMessage(true, validPerPosition, signature, (v, p) -> "string")); + } + + protected static String typeErrorMessage( + boolean includeOrdinal, + List> validPerPosition, + List signature, + AbstractFunctionTestCase.PositionalErrorMessageSupplier positionalErrorMessageSupplier + ) { + for (int i = 0; i < signature.size(); i++) { + if (signature.get(i) == DataType.NULL) { + String ordinal = includeOrdinal ? TypeResolutions.ParamOrdinal.fromIndex(i).name().toLowerCase(Locale.ROOT) + " " : ""; + return ordinal + "argument of [" + sourceForSignature(signature) + "] cannot be null, received []"; + } + + if (validPerPosition.get(i).contains(signature.get(i)) == false) { + break; + } + } + + return ErrorsForCasesWithoutExamplesTestCase.typeErrorMessage( + includeOrdinal, + validPerPosition, + signature, + positionalErrorMessageSupplier + ); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextSerializationTests.java new file mode 100644 index 0000000000000..59d5377e36f76 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextSerializationTests.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.inference; + +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; +import org.junit.Before; + +import java.io.IOException; + +public class EmbedTextSerializationTests extends AbstractExpressionSerializationTests { + + @Before + public void checkCapability() { + assumeTrue("EMBED_TEXT is not enabled", EsqlCapabilities.Cap.EMBED_TEXT_FUNCTION.isEnabled()); + } + + @Override + protected EmbedText createTestInstance() { + Source source = randomSource(); + Expression inputText = randomChild(); + Expression inferenceId = randomChild(); + return new EmbedText(source, inputText, inferenceId); + } + + @Override + protected EmbedText mutateInstance(EmbedText instance) throws IOException { + Source source = instance.source(); + Expression inputText = instance.inputText(); + Expression inferenceId = instance.inferenceId(); + if (randomBoolean()) { + inputText = randomValueOtherThan(inputText, AbstractExpressionSerializationTests::randomChild); + } else { + inferenceId = randomValueOtherThan(inferenceId, AbstractExpressionSerializationTests::randomChild); + } + return new EmbedText(source, inputText, inferenceId); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextTests.java new file mode 100644 index 0000000000000..c342ee143e6af --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextTests.java @@ -0,0 +1,72 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.inference; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; +import org.elasticsearch.xpack.esql.expression.function.FunctionName; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.hamcrest.Matchers; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Supplier; + +import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; +import static org.hamcrest.Matchers.equalTo; + +@FunctionName("embed_text") +public class EmbedTextTests extends AbstractFunctionTestCase { + @Before + public void checkCapability() { + assumeTrue("EMBED_TEXT is not enabled", EsqlCapabilities.Cap.EMBED_TEXT_FUNCTION.isEnabled()); + } + + public EmbedTextTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + List suppliers = new ArrayList<>(); + + // Valid cases with string types for input text and inference_id + for (DataType inputTextDataType : DataType.stringTypes()) { + for (DataType inferenceIdDataType : DataType.stringTypes()) { + suppliers.add( + new TestCaseSupplier( + List.of(inputTextDataType, inferenceIdDataType), + () -> new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(randomBytesReference(10).toBytesRef(), inputTextDataType, "inputText"), + new TestCaseSupplier.TypedData(randomBytesReference(10).toBytesRef(), inferenceIdDataType, "inference_id") + ), + Matchers.blankOrNullString(), + DENSE_VECTOR, + equalTo(true) + ) + ) + ); + } + } + + return parameterSuppliersFromTypedData(suppliers); + } + + @Override + protected Expression build(Source source, List args) { + return new EmbedText(source, args.get(0), args.get(1)); + } +} From f735f4651143a0b7a7deee0ee66a25b59aae40db Mon Sep 17 00:00:00 2001 From: afoucret Date: Fri, 11 Jul 2025 18:39:28 +0200 Subject: [PATCH 05/31] Extend InferenceResolver to collect inference IDs from functions and plans --- .../xpack/esql/expression/function/inference/EmbedText.java | 6 +++++- .../expression/function/inference/InferenceFunction.java | 2 ++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedText.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedText.java index 4c621aa3beba0..5709d734c71d4 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedText.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedText.java @@ -54,7 +54,11 @@ public class EmbedText extends Function implements InferenceFunction { public EmbedText( Source source, @Param(name = "text", type = { "keyword", "text" }, description = "Text to embed") Expression inputText, - @Param(name = "inference_id", type = { "keyword", "text" }, description = "Inference deployment ID") Expression inferenceId + @Param( + name = InferenceFunction.INFERENCE_ID_PARAMETER_NAME, + type = { "keyword", "text" }, + description = "Inference deployment ID" + ) Expression inferenceId ) { super(source, List.of(inputText, inferenceId)); this.inferenceId = inferenceId; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java index 44af928db6a60..dbd28608b3495 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java @@ -14,6 +14,8 @@ */ public interface InferenceFunction { + String INFERENCE_ID_PARAMETER_NAME = "inference_id"; + /** * Returns the inference model ID expression. */ From f03e1553243c9c794b4957b80b8900db3638d191 Mon Sep 17 00:00:00 2001 From: afoucret Date: Mon, 28 Jul 2025 11:49:19 +0200 Subject: [PATCH 06/31] Add Analyzer support for EMBED_TEXT inference functions # Conflicts: # x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java # x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java # x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java # x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceResolverTests.java --- .../function/inference/EmbedText.java | 21 ++- .../function/inference/InferenceFunction.java | 22 ++- .../esql/analysis/AnalyzerTestUtils.java | 1 + .../xpack/esql/analysis/AnalyzerTests.java | 130 +++++++++++++++++- 4 files changed, 164 insertions(+), 10 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedText.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedText.java index 5709d734c71d4..cc63f7e465f48 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedText.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedText.java @@ -10,8 +10,9 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.expression.function.Function; +import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -25,16 +26,14 @@ import java.util.List; import java.util.Objects; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.*; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; -import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable; -import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull; -import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isString; /** * EMBED_TEXT function that generates dense vector embeddings for text using a specified inference deployment. */ -public class EmbedText extends Function implements InferenceFunction { +public class EmbedText extends InferenceFunction { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( Expression.class, @@ -101,7 +100,7 @@ protected TypeResolution resolveType() { return new TypeResolution("Unresolved children"); } - TypeResolution textResolution = isNotNull(inputText, sourceText(), FIRST).and(isFoldable(inferenceId, sourceText(), FIRST)) + TypeResolution textResolution = isNotNull(inputText, sourceText(), FIRST).and(isFoldable(inputText, sourceText(), FIRST)) .and(isString(inputText, sourceText(), FIRST)); if (textResolution.unresolved()) { @@ -124,6 +123,16 @@ public boolean foldable() { return inputText.foldable() && inferenceId.foldable(); } + @Override + public TaskType taskType() { + return TaskType.TEXT_EMBEDDING; + } + + @Override + public EmbedText withInferenceResolutionError(String inferenceId, String error) { + return new EmbedText(source(), inputText, new UnresolvedAttribute(inferenceId().source(), inferenceId, error)); + } + @Override public Expression replaceChildren(List newChildren) { return new EmbedText(source(), newChildren.get(0), newChildren.get(1)); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java index dbd28608b3495..d32587dc71058 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java @@ -7,17 +7,33 @@ package org.elasticsearch.xpack.esql.expression.function.inference; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.function.Function; +import org.elasticsearch.xpack.esql.core.tree.Source; + +import java.util.List; /** * A function is a function using an inference model. */ -public interface InferenceFunction { +public abstract class InferenceFunction> extends Function { + + public static final String INFERENCE_ID_PARAMETER_NAME = "inference_id"; - String INFERENCE_ID_PARAMETER_NAME = "inference_id"; + protected InferenceFunction(Source source, List children) { + super(source, children); + } /** * Returns the inference model ID expression. */ - Expression inferenceId(); + public abstract Expression inferenceId(); + + /** + * Returns the task type of the inference model. + */ + public abstract TaskType taskType(); + + public abstract PlanType withInferenceResolutionError(String inferenceId, String error); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java index fbfa18dccc477..a1704be3c3cfa 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java @@ -199,6 +199,7 @@ public static InferenceResolution defaultInferenceResolution() { return InferenceResolution.builder() .withResolvedInference(new ResolvedInference("reranking-inference-id", TaskType.RERANK)) .withResolvedInference(new ResolvedInference("completion-inference-id", TaskType.COMPLETION)) + .withResolvedInference(new ResolvedInference("text-embedding-inference-id", TaskType.TEXT_EMBEDDING)) .withError("error-inference-id", "error with inference resolution") .build(); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index 7d6ef0b5fae46..17725a63aa777 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -50,6 +50,7 @@ import org.elasticsearch.xpack.esql.expression.function.fulltext.MultiMatch; import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString; import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket; +import org.elasticsearch.xpack.esql.expression.function.inference.EmbedText; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDateNanos; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDatetime; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger; @@ -3520,7 +3521,11 @@ private void assertProjectionWithMapping(String query, String mapping, QueryPara } private void assertError(String query, String mapping, QueryParams params, String error) { - Throwable e = expectThrows(VerificationException.class, () -> analyze(query, mapping, params)); + assertError(query, mapping, params, error, VerificationException.class); + } + + private void assertError(String query, String mapping, QueryParams params, String error, Class clazz) { + Throwable e = expectThrows(clazz, () -> analyze(query, mapping, params)); assertThat(e.getMessage(), containsString(error)); } @@ -3835,6 +3840,129 @@ public void testResolveCompletionOutputFieldOverwriteInputField() { assertThat(getAttributeByName(esRelation.output(), "description"), not(equalTo(completion.targetField()))); } + public void testResolveEmbedTextInferenceId() { + LogicalPlan plan = analyze(""" + FROM books METADATA _score + | EVAL embedding = EMBED_TEXT("description", "text-embedding-inference-id") + """, "mapping-books.json"); + + var limit = as(plan, Limit.class); + var eval = as(limit.child(), Eval.class); + var embedTextAlias = as(eval.fields().get(0), Alias.class); + var embedText = as(embedTextAlias.child(), EmbedText.class); + + assertThat(embedText.inferenceId(), equalTo(string("text-embedding-inference-id"))); + assertThat(embedText.inputText(), equalTo(string("description"))); + } + + public void testResolveEmbedTextInferenceIdInvalidTaskType() { + assertError( + """ + FROM books METADATA _score + | EVAL embedding = EMBED_TEXT(description, "completion-inference-id") + """, + "mapping-books.json", + new QueryParams(), + "cannot use inference endpoint [completion-inference-id] with task type [completion] within a embed_text function." + + " Only inference endpoints with the task type [text_embedding] are supported" + ); + } + + public void testResolveEmbedTextInferenceMissingInferenceId() { + assertError(""" + FROM books METADATA _score + | EVAL embedding = EMBED_TEXT(description, "unknown-inference-id") + """, "mapping-books.json", new QueryParams(), "unresolved inference [unknown-inference-id]"); + } + + public void testResolveEmbedTextInferenceIdResolutionError() { + assertError(""" + FROM books METADATA _score + | EVAL embedding = EMBED_TEXT(description, "error-inference-id") + """, "mapping-books.json", new QueryParams(), "error with inference resolution"); + } + + public void testResolveEmbedTextInNestedExpression() { + LogicalPlan plan = analyze(""" + FROM colors METADATA _score + | WHERE KNN(rgb_vector, EMBED_TEXT("blue", "text-embedding-inference-id"), 10) + """, "mapping-colors.json"); + + var limit = as(plan, Limit.class); + var filter = as(limit.child(), Filter.class); + + // Navigate to the EMBED_TEXT function within the KNN function + filter.condition().forEachDown(EmbedText.class, embedText -> { + assertThat(embedText.inferenceId(), equalTo(string("text-embedding-inference-id"))); + assertThat(embedText.inputText(), equalTo(string("blue"))); + }); + } + + public void testResolveEmbedTextDataType() { + LogicalPlan plan = analyze(""" + FROM books METADATA _score + | EVAL embedding = EMBED_TEXT("description", "text-embedding-inference-id") + """, "mapping-books.json"); + + var limit = as(plan, Limit.class); + var eval = as(limit.child(), Eval.class); + var embedTextAlias = as(eval.fields().get(0), Alias.class); + var embedText = as(embedTextAlias.child(), EmbedText.class); + + assertThat(embedText.dataType(), equalTo(DataType.DENSE_VECTOR)); + } + + public void testResolveEmbedTextInvalidParameters() { + assertError( + "FROM books METADATA _score| EVAL embedding = EMBED_TEXT(description, \"text-embedding-inference-id\")", + "mapping-books.json", + new QueryParams(), + "first argument of [EMBED_TEXT(description, \"text-embedding-inference-id\")] must be a constant, received [description]" + ); + + assertError( + "FROM books METADATA _score| EVAL embedding = EMBED_TEXT(description)", + "mapping-books.json", + new QueryParams(), + "error building [embed_text]: function [embed_text] expects exactly two arguments, it received 1", + ParsingException.class + ); + } + + public void testResolveEmbedTextWithPositionalQueryParams() { + LogicalPlan plan = analyze( + "FROM books METADATA _score| EVAL embedding = EMBED_TEXT(?, ?)", + "mapping-books.json", + new QueryParams(List.of(paramAsConstant(null, "description"), paramAsConstant(null, "text-embedding-inference-id"))) + ); + + var limit = as(plan, Limit.class); + var eval = as(limit.child(), Eval.class); + var embedTextAlias = as(eval.fields().get(0), Alias.class); + var embedText = as(embedTextAlias.child(), EmbedText.class); + + assertThat(embedText.inferenceId(), equalTo(string("text-embedding-inference-id"))); + assertThat(embedText.inputText(), equalTo(string("description"))); + } + + public void testResolveEmbedTextWithNamedQueryParams() { + LogicalPlan plan = analyze( + "FROM books METADATA _score| EVAL embedding = EMBED_TEXT(?inputText, ?inferenceId)", + "mapping-books.json", + new QueryParams( + List.of(paramAsConstant("inputText", "description"), paramAsConstant("inferenceId", "text-embedding-inference-id")) + ) + ); + + var limit = as(plan, Limit.class); + var eval = as(limit.child(), Eval.class); + var embedTextAlias = as(eval.fields().get(0), Alias.class); + var embedText = as(embedTextAlias.child(), EmbedText.class); + + assertThat(embedText.inferenceId(), equalTo(string("text-embedding-inference-id"))); + assertThat(embedText.inputText(), equalTo(string("description"))); + } + public void testResolveGroupingsBeforeResolvingImplicitReferencesToGroupings() { var plan = analyze(""" FROM test From d365b0de173a0d96b1927a7da66e985fccbb50b6 Mon Sep 17 00:00:00 2001 From: afoucret Date: Fri, 11 Jul 2025 23:22:13 +0200 Subject: [PATCH 07/31] Rename EMBED_TEXT function to TEXT_EMBEDDING and update all references --- .../esql/images/functions/embed_text.svg | 1 - .../esql/images/functions/text_embedding.svg | 1 + .../{embed_text.json => text_embedding.json} | 2 +- .../{embed_text.md => text_embedding.md} | 2 +- .../xpack/esql/action/EsqlCapabilities.java | 4 +-- .../esql/expression/ExpressionWritables.java | 6 ++-- .../function/EsqlFunctionRegistry.java | 4 +-- .../{EmbedText.java => TextEmbedding.java} | 30 ++++++++++--------- ...ests.java => TextEmbeddingErrorTests.java} | 8 ++--- ...a => TextEmbeddingSerializationTests.java} | 12 ++++---- ...TextTests.java => TextEmbeddingTests.java} | 10 +++---- 11 files changed, 41 insertions(+), 39 deletions(-) delete mode 100644 docs/reference/query-languages/esql/images/functions/embed_text.svg create mode 100644 docs/reference/query-languages/esql/images/functions/text_embedding.svg rename docs/reference/query-languages/esql/kibana/definition/functions/{embed_text.json => text_embedding.json} (91%) rename docs/reference/query-languages/esql/kibana/docs/functions/{embed_text.md => text_embedding.md} (91%) rename x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/{EmbedText.java => TextEmbedding.java} (80%) rename x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/{EmbedTextErrorTests.java => TextEmbeddingErrorTests.java} (88%) rename x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/{EmbedTextSerializationTests.java => TextEmbeddingSerializationTests.java} (72%) rename x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/{EmbedTextTests.java => TextEmbeddingTests.java} (87%) diff --git a/docs/reference/query-languages/esql/images/functions/embed_text.svg b/docs/reference/query-languages/esql/images/functions/embed_text.svg deleted file mode 100644 index 9bb6cab692c4e..0000000000000 --- a/docs/reference/query-languages/esql/images/functions/embed_text.svg +++ /dev/null @@ -1 +0,0 @@ -EMBED_TEXT(text,inference_id) \ No newline at end of file diff --git a/docs/reference/query-languages/esql/images/functions/text_embedding.svg b/docs/reference/query-languages/esql/images/functions/text_embedding.svg new file mode 100644 index 0000000000000..dab58c5e5bda0 --- /dev/null +++ b/docs/reference/query-languages/esql/images/functions/text_embedding.svg @@ -0,0 +1 @@ +TEXT_EMBEDDING(text,inference_id) \ No newline at end of file diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/embed_text.json b/docs/reference/query-languages/esql/kibana/definition/functions/text_embedding.json similarity index 91% rename from docs/reference/query-languages/esql/kibana/definition/functions/embed_text.json rename to docs/reference/query-languages/esql/kibana/definition/functions/text_embedding.json index edfafb213e16d..09a07d6851005 100644 --- a/docs/reference/query-languages/esql/kibana/definition/functions/embed_text.json +++ b/docs/reference/query-languages/esql/kibana/definition/functions/text_embedding.json @@ -1,7 +1,7 @@ { "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.", "type" : "scalar", - "name" : "embed_text", + "name" : "text_embedding", "description" : "Generates dense vector embeddings for text using a specified inference deployment.", "signatures" : [ ], "preview" : true, diff --git a/docs/reference/query-languages/esql/kibana/docs/functions/embed_text.md b/docs/reference/query-languages/esql/kibana/docs/functions/text_embedding.md similarity index 91% rename from docs/reference/query-languages/esql/kibana/docs/functions/embed_text.md rename to docs/reference/query-languages/esql/kibana/docs/functions/text_embedding.md index fc6d8d0772d64..f2bebfa37ac2b 100644 --- a/docs/reference/query-languages/esql/kibana/docs/functions/embed_text.md +++ b/docs/reference/query-languages/esql/kibana/docs/functions/text_embedding.md @@ -1,4 +1,4 @@ % This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. -### EMBED TEXT +### TEXT EMBEDDING Generates dense vector embeddings for text using a specified inference deployment. diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 0e1771400bd87..13f1d4b2d0466 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -1297,9 +1297,9 @@ public enum Cap { LIKE_ON_INDEX_FIELDS, /** - * Support for the {@code EMBED_TEXT} function for generating dense vector embeddings. + * Support for the {@code TEXT_EMBEDDING} function for generating dense vector embeddings. */ - EMBED_TEXT_FUNCTION(Build.current().isSnapshot()), + TEXT_EMBEDDING_FUNCTION(Build.current().isSnapshot()), /** * Forbid usage of brackets in unquoted index and enrich policy names diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java index 983dea6fd58ba..fedae9ec54d9e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java @@ -12,7 +12,7 @@ import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateWritables; import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextWritables; -import org.elasticsearch.xpack.esql.expression.function.inference.EmbedText; +import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding; import org.elasticsearch.xpack.esql.expression.function.scalar.ScalarFunctionWritables; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FromBase64; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToAggregateMetricDouble; @@ -266,8 +266,8 @@ private static List vector() { } private static List inference() { - if (EsqlCapabilities.Cap.EMBED_TEXT_FUNCTION.isEnabled()) { - return List.of(EmbedText.ENTRY); + if (EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()) { + return List.of(TextEmbedding.ENTRY); } return List.of(); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index b643e4f49005b..c0fc91413e760 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -53,7 +53,7 @@ import org.elasticsearch.xpack.esql.expression.function.fulltext.Term; import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket; import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize; -import org.elasticsearch.xpack.esql.expression.function.inference.EmbedText; +import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Least; @@ -486,7 +486,7 @@ private static FunctionDefinition[][] snapshotFunctions() { def(Score.class, uni(Score::new), Score.NAME), def(Term.class, bi(Term::new), "term"), def(Knn.class, quad(Knn::new), "knn"), - def(EmbedText.class, EmbedText::new, "embed_text"), + def(TextEmbedding.class, TextEmbedding::new, "text_embedding"), def(StGeohash.class, StGeohash::new, "st_geohash"), def(StGeohashToLong.class, StGeohashToLong::new, "st_geohash_to_long"), def(StGeohashToString.class, StGeohashToString::new, "st_geohash_to_string"), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedText.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbedding.java similarity index 80% rename from x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedText.java rename to x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbedding.java index cc63f7e465f48..7d1cd356ea1ac 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedText.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbedding.java @@ -26,19 +26,21 @@ import java.util.List; import java.util.Objects; -import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.*; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isString; /** - * EMBED_TEXT function that generates dense vector embeddings for text using a specified inference deployment. + * TEXT_EMBEDDING function that generates dense vector embeddings for text using a specified inference deployment. */ -public class EmbedText extends InferenceFunction { +public class TextEmbedding extends InferenceFunction { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( Expression.class, - "EmbedText", - EmbedText::new + "TextEmbedding", + TextEmbedding::new ); private final Expression inferenceId; @@ -50,7 +52,7 @@ public class EmbedText extends InferenceFunction { appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) }, preview = true ) - public EmbedText( + public TextEmbedding( Source source, @Param(name = "text", type = { "keyword", "text" }, description = "Text to embed") Expression inputText, @Param( @@ -64,7 +66,7 @@ public EmbedText( this.inputText = inputText; } - private EmbedText(StreamInput in) throws IOException { + private TextEmbedding(StreamInput in) throws IOException { this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class), in.readNamedWriteable(Expression.class)); } @@ -129,31 +131,31 @@ public TaskType taskType() { } @Override - public EmbedText withInferenceResolutionError(String inferenceId, String error) { - return new EmbedText(source(), inputText, new UnresolvedAttribute(inferenceId().source(), inferenceId, error)); + public TextEmbedding withInferenceResolutionError(String inferenceId, String error) { + return new TextEmbedding(source(), inputText, new UnresolvedAttribute(inferenceId().source(), inferenceId, error)); } @Override public Expression replaceChildren(List newChildren) { - return new EmbedText(source(), newChildren.get(0), newChildren.get(1)); + return new TextEmbedding(source(), newChildren.get(0), newChildren.get(1)); } @Override protected NodeInfo info() { - return NodeInfo.create(this, EmbedText::new, inputText, inferenceId); + return NodeInfo.create(this, TextEmbedding::new, inputText, inferenceId); } @Override public String toString() { - return "EMBED_TEXT(" + inputText + ", " + inferenceId + ")"; + return "TEXT_EMBEDDING(" + inputText + ", " + inferenceId + ")"; } @Override public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; if (super.equals(o) == false) return false; - EmbedText embedText = (EmbedText) o; - return Objects.equals(inferenceId, embedText.inferenceId) && Objects.equals(inputText, embedText.inputText); + TextEmbedding textEmbedding = (TextEmbedding) o; + return Objects.equals(inferenceId, textEmbedding.inferenceId) && Objects.equals(inputText, textEmbedding.inputText); } @Override diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextErrorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingErrorTests.java similarity index 88% rename from x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextErrorTests.java rename to x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingErrorTests.java index 4b490a86b7587..03a332d21eb1c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextErrorTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingErrorTests.java @@ -24,21 +24,21 @@ import static org.hamcrest.Matchers.equalTo; -public class EmbedTextErrorTests extends ErrorsForCasesWithoutExamplesTestCase { +public class TextEmbeddingErrorTests extends ErrorsForCasesWithoutExamplesTestCase { @Before public void checkCapability() { - assumeTrue("EMBED_TEXT is not enabled", EsqlCapabilities.Cap.EMBED_TEXT_FUNCTION.isEnabled()); + assumeTrue("TEXT_EMBEDDING is not enabled", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()); } @Override protected List cases() { - return paramsToSuppliers(EmbedTextTests.parameters()); + return paramsToSuppliers(TextEmbeddingTests.parameters()); } @Override protected Expression build(Source source, List args) { - return new EmbedText(source, args.get(0), args.get(1)); + return new TextEmbedding(source, args.get(0), args.get(1)); } @Override diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingSerializationTests.java similarity index 72% rename from x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextSerializationTests.java rename to x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingSerializationTests.java index 59d5377e36f76..57164605e3328 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingSerializationTests.java @@ -15,23 +15,23 @@ import java.io.IOException; -public class EmbedTextSerializationTests extends AbstractExpressionSerializationTests { +public class TextEmbeddingSerializationTests extends AbstractExpressionSerializationTests { @Before public void checkCapability() { - assumeTrue("EMBED_TEXT is not enabled", EsqlCapabilities.Cap.EMBED_TEXT_FUNCTION.isEnabled()); + assumeTrue("TEXT_EMBEDDING is not enabled", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()); } @Override - protected EmbedText createTestInstance() { + protected TextEmbedding createTestInstance() { Source source = randomSource(); Expression inputText = randomChild(); Expression inferenceId = randomChild(); - return new EmbedText(source, inputText, inferenceId); + return new TextEmbedding(source, inputText, inferenceId); } @Override - protected EmbedText mutateInstance(EmbedText instance) throws IOException { + protected TextEmbedding mutateInstance(TextEmbedding instance) throws IOException { Source source = instance.source(); Expression inputText = instance.inputText(); Expression inferenceId = instance.inferenceId(); @@ -40,6 +40,6 @@ protected EmbedText mutateInstance(EmbedText instance) throws IOException { } else { inferenceId = randomValueOtherThan(inferenceId, AbstractExpressionSerializationTests::randomChild); } - return new EmbedText(source, inputText, inferenceId); + return new TextEmbedding(source, inputText, inferenceId); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingTests.java similarity index 87% rename from x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextTests.java rename to x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingTests.java index c342ee143e6af..30eeba0bfcc8f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/EmbedTextTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingTests.java @@ -27,14 +27,14 @@ import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; import static org.hamcrest.Matchers.equalTo; -@FunctionName("embed_text") -public class EmbedTextTests extends AbstractFunctionTestCase { +@FunctionName("text_embedding") +public class TextEmbeddingTests extends AbstractFunctionTestCase { @Before public void checkCapability() { - assumeTrue("EMBED_TEXT is not enabled", EsqlCapabilities.Cap.EMBED_TEXT_FUNCTION.isEnabled()); + assumeTrue("TEXT_EMBEDDING is not enabled", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()); } - public EmbedTextTests(@Name("TestCase") Supplier testCaseSupplier) { + public TextEmbeddingTests(@Name("TestCase") Supplier testCaseSupplier) { this.testCase = testCaseSupplier.get(); } @@ -67,6 +67,6 @@ public static Iterable parameters() { @Override protected Expression build(Source source, List args) { - return new EmbedText(source, args.get(0), args.get(1)); + return new TextEmbedding(source, args.get(0), args.get(1)); } } From df1936d2a63cc3de52cc377fb1a067852e609c22 Mon Sep 17 00:00:00 2001 From: afoucret Date: Sat, 12 Jul 2025 00:04:50 +0200 Subject: [PATCH 08/31] Add PreOptimizer infrastructure for async pre-optimization steps --- .../xpack/esql/optimizer/PreOptimizer.java | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizer.java diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizer.java new file mode 100644 index 0000000000000..4390dc48be587 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizer.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; + +/** + * The class is responsible for invoking any steps that need to be applied to the logical plan, + * before this is being optimized. + *

+ * This is useful, especially if you need to execute some async tasks before the plan is optimized. + *

+ */ +public class PreOptimizer { + + public PreOptimizer() { + + } + + public void preOptimize(LogicalPlan plan, ActionListener listener) { + listener.onResponse(plan); + } +} From 1f0c62c4430e9a8e0c317a5e2e4f1d12741cd0e5 Mon Sep 17 00:00:00 2001 From: afoucret Date: Wed, 16 Jul 2025 18:16:13 +0200 Subject: [PATCH 09/31] [ESQL] Add async transformation methods to Node and QueryPlan --- .../xpack/esql/core/tree/Node.java | 129 ++++- .../xpack/esql/core/tree/NodeInfo.java | 236 ++++++++ .../esql/core/tree/NodeTransformTests.java | 516 ++++++++++++++++++ .../xpack/esql/plan/QueryPlan.java | 173 ++++++ .../xpack/esql/plan/QueryPlanTests.java | 73 +++ 5 files changed, 1126 insertions(+), 1 deletion(-) create mode 100644 x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/tree/NodeTransformTests.java diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/Node.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/Node.java index 613f5b0ae76c2..190c4c327e099 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/Node.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/Node.java @@ -6,6 +6,8 @@ */ package org.elasticsearch.xpack.esql.core.tree; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.CountDownActionListener; import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.xpack.esql.core.QlIllegalArgumentException; @@ -14,6 +16,8 @@ import java.util.Iterator; import java.util.List; import java.util.Objects; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Predicate; @@ -187,16 +191,45 @@ public T transformDown(Function rule) { return node.transformChildren(child -> child.transformDown(rule)); } + @SuppressWarnings("unchecked") + public void transformDown(BiConsumer> rule, ActionListener listener) { + // First apply the rule to the current node (top-down) + rule.accept((T) this, listener.delegateFailureAndWrap((l, transformedNode) -> { + // Then recursively transform the children with the same rule + transformedNode.transformChildren((child, childListener) -> child.transformDown(rule, childListener), l); + })); + } + @SuppressWarnings("unchecked") public T transformDown(Class typeToken, Function rule) { return transformDown((t) -> (typeToken.isInstance(t) ? rule.apply((E) t) : t)); } + @SuppressWarnings("unchecked") + public void transformDown(Class typeToken, BiConsumer> rule, ActionListener listener) { + transformDown(typeToken::isInstance, rule, listener); + } + @SuppressWarnings("unchecked") public T transformDown(Predicate> nodePredicate, Function rule) { return transformDown((t) -> (nodePredicate.test(t) ? rule.apply((E) t) : t)); } + @SuppressWarnings("unchecked") + public void transformDown( + Predicate> nodePredicate, + BiConsumer> rule, + ActionListener listener + ) { + transformDown((T node, ActionListener l) -> { + if (nodePredicate.test(node)) { + rule.accept((E) node, l); + } else { + l.onResponse(node); + } + }, listener); + } + @SuppressWarnings("unchecked") public T transformUp(Function rule) { T transformed = transformChildren(child -> child.transformUp(rule)); @@ -205,8 +238,25 @@ public T transformUp(Function rule) { } @SuppressWarnings("unchecked") + public void transformUp(BiConsumer> rule, ActionListener listener) { + // First, recursively transform the children (depth-first, bottom-up) using the same async rule + transformChildren( + // traversal operation applied to each child + (child, childListener) -> child.transformUp(rule, childListener), + // After all children are transformed, apply the rule to the (possibly) new current node + listener.delegateFailureAndWrap((l, transformedChildrenNode) -> { + T node = transformedChildrenNode.equals(this) ? (T) this : transformedChildrenNode; + rule.accept(node, l); + }) + ); + } + public T transformUp(Class typeToken, Function rule) { - return transformUp((t) -> (typeToken.isInstance(t) ? rule.apply((E) t) : t)); + return transformUp(typeToken::isInstance, rule); + } + + public void transformUp(Class typeToken, BiConsumer> rule, ActionListener listener) { + transformUp(typeToken::isInstance, rule, listener); } @SuppressWarnings("unchecked") @@ -214,6 +264,22 @@ public T transformUp(Predicate> nodePredicate, Function (nodePredicate.test(t) ? rule.apply((E) t) : t)); } + @SuppressWarnings("unchecked") + public void transformUp( + Predicate> nodePredicate, + BiConsumer> rule, + ActionListener listener + ) { + transformUp((T node, ActionListener l) -> { + if (nodePredicate.test(node)) { + E typedNode = (E) node; + rule.accept((E) node, l); + } else { + l.onResponse(node); + } + }, listener); + } + @SuppressWarnings("unchecked") protected > T transformChildren(Function traversalOperation) { boolean childrenChanged = false; @@ -238,6 +304,35 @@ public T transformUp(Predicate> nodePredicate, Function> traversalOperation, ActionListener listener) { + if (children.isEmpty()) { + listener.onResponse((T) this); + return; + } + + final AtomicReference> transformedChildren = new AtomicReference<>(null); + + CountDownActionListener countDownListener = new CountDownActionListener( + children.size(), + listener.delegateFailureIgnoreResponseAndWrap((l) -> { + l.onResponse(transformedChildren.get() != null ? replaceChildren(transformedChildren.get()) : (T) this); + }) + ); + + for (int i = 0, s = children.size(); i < s; i++) { + T child = children.get(i); + final int childId = i; + traversalOperation.accept(child, countDownListener.map(next -> { + if (child.equals(next) == false) { + transformedChildren.compareAndSet(null, new ArrayList<>(children)); + transformedChildren.get().set(childId, next); + } + return null; + })); + } + } + public final T replaceChildrenSameSize(List newChildren) { if (newChildren.size() != children.size()) { throw new QlIllegalArgumentException( @@ -257,14 +352,38 @@ public T transformPropertiesOnly(Class typeToken, Function void transformPropertiesOnly( + Class typeToken, + BiConsumer> rule, + ActionListener listener + ) { + transformNodeProps(typeToken, rule, listener); + } + public T transformPropertiesDown(Class typeToken, Function rule) { return transformDown(t -> t.transformNodeProps(typeToken, rule)); } + public void transformPropertiesDown( + Class typeToken, + BiConsumer> rule, + ActionListener listener + ) { + transformDown((t, l) -> t.transformNodeProps(typeToken, rule, l), listener); + } + public T transformPropertiesUp(Class typeToken, Function rule) { return transformUp(t -> t.transformNodeProps(typeToken, rule)); } + public void transformPropertiesUp( + Class typeToken, + BiConsumer> rule, + ActionListener listener + ) { + transformUp((t, l) -> t.transformNodeProps(typeToken, rule, l), listener); + } + /** * Transform this node's properties. *

@@ -277,6 +396,14 @@ protected final T transformNodeProps(Class typeToken, Function void transformNodeProps( + Class typeToken, + BiConsumer> rule, + ActionListener listener + ) { + info().transform(rule, typeToken, listener); + } + /** * Return the information about this node. *

diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/NodeInfo.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/NodeInfo.java index 28e4e739085d4..bea432cfb0d13 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/NodeInfo.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/NodeInfo.java @@ -6,9 +6,15 @@ */ package org.elasticsearch.xpack.esql.core.tree; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.CountDownActionListener; + +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Objects; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.Function; @@ -60,8 +66,27 @@ final T transform(Function rule, Class typeToken) return innerTransform(realRule); } + public void transform( + BiConsumer> rule, + Class typeToken, + ActionListener listener + ) { + List children = node.children(); + BiConsumer> realRule = (p, l) -> { + if (p != children && (p == null || typeToken.isInstance(p)) && false == children.contains(p)) { + rule.accept(typeToken.cast(p), l); + } else { + l.onResponse(p); + } + }; + + innerTransform(realRule, listener); + } + protected abstract T innerTransform(Function rule); + protected abstract void innerTransform(BiConsumer> rule, ActionListener listener); + /** * Builds a {@link NodeInfo} for Nodes without any properties. */ @@ -75,6 +100,10 @@ protected List innerProperties() { protected T innerTransform(Function rule) { return node; } + + protected void innerTransform(BiConsumer> rule, ActionListener listener) { + listener.onResponse(node); + } }; } @@ -94,6 +123,16 @@ protected T innerTransform(Function rule) { return same ? node : ctor.apply(node.source(), newP1); } + + @SuppressWarnings("unchecked") + protected void innerTransform(BiConsumer> rule, ActionListener listener) { + transformProperties( + rule, + listener.safeMap( + newProps -> innerProperties().equals(newProps) ? node : ctor.apply(node.source(), (P1) newProps.get(0)) + ) + ); + } }; } @@ -116,6 +155,18 @@ protected T innerTransform(Function rule) { return same ? node : ctor.apply(node.source(), newP1, newP2); } + + @SuppressWarnings("unchecked") + protected void innerTransform(BiConsumer> rule, ActionListener listener) { + transformProperties( + rule, + listener.safeMap( + newProps -> innerProperties().equals(newProps) + ? node + : ctor.apply(node.source(), (P1) newProps.get(0), (P2) newProps.get(1)) + ) + ); + } }; } @@ -145,6 +196,18 @@ protected T innerTransform(Function rule) { return same ? node : ctor.apply(node.source(), newP1, newP2, newP3); } + + @SuppressWarnings("unchecked") + protected void innerTransform(BiConsumer> rule, ActionListener listener) { + transformProperties( + rule, + listener.safeMap( + newProps -> innerProperties().equals(newProps) + ? node + : ctor.apply(node.source(), (P1) newProps.get(0), (P2) newProps.get(1), (P3) newProps.get(2)) + ) + ); + } }; } @@ -184,6 +247,24 @@ protected T innerTransform(Function rule) { return same ? node : ctor.apply(node.source(), newP1, newP2, newP3, newP4); } + + @SuppressWarnings("unchecked") + protected void innerTransform(BiConsumer> rule, ActionListener listener) { + transformProperties( + rule, + listener.safeMap( + newProps -> innerProperties().equals(newProps) + ? node + : ctor.apply( + node.source(), + (P1) newProps.get(0), + (P2) newProps.get(1), + (P3) newProps.get(2), + (P4) newProps.get(3) + ) + ) + ); + } }; } @@ -227,6 +308,25 @@ protected T innerTransform(Function rule) { return same ? node : ctor.apply(node.source(), newP1, newP2, newP3, newP4, newP5); } + + @SuppressWarnings("unchecked") + protected void innerTransform(BiConsumer> rule, ActionListener listener) { + transformProperties( + rule, + listener.safeMap( + newProps -> innerProperties().equals(newProps) + ? node + : ctor.apply( + node.source(), + (P1) newProps.get(0), + (P2) newProps.get(1), + (P3) newProps.get(2), + (P4) newProps.get(3), + (P5) newProps.get(4) + ) + ) + ); + } }; } @@ -274,6 +374,26 @@ protected T innerTransform(Function rule) { return same ? node : ctor.apply(node.source(), newP1, newP2, newP3, newP4, newP5, newP6); } + + @SuppressWarnings("unchecked") + protected void innerTransform(BiConsumer> rule, ActionListener listener) { + transformProperties( + rule, + listener.safeMap( + newProps -> innerProperties().equals(newProps) + ? node + : ctor.apply( + node.source(), + (P1) newProps.get(0), + (P2) newProps.get(1), + (P3) newProps.get(2), + (P4) newProps.get(3), + (P5) newProps.get(4), + (P6) newProps.get(5) + ) + ) + ); + } }; } @@ -325,6 +445,27 @@ protected T innerTransform(Function rule) { return same ? node : ctor.apply(node.source(), newP1, newP2, newP3, newP4, newP5, newP6, newP7); } + + @SuppressWarnings("unchecked") + protected void innerTransform(BiConsumer> rule, ActionListener listener) { + transformProperties( + rule, + listener.safeMap( + newProps -> innerProperties().equals(newProps) + ? node + : ctor.apply( + node.source(), + (P1) newProps.get(0), + (P2) newProps.get(1), + (P3) newProps.get(2), + (P4) newProps.get(3), + (P5) newProps.get(4), + (P6) newProps.get(5), + (P7) newProps.get(6) + ) + ) + ); + } }; } @@ -380,6 +521,28 @@ protected T innerTransform(Function rule) { return same ? node : ctor.apply(node.source(), newP1, newP2, newP3, newP4, newP5, newP6, newP7, newP8); } + + @SuppressWarnings("unchecked") + protected void innerTransform(BiConsumer> rule, ActionListener listener) { + transformProperties( + rule, + listener.safeMap( + newProps -> innerProperties().equals(newProps) + ? node + : ctor.apply( + node.source(), + (P1) newProps.get(0), + (P2) newProps.get(1), + (P3) newProps.get(2), + (P4) newProps.get(3), + (P5) newProps.get(4), + (P6) newProps.get(5), + (P7) newProps.get(6), + (P8) newProps.get(7) + ) + ) + ); + } }; } @@ -439,6 +602,29 @@ protected T innerTransform(Function rule) { return same ? node : ctor.apply(node.source(), newP1, newP2, newP3, newP4, newP5, newP6, newP7, newP8, newP9); } + + @SuppressWarnings("unchecked") + protected void innerTransform(BiConsumer> rule, ActionListener listener) { + transformProperties( + rule, + listener.safeMap( + newProps -> innerProperties().equals(newProps) + ? node + : ctor.apply( + node.source(), + (P1) newProps.get(0), + (P2) newProps.get(1), + (P3) newProps.get(2), + (P4) newProps.get(3), + (P5) newProps.get(4), + (P6) newProps.get(5), + (P7) newProps.get(6), + (P8) newProps.get(7), + (P9) newProps.get(8) + ) + ) + ); + } }; } @@ -502,10 +688,60 @@ protected T innerTransform(Function rule) { return same ? node : ctor.apply(node.source(), newP1, newP2, newP3, newP4, newP5, newP6, newP7, newP8, newP9, newP10); } + + @SuppressWarnings("unchecked") + protected void innerTransform(BiConsumer> rule, ActionListener listener) { + transformProperties( + rule, + listener.safeMap( + newProps -> innerProperties().equals(newProps) + ? node + : ctor.apply( + node.source(), + (P1) newProps.get(0), + (P2) newProps.get(1), + (P3) newProps.get(2), + (P4) newProps.get(3), + (P5) newProps.get(4), + (P6) newProps.get(5), + (P7) newProps.get(6), + (P8) newProps.get(7), + (P9) newProps.get(8), + (P10) newProps.get(9) + ) + ) + ); + } }; } public interface NodeCtor10 { T apply(Source l, P1 p1, P2 p2, P3 p3, P4 p4, P5 p5, P6 p6, P7 p7, P8 p8, P9 p9, P10 p10); } + + protected void transformProperties(BiConsumer> rule, ActionListener> listener) { + List properties = innerProperties(); + if (properties.isEmpty()) { + listener.onResponse(properties); + return; + } + + AtomicReference> transformedProperties = new AtomicReference<>(null); + CountDownActionListener completionListener = new CountDownActionListener(properties.size(), ActionListener.wrap(ignored -> { + List result = transformedProperties.get() != null ? transformedProperties.get() : properties; + listener.onResponse(result); + }, listener::onFailure)); + + for (int i = 0, size = properties.size(); i < size; i++) { + final int currentIndex = i; + Object property = properties.get(currentIndex); + rule.accept(property, completionListener.delegateFailureAndWrap((l, transformed) -> { + if (properties.get(currentIndex).equals(transformed) == false) { + transformedProperties.compareAndSet(null, new ArrayList<>(properties)); + transformedProperties.get().set(currentIndex, transformed); + } + l.onResponse(null); + })); + } + } } diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/tree/NodeTransformTests.java b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/tree/NodeTransformTests.java new file mode 100644 index 0000000000000..6f22009ee39bd --- /dev/null +++ b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/tree/NodeTransformTests.java @@ -0,0 +1,516 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.core.tree; + +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.test.ESTestCase; + +import java.util.List; +import java.util.function.Function; +import java.util.function.Predicate; + +import static org.elasticsearch.xpack.esql.core.tree.SourceTests.randomSource; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; + +public class NodeTransformTests extends ESTestCase { + // Transform Up Tests + public void testTransformUpSimpleLeafTransformation() throws Exception { + NodeTests.ChildrenAreAProperty tree = createSimpleTree(); + Function transformer = createLeafTransformer(); + + NodeTests.Dummy result = tree.transformUp(transformer); + + assertThat(result, instanceOf(NodeTests.ChildrenAreAProperty.class)); + NodeTests.ChildrenAreAProperty transformed = (NodeTests.ChildrenAreAProperty) result; + assertThat(transformed.children().get(0).thing(), equalTo("leaf1_transformed")); + assertThat(transformed.children().get(1).thing(), equalTo("leaf2_transformed")); + + // Verify async version matches + assertAsyncTransformMatches(tree, transformer, result); + } + + public void testTransformUpWithTypeToken() throws Exception { + NodeTests.ChildrenAreAProperty tree = createSimpleTree(); + Function transformer = n -> new NodeTests.NoChildren(n.source(), n.thing() + "_transformed"); + + NodeTests.Dummy result = tree.transformUp(NodeTests.NoChildren.class, transformer); + + assertThat(result, instanceOf(NodeTests.ChildrenAreAProperty.class)); + NodeTests.ChildrenAreAProperty transformed = (NodeTests.ChildrenAreAProperty) result; + assertThat(transformed.children().get(0).thing(), equalTo("leaf1_transformed")); + assertThat(transformed.children().get(1).thing(), equalTo("leaf2_transformed")); + + // Verify async version matches + SetOnce asyncResult = new SetOnce<>(); + tree.transformUp( + NodeTests.NoChildren.class, + (n, listener) -> listener.onResponse(transformer.apply(n)), + ActionListener.wrap(asyncResult::set, ESTestCase::fail) + ); + assertBusy(() -> assertThat(asyncResult.get(), equalTo(result))); + } + + public void testTransformUpWithPredicate() throws Exception { + NodeTests.ChildrenAreAProperty tree = createSimpleTree(); + Predicate> predicate = n -> n instanceof NodeTests.NoChildren && ((NodeTests.NoChildren) n).thing().equals("leaf1"); + Function transformer = n -> new NodeTests.NoChildren(n.source(), n.thing() + "_transformed"); + + NodeTests.Dummy result = tree.transformUp(predicate, transformer); + + assertThat(result, instanceOf(NodeTests.ChildrenAreAProperty.class)); + NodeTests.ChildrenAreAProperty transformed = (NodeTests.ChildrenAreAProperty) result; + assertThat(transformed.children().get(0).thing(), equalTo("leaf1_transformed")); + assertThat(transformed.children().get(1).thing(), equalTo("leaf2")); // Not transformed + + // Verify async version matches + SetOnce asyncResult = new SetOnce<>(); + tree.transformUp( + predicate, + (n, listener) -> listener.onResponse(transformer.apply(n)), + ActionListener.wrap(asyncResult::set, ESTestCase::fail) + ); + assertBusy(() -> assertThat(asyncResult.get(), equalTo(result))); + } + + public void testTransformUpErrorHandling() { + NodeTests.ChildrenAreAProperty tree = createSimpleTree(); + + RuntimeException e = expectThrows( + RuntimeException.class, + () -> tree.transformUp(n -> { throw new RuntimeException("test error"); }) + ); + assertThat(e.getMessage(), equalTo("test error")); + } + + public void testTransformUpAsyncErrorHandling() throws Exception { + NodeTests.ChildrenAreAProperty tree = createSimpleTree(); + SetOnce exception = new SetOnce<>(); + + tree.transformUp( + (n, l) -> l.onFailure(new RuntimeException("test error")), + ActionListener.wrap(r -> fail("should not be called"), e -> exception.set(asInstanceOf(RuntimeException.class, e))) + ); + + assertBusy(() -> assertThat(exception.get().getMessage(), equalTo("test error"))); + } + + public void testTransformUpNestedStructures() throws Exception { + NodeTests.ChildrenAreAProperty tree = createNestedTree(); + Function transformer = createAllNodesTransformer(); + + NodeTests.Dummy result = tree.transformUp(transformer); + + assertThat(result, instanceOf(NodeTests.ChildrenAreAProperty.class)); + NodeTests.ChildrenAreAProperty transformedOuter = (NodeTests.ChildrenAreAProperty) result; + assertThat(transformedOuter.thing(), equalTo("outer_transformed")); + + NodeTests.Dummy innerResult = transformedOuter.children().get(0); + assertThat(innerResult, instanceOf(NodeTests.ChildrenAreAProperty.class)); + NodeTests.ChildrenAreAProperty transformedInner = (NodeTests.ChildrenAreAProperty) innerResult; + assertThat(transformedInner.thing(), equalTo("inner_transformed")); + assertThat(transformedInner.children().get(0).thing(), equalTo("leaf1_transformed")); + assertThat(transformedInner.children().get(1).thing(), equalTo("leaf2_transformed")); + + // Verify async version matches + assertAsyncTransformMatches(tree, transformer, result); + } + + // Transform Down Tests + public void testTransformDownSimpleLeafTransformation() throws Exception { + NodeTests.ChildrenAreAProperty tree = createSimpleTree(); + Function transformer = createLeafTransformer(); + + NodeTests.Dummy result = tree.transformDown(transformer); + + assertThat(result, instanceOf(NodeTests.ChildrenAreAProperty.class)); + NodeTests.ChildrenAreAProperty transformed = (NodeTests.ChildrenAreAProperty) result; + assertThat(transformed.children().get(0).thing(), equalTo("leaf1_transformed")); + assertThat(transformed.children().get(1).thing(), equalTo("leaf2_transformed")); + + // Verify async version matches + assertAsyncTransformDownMatches(tree, transformer, result); + } + + public void testTransformDownWithTypeToken() throws Exception { + NodeTests.ChildrenAreAProperty tree = createSimpleTree(); + Function transformer = n -> new NodeTests.NoChildren(n.source(), n.thing() + "_transformed"); + + NodeTests.Dummy result = tree.transformDown(NodeTests.NoChildren.class, transformer); + + assertThat(result, instanceOf(NodeTests.ChildrenAreAProperty.class)); + NodeTests.ChildrenAreAProperty transformed = (NodeTests.ChildrenAreAProperty) result; + assertThat(transformed.children().get(0).thing(), equalTo("leaf1_transformed")); + assertThat(transformed.children().get(1).thing(), equalTo("leaf2_transformed")); + + // Verify async version matches + SetOnce asyncResult = new SetOnce<>(); + tree.transformDown( + NodeTests.NoChildren.class, + (n, listener) -> listener.onResponse(transformer.apply(n)), + ActionListener.wrap(asyncResult::set, ESTestCase::fail) + ); + assertBusy(() -> assertThat(asyncResult.get(), equalTo(result))); + } + + public void testTransformDownWithPredicate() throws Exception { + NodeTests.ChildrenAreAProperty tree = createSimpleTree(); + Predicate> predicate = n -> n instanceof NodeTests.NoChildren && ((NodeTests.NoChildren) n).thing().equals("leaf1"); + Function transformer = n -> new NodeTests.NoChildren(n.source(), n.thing() + "_transformed"); + + NodeTests.Dummy result = tree.transformDown(predicate, transformer); + + assertThat(result, instanceOf(NodeTests.ChildrenAreAProperty.class)); + NodeTests.ChildrenAreAProperty transformed = (NodeTests.ChildrenAreAProperty) result; + assertThat(transformed.children().get(0).thing(), equalTo("leaf1_transformed")); + assertThat(transformed.children().get(1).thing(), equalTo("leaf2")); // Not transformed + + // Verify async version matches + SetOnce asyncResult = new SetOnce<>(); + tree.transformDown( + predicate, + (n, listener) -> listener.onResponse(transformer.apply(n)), + ActionListener.wrap(asyncResult::set, ESTestCase::fail) + ); + assertBusy(() -> assertThat(asyncResult.get(), equalTo(result))); + } + + public void testTransformDownErrorHandling() { + NodeTests.ChildrenAreAProperty tree = createSimpleTree(); + + RuntimeException e = expectThrows( + RuntimeException.class, + () -> tree.transformDown(n -> { throw new RuntimeException("test error"); }) + ); + assertThat(e.getMessage(), equalTo("test error")); + } + + public void testTransformDownAsyncErrorHandling() throws Exception { + NodeTests.ChildrenAreAProperty tree = createSimpleTree(); + SetOnce exception = new SetOnce<>(); + + tree.transformDown((n, listener) -> { + if (n instanceof NodeTests.NoChildren) { + listener.onFailure(new RuntimeException("test error")); + } else { + listener.onResponse(n); + } + }, ActionListener.wrap(r -> fail("should not be called"), e -> exception.set(asInstanceOf(RuntimeException.class, e)))); + + assertBusy(() -> { + assertNotNull(exception.get()); + assertThat(exception.get().getMessage(), equalTo("test error")); + }); + } + + public void testTransformDownNestedStructures() throws Exception { + NodeTests.ChildrenAreAProperty tree = createNestedTree(); + Function transformer = createAllNodesTransformer(); + + NodeTests.Dummy result = tree.transformDown(transformer); + + assertThat(result, instanceOf(NodeTests.ChildrenAreAProperty.class)); + NodeTests.ChildrenAreAProperty transformedOuter = (NodeTests.ChildrenAreAProperty) result; + assertThat(transformedOuter.thing(), equalTo("outer_transformed")); + + NodeTests.Dummy innerResult = transformedOuter.children().get(0); + assertThat(innerResult, instanceOf(NodeTests.ChildrenAreAProperty.class)); + NodeTests.ChildrenAreAProperty transformedInner = (NodeTests.ChildrenAreAProperty) innerResult; + assertThat(transformedInner.thing(), equalTo("inner_transformed")); + assertThat(transformedInner.children().get(0).thing(), equalTo("leaf1_transformed")); + assertThat(transformedInner.children().get(1).thing(), equalTo("leaf2_transformed")); + + // Verify async version matches + assertAsyncTransformDownMatches(tree, transformer, result); + } + + public void testTransformPropertiesOnly() throws Exception { + NodeTests.ChildrenAreAProperty tree = createSimpleTree(); + + // Sync transformation: change the root mode property only + NodeTests.Dummy result = tree.transformPropertiesOnly(String.class, s -> s + "_changed"); + NodeTests.ChildrenAreAProperty transformed = (NodeTests.ChildrenAreAProperty) result; + assertThat(transformed.thing(), equalTo("node_changed")); + // Children should remain unchanged because transformPropertiesOnly does not traverse + assertThat(transformed.children().get(0).thing(), equalTo("leaf1")); + assertThat(transformed.children().get(1).thing(), equalTo("leaf2")); + + // Async variant should yield identical result + SetOnce asyncResult = new SetOnce<>(); + tree.transformPropertiesOnly( + String.class, + (s, l) -> l.onResponse(s + "_changed"), + ActionListener.wrap(asyncResult::set, ESTestCase::fail) + ); + assertBusy(() -> assertThat(asyncResult.get(), equalTo(result))); + } + + public void testTransformPropertiesDown() throws Exception { + NodeTests.ChildrenAreAProperty tree = createNestedTree(); + + // Sync transformation: propagate "_changed" to every "thing" property top-down + NodeTests.Dummy result = tree.transformPropertiesDown(String.class, s -> s + "_changed"); + // Root, inner, and leaves should all have suffix + NodeTests.ChildrenAreAProperty outer = (NodeTests.ChildrenAreAProperty) result; + assertThat(outer.thing(), equalTo("outer_changed")); + NodeTests.ChildrenAreAProperty inner = (NodeTests.ChildrenAreAProperty) outer.children().get(0); + assertThat(inner.thing(), equalTo("inner_changed")); + assertThat(inner.children().get(0).thing(), equalTo("leaf1_changed")); + assertThat(inner.children().get(1).thing(), equalTo("leaf2_changed")); + + // Async variant should yield identical result + SetOnce asyncResult = new SetOnce<>(); + tree.transformPropertiesDown( + String.class, + (s, l) -> l.onResponse(s + "_changed"), + ActionListener.wrap(asyncResult::set, ESTestCase::fail) + ); + assertBusy(() -> assertThat(asyncResult.get(), equalTo(result))); + } + + public void testTransformPropertiesUp() throws Exception { + NodeTests.ChildrenAreAProperty tree = createNestedTree(); + + // Sync transformation: propagate "_changed" to every "thing" property bottom-up + NodeTests.Dummy result = tree.transformPropertiesUp(String.class, s -> s + "_changed"); + NodeTests.ChildrenAreAProperty outer = (NodeTests.ChildrenAreAProperty) result; + assertThat(outer.thing(), equalTo("outer_changed")); + NodeTests.ChildrenAreAProperty inner = (NodeTests.ChildrenAreAProperty) outer.children().get(0); + assertThat(inner.thing(), equalTo("inner_changed")); + assertThat(inner.children().get(0).thing(), equalTo("leaf1_changed")); + assertThat(inner.children().get(1).thing(), equalTo("leaf2_changed")); + + // Async variant should yield identical result + SetOnce asyncResult = new SetOnce<>(); + tree.transformPropertiesUp( + String.class, + (s, l) -> l.onResponse(s + "_changed"), + ActionListener.wrap(asyncResult::set, ESTestCase::fail) + ); + assertBusy(() -> assertThat(asyncResult.get(), equalTo(result))); + } + + // Tests demonstrating behavioral differences between transformUp and transformDown + public void testTransformUpVsDownOrderDependentTransformation() { + NodeTests.NoChildren leaf1 = new NodeTests.NoChildren(randomSource(), "leaf"); + NodeTests.NoChildren leaf2 = new NodeTests.NoChildren(randomSource(), "leaf"); + NodeTests.ChildrenAreAProperty innerNode = new NodeTests.ChildrenAreAProperty(randomSource(), List.of(leaf1, leaf2), "inner"); + NodeTests.ChildrenAreAProperty outerNode = new NodeTests.ChildrenAreAProperty(randomSource(), List.of(innerNode), "outer"); + + Function transformerDown = n -> { + if (n instanceof NodeTests.ChildrenAreAProperty) { + NodeTests.ChildrenAreAProperty cn = (NodeTests.ChildrenAreAProperty) n; + return new NodeTests.ChildrenAreAProperty(cn.source(), cn.children(), cn.thing() + "_DOWN"); + } + return n; + }; + + Function transformerUp = n -> { + if (n instanceof NodeTests.ChildrenAreAProperty) { + NodeTests.ChildrenAreAProperty cn = (NodeTests.ChildrenAreAProperty) n; + return new NodeTests.ChildrenAreAProperty(cn.source(), cn.children(), cn.thing() + "_UP"); + } + return n; + }; + + // Transform down: parent first, then children + NodeTests.Dummy resultDown = outerNode.transformDown(transformerDown); + NodeTests.ChildrenAreAProperty outerDown = (NodeTests.ChildrenAreAProperty) resultDown; + NodeTests.ChildrenAreAProperty innerDown = (NodeTests.ChildrenAreAProperty) outerDown.children().get(0); + + // Transform up: children first, then parent + NodeTests.Dummy resultUp = outerNode.transformUp(transformerUp); + NodeTests.ChildrenAreAProperty outerUp = (NodeTests.ChildrenAreAProperty) resultUp; + NodeTests.ChildrenAreAProperty innerUp = (NodeTests.ChildrenAreAProperty) outerUp.children().get(0); + + // Verify transformation order is reflected in results + assertThat(outerDown.thing(), equalTo("outer_DOWN")); + assertThat(innerDown.thing(), equalTo("inner_DOWN")); + assertThat(outerUp.thing(), equalTo("outer_UP")); + assertThat(innerUp.thing(), equalTo("inner_UP")); + } + + public void testTransformUpVsDownChildDependentLogic() { + NodeTests.NoChildren leaf1 = new NodeTests.NoChildren(randomSource(), "A"); + NodeTests.NoChildren leaf2 = new NodeTests.NoChildren(randomSource(), "B"); + NodeTests.ChildrenAreAProperty node = new NodeTests.ChildrenAreAProperty(randomSource(), List.of(leaf1, leaf2), "parent"); + + // Transformer that changes parent based on children's current state + Function transformer = n -> { + if (n instanceof NodeTests.ChildrenAreAProperty) { + NodeTests.ChildrenAreAProperty cn = (NodeTests.ChildrenAreAProperty) n; + // Count how many children have "transformed" in their name + long transformedChildrenCount = cn.children().stream().filter(child -> child.thing().contains("transformed")).count(); + return new NodeTests.ChildrenAreAProperty( + cn.source(), + cn.children(), + cn.thing() + "_sees_" + transformedChildrenCount + "_transformed_children" + ); + } else if (n instanceof NodeTests.NoChildren) { + return new NodeTests.NoChildren(n.source(), n.thing() + "_transformed"); + } + return n; + }; + + // Transform down: parent sees children in original state + NodeTests.Dummy resultDown = node.transformDown(transformer); + NodeTests.ChildrenAreAProperty parentDown = (NodeTests.ChildrenAreAProperty) resultDown; + + // Transform up: parent sees children after they've been transformed + NodeTests.Dummy resultUp = node.transformUp(transformer); + NodeTests.ChildrenAreAProperty parentUp = (NodeTests.ChildrenAreAProperty) resultUp; + + // Key difference: transformDown parent sees 0 transformed children, + // transformUp parent sees 2 transformed children + assertThat(parentDown.thing(), equalTo("parent_sees_0_transformed_children")); + assertThat(parentUp.thing(), equalTo("parent_sees_2_transformed_children")); + + // Both should have transformed children + assertThat(parentDown.children().get(0).thing(), equalTo("A_transformed")); + assertThat(parentDown.children().get(1).thing(), equalTo("B_transformed")); + assertThat(parentUp.children().get(0).thing(), equalTo("A_transformed")); + assertThat(parentUp.children().get(1).thing(), equalTo("B_transformed")); + } + + public void testTransformUpVsDownConditionalTransformation() { + NodeTests.NoChildren leaf1 = new NodeTests.NoChildren(randomSource(), "child1"); + NodeTests.NoChildren leaf2 = new NodeTests.NoChildren(randomSource(), "child2"); + NodeTests.ChildrenAreAProperty node = new NodeTests.ChildrenAreAProperty(randomSource(), List.of(leaf1, leaf2), "STOP"); + + // Transformer that stops transformation if parent has "STOP" in name + Function transformer = n -> { + if (n instanceof NodeTests.ChildrenAreAProperty) { + NodeTests.ChildrenAreAProperty cn = (NodeTests.ChildrenAreAProperty) n; + if (cn.thing().contains("STOP")) { + // Return node unchanged + return cn; + } else { + return new NodeTests.ChildrenAreAProperty(cn.source(), cn.children(), cn.thing() + "_processed"); + } + } else if (n instanceof NodeTests.NoChildren) { + return new NodeTests.NoChildren(n.source(), n.thing() + "_transformed"); + } + return n; + }; + + NodeTests.Dummy resultDown = node.transformDown(transformer); + NodeTests.ChildrenAreAProperty parentDown = (NodeTests.ChildrenAreAProperty) resultDown; + + NodeTests.Dummy resultUp = node.transformUp(transformer); + NodeTests.ChildrenAreAProperty parentUp = (NodeTests.ChildrenAreAProperty) resultUp; + + // Both parents should remain unchanged (contain "STOP") + assertThat(parentDown.thing(), equalTo("STOP")); + assertThat(parentUp.thing(), equalTo("STOP")); + + // Both should have transformed children + assertThat(parentDown.children().get(0).thing(), equalTo("child1_transformed")); + assertThat(parentUp.children().get(0).thing(), equalTo("child1_transformed")); + } + + public void testTransformUpVsDownAccumulativeChanges() { + NodeTests.NoChildren leaf = new NodeTests.NoChildren(randomSource(), "0"); + NodeTests.AChildIsAProperty innerNode = new NodeTests.AChildIsAProperty(randomSource(), leaf, "0"); + NodeTests.AChildIsAProperty outerNode = new NodeTests.AChildIsAProperty(randomSource(), innerNode, "0"); + + // Transformer that increments numeric values + Function transformer = n -> { + try { + int currentValue = Integer.parseInt(n.thing()); + String newValue = String.valueOf(currentValue + 1); + + if (n instanceof NodeTests.NoChildren) { + return new NodeTests.NoChildren(n.source(), newValue); + } else if (n instanceof NodeTests.AChildIsAProperty) { + NodeTests.AChildIsAProperty an = (NodeTests.AChildIsAProperty) n; + return new NodeTests.AChildIsAProperty(an.source(), an.child(), newValue); + } + } catch (NumberFormatException e) { + // If not a number, leave unchanged + } + return n; + }; + + NodeTests.Dummy resultDown = outerNode.transformDown(transformer); + NodeTests.Dummy resultUp = outerNode.transformUp(transformer); + + // Extract the final values + NodeTests.AChildIsAProperty outerDown = (NodeTests.AChildIsAProperty) resultDown; + NodeTests.AChildIsAProperty innerDown = (NodeTests.AChildIsAProperty) outerDown.child(); + NodeTests.NoChildren leafDown = (NodeTests.NoChildren) innerDown.child(); + + NodeTests.AChildIsAProperty outerUp = (NodeTests.AChildIsAProperty) resultUp; + NodeTests.AChildIsAProperty innerUp = (NodeTests.AChildIsAProperty) outerUp.child(); + NodeTests.NoChildren leafUp = (NodeTests.NoChildren) innerUp.child(); + + // All nodes should be incremented to "1" + assertThat(leafDown.thing(), equalTo("1")); + assertThat(leafUp.thing(), equalTo("1")); + assertThat(innerDown.thing(), equalTo("1")); + assertThat(innerUp.thing(), equalTo("1")); + assertThat(outerDown.thing(), equalTo("1")); + assertThat(outerUp.thing(), equalTo("1")); + } + + // Helper methods for transform tests + private NodeTests.ChildrenAreAProperty createSimpleTree() { + NodeTests.NoChildren leaf1 = new NodeTests.NoChildren(randomSource(), "leaf1"); + NodeTests.NoChildren leaf2 = new NodeTests.NoChildren(randomSource(), "leaf2"); + return new NodeTests.ChildrenAreAProperty(randomSource(), List.of(leaf1, leaf2), "node"); + } + + private NodeTests.ChildrenAreAProperty createNestedTree() { + NodeTests.NoChildren leaf1 = new NodeTests.NoChildren(randomSource(), "leaf1"); + NodeTests.NoChildren leaf2 = new NodeTests.NoChildren(randomSource(), "leaf2"); + NodeTests.ChildrenAreAProperty innerNode = new NodeTests.ChildrenAreAProperty(randomSource(), List.of(leaf1, leaf2), "inner"); + return new NodeTests.ChildrenAreAProperty(randomSource(), List.of(innerNode), "outer"); + } + + private Function createLeafTransformer() { + return n -> n instanceof NodeTests.NoChildren ? new NodeTests.NoChildren(n.source(), n.thing() + "_transformed") : n; + } + + private Function createAllNodesTransformer() { + return n -> { + if (n instanceof NodeTests.NoChildren) { + return new NodeTests.NoChildren(n.source(), ((NodeTests.NoChildren) n).thing() + "_transformed"); + } else if (n instanceof NodeTests.ChildrenAreAProperty) { + NodeTests.ChildrenAreAProperty cn = (NodeTests.ChildrenAreAProperty) n; + return new NodeTests.ChildrenAreAProperty(cn.source(), cn.children(), cn.thing() + "_transformed"); + } + return n; + }; + } + + private void assertAsyncTransformMatches( + NodeTests.Dummy node, + Function transformer, + NodeTests.Dummy expectedResult + ) throws Exception { + SetOnce asyncResult = new SetOnce<>(); + ((Node) node).transformUp( + (n, listener) -> listener.onResponse(transformer.apply(n)), + ActionListener.wrap(asyncResult::set, ESTestCase::fail) + ); + assertBusy(() -> assertThat(asyncResult.get(), equalTo(expectedResult))); + } + + private void assertAsyncTransformDownMatches( + NodeTests.Dummy node, + Function transformer, + NodeTests.Dummy expectedResult + ) throws Exception { + SetOnce asyncResult = new SetOnce<>(); + ((Node) node).transformDown( + (n, listener) -> listener.onResponse(transformer.apply(n)), + ActionListener.wrap(asyncResult::set, ESTestCase::fail) + ); + assertBusy(() -> assertThat(asyncResult.get(), equalTo(expectedResult))); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/QueryPlan.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/QueryPlan.java index 81a89950b0a02..0e127776c30e6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/QueryPlan.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/QueryPlan.java @@ -6,6 +6,8 @@ */ package org.elasticsearch.xpack.esql.plan; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.CountDownActionListener; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.AttributeSet; import org.elasticsearch.xpack.esql.core.expression.Expression; @@ -16,6 +18,8 @@ import java.util.ArrayList; import java.util.Collection; import java.util.List; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Predicate; @@ -102,18 +106,78 @@ public PlanType transformExpressionsOnly(Function doTransformExpression(e, exp -> exp.transformDown(rule))); } + public void transformExpressionsOnly(BiConsumer> rule, ActionListener listener) { + transformPropertiesOnly( + Object.class, + (prop, propListener) -> doTransformExpression( + prop, + (expr, exprListener) -> expr.transformDown(rule, exprListener), + propListener + ), + listener + ); + } + public PlanType transformExpressionsOnly(Class typeToken, Function rule) { return transformPropertiesOnly(Object.class, e -> doTransformExpression(e, exp -> exp.transformDown(typeToken, rule))); } + public void transformExpressionsOnly( + Class typeToken, + BiConsumer> rule, + ActionListener listener + ) { + transformPropertiesOnly( + Object.class, + (prop, propListener) -> doTransformExpression( + prop, + (expr, exprListener) -> expr.transformDown(typeToken, rule, exprListener), + propListener + ), + listener + ); + } + public PlanType transformExpressionsOnlyUp(Class typeToken, Function rule) { return transformPropertiesOnly(Object.class, e -> doTransformExpression(e, exp -> exp.transformUp(typeToken, rule))); } + public void transformExpressionsOnlyUp( + Class typeToken, + BiConsumer> rule, + ActionListener listener + ) { + transformPropertiesOnly( + Object.class, + (prop, propListener) -> doTransformExpression( + prop, + (expr, exprListener) -> expr.transformUp(typeToken, rule, exprListener), + propListener + ), + listener + ); + } + public PlanType transformExpressionsDown(Class typeToken, Function rule) { return transformPropertiesDown(Object.class, e -> doTransformExpression(e, exp -> exp.transformDown(typeToken, rule))); } + public void transformExpressionsDown( + Class typeToken, + BiConsumer> rule, + ActionListener listener + ) { + transformPropertiesDown( + Object.class, + (prop, propListener) -> doTransformExpression( + prop, + (expr, exprListener) -> expr.transformDown(typeToken, rule, exprListener), + propListener + ), + listener + ); + } + public PlanType transformExpressionsDown( Predicate> shouldVisit, Class typeToken, @@ -125,10 +189,47 @@ public PlanType transformExpressionsDown( ); } + public void transformExpressionsDown( + Predicate> shouldVisit, + Class typeToken, + BiConsumer> rule, + ActionListener listener + ) { + transformDown( + shouldVisit, + (plan, planListener) -> plan.transformNodeProps( + Object.class, + (prop, propListener) -> doTransformExpression( + prop, + (expr, exprListener) -> expr.transformDown(typeToken, rule, exprListener), + propListener + ), + planListener + ), + listener + ); + } + public PlanType transformExpressionsUp(Class typeToken, Function rule) { return transformPropertiesUp(Object.class, e -> doTransformExpression(e, exp -> exp.transformUp(typeToken, rule))); } + public void transformExpressionsUp( + Class typeToken, + BiConsumer> rule, + ActionListener listener + ) { + transformPropertiesUp( + Object.class, + (prop, propListener) -> doTransformExpression( + prop, + (expr, exprListener) -> expr.transformUp(typeToken, rule, exprListener), + propListener + ), + listener + ); + } + public PlanType transformExpressionsUp( Predicate> shouldVisit, Class typeToken, @@ -140,6 +241,27 @@ public PlanType transformExpressionsUp( ); } + public void transformExpressionsUp( + Predicate> shouldVisit, + Class typeToken, + BiConsumer> rule, + ActionListener listener + ) { + transformUp( + shouldVisit, + (plan, planListener) -> plan.transformNodeProps( + Object.class, + (prop, propListener) -> doTransformExpression( + prop, + (expr, exprListener) -> expr.transformUp(typeToken, rule, exprListener), + propListener + ), + planListener + ), + listener + ); + } + @SuppressWarnings("unchecked") private static Object doTransformExpression(Object arg, Function traversal) { if (arg instanceof Expression exp) { @@ -217,4 +339,55 @@ private static void doForEachExpression(Object arg, Consumer travers } } } + + private static void doTransformExpression( + Object arg, + BiConsumer> traversal, + ActionListener listener + ) { + if (arg instanceof Expression exp) { + traversal.accept(exp, listener.map(r -> (Object) r)); + } else if (arg instanceof Collection c && c.isEmpty()) { + listener.onResponse(arg); + } else if (arg instanceof List list) { + AtomicReference> transformed = new AtomicReference<>(null); + CountDownActionListener completionListener = new CountDownActionListener( + list.size(), + listener.delegateFailureIgnoreResponseAndWrap(l -> l.onResponse(transformed.get() != null ? transformed.get() : arg)) + ); + for (int i = 0; i < list.size(); i++) { + final int idx = i; + Object el = list.get(i); + doTransformExpression(el, traversal, completionListener.map(next -> { + if (el.equals(next) == false) { + transformed.compareAndSet(null, new ArrayList<>(list)); + transformed.get().set(idx, next); + } + return null; + })); + } + return; + } else if (arg instanceof Collection c) { + AtomicReference> transformed = new AtomicReference<>(null); + CountDownActionListener completionListener = new CountDownActionListener( + c.size(), + listener.delegateFailureIgnoreResponseAndWrap(l -> l.onResponse(transformed.get() != null ? transformed.get() : arg)) + ); + int i = 0; + for (Object el : c) { + final int idx = i++; + doTransformExpression(el, traversal, completionListener.map(next -> { + if (next.equals(el) == false) { + if (el.equals(next) == false) { + transformed.compareAndSet(null, new ArrayList<>(c)); + transformed.get().set(idx, next); + } + } + return null; + })); + } + } else { + listener.onResponse(arg); + } + } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/QueryPlanTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/QueryPlanTests.java index dadcd12b31030..fdc3817b6ef9d 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/QueryPlanTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/QueryPlanTests.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.esql.plan; +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Expressions; @@ -14,6 +16,7 @@ import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; +import org.elasticsearch.xpack.esql.core.tree.Node; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; import org.elasticsearch.xpack.esql.plan.logical.Filter; import org.elasticsearch.xpack.esql.plan.logical.Limit; @@ -23,6 +26,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.function.Predicate; import static java.util.Arrays.asList; import static java.util.Collections.emptyList; @@ -34,6 +38,7 @@ import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; public class QueryPlanTests extends ESTestCase { @@ -44,6 +49,15 @@ public void testTransformWithExpressionTopLevel() throws Exception { assertEquals(Limit.class, transformed.getClass()); Limit l = (Limit) transformed; assertEquals(24, l.limit().fold(FoldContext.small())); + + // Test async version is returning the same result as the sync version + SetOnce asyncResultHolder = new SetOnce<>(); + limit.transformExpressionsOnly( + Literal.class, + (e, listener) -> listener.onResponse(of(24)), + ActionListener.wrap(asyncResultHolder::set, ESTestCase::fail) + ); + assertBusy(() -> { assertThat(asyncResultHolder.get(), equalTo(transformed)); }); } public void testTransformWithExpressionTree() throws Exception { @@ -55,6 +69,15 @@ public void testTransformWithExpressionTree() throws Exception { OrderBy order = (OrderBy) transformed; assertEquals(Limit.class, order.child().getClass()); assertEquals(24, ((Limit) order.child()).limit().fold(FoldContext.small())); + + // Test async version is returning the same result as the sync version + SetOnce asyncResultHolder = new SetOnce<>(); + o.transformExpressionsDown( + Literal.class, + (e, listener) -> listener.onResponse(of(24)), + ActionListener.wrap(asyncResultHolder::set, ESTestCase::fail) + ); + assertBusy(() -> { assertThat(asyncResultHolder.get(), equalTo(transformed)); }); } public void testTransformWithExpressionTopLevelInCollection() throws Exception { @@ -74,6 +97,56 @@ public void testTransformWithExpressionTopLevelInCollection() throws Exception { NamedExpression o = p.projections().get(0); assertEquals("changed", o.name()); + + // Test async version is returning the same result as the sync version + SetOnce asyncResultHolder = new SetOnce<>(); + project.transformExpressionsOnly( + NamedExpression.class, + (n, listener) -> listener.onResponse(n.name().equals("one") ? new FieldAttribute(EMPTY, "changed", one.field()) : n), + ActionListener.wrap(asyncResultHolder::set, ESTestCase::fail) + ); + assertBusy(() -> { assertThat(asyncResultHolder.get(), equalTo(transformed)); }); + } + + public void testTransformExpressionsUpTree() throws Exception { + Limit limit = new Limit(EMPTY, of(42), relation()); + OrderBy order = new OrderBy(EMPTY, limit, emptyList()); + + LogicalPlan transformed = order.transformExpressionsUp(Literal.class, l -> of(24)); + + assertEquals(OrderBy.class, transformed.getClass()); + OrderBy out = (OrderBy) transformed; + assertEquals(24, ((Limit) out.child()).limit().fold(FoldContext.small())); + + // Test async version is returning the same result as the sync version + SetOnce asyncResult = new SetOnce<>(); + order.transformExpressionsUp( + Literal.class, + (lit, listener) -> listener.onResponse(of(24)), + ActionListener.wrap(asyncResult::set, ESTestCase::fail) + ); + assertBusy(() -> assertThat(asyncResult.get(), equalTo(transformed))); + } + + public void testTransformExpressionsDownWithPredicate() throws Exception { + Limit limit = new Limit(EMPTY, of(42), relation()); + OrderBy outer = new OrderBy(EMPTY, limit, emptyList()); + + Predicate> onlyLimit = n -> n instanceof Limit; + + LogicalPlan transformed = outer.transformExpressionsDown(onlyLimit, Literal.class, lit -> of(24)); + + assertEquals(24, ((Limit) ((OrderBy) transformed).child()).limit().fold(FoldContext.small())); + + // Test async version is returning the same result as the sync version + SetOnce asyncResult = new SetOnce<>(); + outer.transformExpressionsDown( + onlyLimit, + Literal.class, + (lit, listener) -> listener.onResponse(of(24)), + ActionListener.wrap(asyncResult::set, ESTestCase::fail) + ); + assertBusy(() -> assertThat(asyncResult.get(), equalTo(transformed))); } public void testForEachWithExpressionTopLevel() throws Exception { From 1f034a0e885625f19c427f9884abde3841c5e2ad Mon Sep 17 00:00:00 2001 From: afoucret Date: Wed, 16 Jul 2025 18:18:59 +0200 Subject: [PATCH 10/31] [ESQL] Add TEXT_EMBEDDING function evaluator for dense vector embeddings Implements the core evaluation logic for the TEXT_EMBEDDING function in ES|QL: - Add InferenceFunctionEvaluator interface for all inference functions - Implement TextEmbeddingFunctionEvaluator with support for float/byte/bit vectors - Integration with InferenceRunner for async model execution - Proper conversion of embedding results to DENSE_VECTOR data type --- .../results/TextEmbeddingByteResults.java | 4 +- .../inference/InferenceFunctionEvaluator.java | 27 ++++++++ .../TextEmbeddingFunctionEvaluator.java | 63 +++++++++++++++++++ 3 files changed, 92 insertions(+), 2 deletions(-) create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingFunctionEvaluator.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java index 54f858cb20ae0..d873a5ba37545 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java @@ -155,7 +155,7 @@ public String toString() { return Strings.toString(this); } - float[] toFloatArray() { + public float[] toFloatArray() { float[] floatArray = new float[values.length]; for (int i = 0; i < values.length; i++) { floatArray[i] = ((Byte) values[i]).floatValue(); @@ -163,7 +163,7 @@ float[] toFloatArray() { return floatArray; } - double[] toDoubleArray() { + public double[] toDoubleArray() { double[] doubleArray = new double[values.length]; for (int i = 0; i < values.length; i++) { doubleArray[i] = ((Byte) values[i]).doubleValue(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java new file mode 100644 index 0000000000000..49a4d4cf8c4bf --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction; +import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding; +import org.elasticsearch.xpack.esql.inference.textembedding.TextEmbeddingFunctionEvaluator; + +public interface InferenceFunctionEvaluator { + + void eval(FoldContext foldContext, ActionListener listener); + + static InferenceFunctionEvaluator get(InferenceFunction inferenceFunction, InferenceRunner inferenceRunner) { + return switch (inferenceFunction) { + case TextEmbedding textEmbedding -> new TextEmbeddingFunctionEvaluator(textEmbedding, inferenceRunner); + default -> throw new IllegalArgumentException("Unsupported inference function: " + inferenceFunction.getClass()); + }; + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingFunctionEvaluator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingFunctionEvaluator.java new file mode 100644 index 0000000000000..474ffb87e3b4f --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingFunctionEvaluator.java @@ -0,0 +1,63 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.textembedding; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.lucene.BytesRefs; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding; +import org.elasticsearch.xpack.esql.inference.InferenceFunctionEvaluator; +import org.elasticsearch.xpack.esql.inference.InferenceRunner; + +import java.util.List; + +public class TextEmbeddingFunctionEvaluator implements InferenceFunctionEvaluator { + + private final InferenceRunner inferenceRunner; + + private final TextEmbedding f; + + public TextEmbeddingFunctionEvaluator(TextEmbedding f, InferenceRunner inferenceRunner) { + this.f = f; + this.inferenceRunner = inferenceRunner; + } + + @Override + public void eval(FoldContext foldContext, ActionListener listener) { + assert f.inferenceId() != null && f.inferenceId().foldable() : "inferenceId should not be null and be foldable"; + assert f.inputText() != null && f.inputText().foldable() : "inputText should not be null and be foldable"; + + String inferenceId = BytesRefs.toString(f.inferenceId().fold(foldContext)); + String inputText = BytesRefs.toString(f.inputText().fold(foldContext)); + + inferenceRunner.execute(inferenceRequest(inferenceId, inputText), listener.map(this::parseInferenceResponse)); + } + + private InferenceAction.Request inferenceRequest(String inferenceId, String inputText) { + return InferenceAction.Request.builder(inferenceId, TaskType.TEXT_EMBEDDING).setInput(List.of(inputText)).build(); + } + + private Literal parseInferenceResponse(InferenceAction.Response response) { + float[] embeddingValues = switch (response.getResults()) { + case TextEmbeddingFloatResults floatEmbeddingResults -> floatEmbeddingResults.embeddings().get(0).values(); + case TextEmbeddingByteResults bytesEmbeddingResults -> bytesEmbeddingResults.embeddings().get(0).toFloatArray(); + case TextEmbeddingBitResults bitsEmbeddingResults -> bitsEmbeddingResults.embeddings().get(0).toFloatArray(); + default -> throw new IllegalArgumentException("Inference response should be of type TextEmbeddingResults"); + }; + + return new Literal(f.source(), embeddingValues, DataType.DENSE_VECTOR); + } +} From 419d523f643275a18592b8d1a870b29dec3a3c22 Mon Sep 17 00:00:00 2001 From: afoucret Date: Wed, 16 Jul 2025 18:22:57 +0200 Subject: [PATCH 11/31] [ESQL] Complete TEXT_EMBEDDING function integration Integrates the TEXT_EMBEDDING function with the ESQL execution pipeline: - Update PreOptimizer to handle TEXT_EMBEDDING function evaluation - Add TextEmbedding function definition and type validation - Integrate with InferenceServices for model execution - Add comprehensive tests in PreOptimizerTests - Update session and execution components for async function support --- .../function/inference/InferenceFunction.java | 6 + .../function/inference/TextEmbedding.java | 6 - .../xpack/esql/optimizer/PreOptimizer.java | 34 ++- .../esql/optimizer/PreOptimizerTests.java | 196 ++++++++++++++++++ 4 files changed, 234 insertions(+), 8 deletions(-) create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizerTests.java diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java index d32587dc71058..331d688f7b0f9 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java @@ -36,4 +36,10 @@ protected InferenceFunction(Source source, List children) { public abstract TaskType taskType(); public abstract PlanType withInferenceResolutionError(String inferenceId, String error); + + @Override + public boolean foldable() { + // Inference functions are not foldable and need to be evaluated using an async inference call. + return false; + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbedding.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbedding.java index 7d1cd356ea1ac..0ba6d138a13f7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbedding.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbedding.java @@ -119,12 +119,6 @@ protected TypeResolution resolveType() { return TypeResolution.TYPE_RESOLVED; } - @Override - public boolean foldable() { - // The function is foldable only if both arguments are foldable - return inputText.foldable() && inferenceId.foldable(); - } - @Override public TaskType taskType() { return TaskType.TEXT_EMBEDDING; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizer.java index 4390dc48be587..6d40d93ad2bf6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizer.java @@ -8,7 +8,13 @@ package org.elasticsearch.xpack.esql.optimizer; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction; +import org.elasticsearch.xpack.esql.inference.InferenceFunctionEvaluator; +import org.elasticsearch.xpack.esql.inference.InferenceRunner; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plugin.TransportActionServices; /** * The class is responsible for invoking any steps that need to be applied to the logical plan, @@ -19,11 +25,35 @@ */ public class PreOptimizer { - public PreOptimizer() { + private final InferencePreOptimizer inferencePreOptimizer; + public PreOptimizer(TransportActionServices services, FoldContext foldContext) { + this(services.inferenceRunner(), foldContext); + } + + PreOptimizer(InferenceRunner inferenceRunner, FoldContext foldContext) { + this.inferencePreOptimizer = new InferencePreOptimizer(inferenceRunner, foldContext); } public void preOptimize(LogicalPlan plan, ActionListener listener) { - listener.onResponse(plan); + inferencePreOptimizer.foldInferenceFunctions(plan, listener); + } + + private static class InferencePreOptimizer { + private final InferenceRunner inferenceRunner; + private final FoldContext foldContext; + + private InferencePreOptimizer(InferenceRunner inferenceRunner, FoldContext foldContext) { + this.inferenceRunner = inferenceRunner; + this.foldContext = foldContext; + } + + private void foldInferenceFunctions(LogicalPlan plan, ActionListener listener) { + plan.transformExpressionsUp(InferenceFunction.class, this::foldInferenceFunction, listener); + } + + private void foldInferenceFunction(InferenceFunction inferenceFunction, ActionListener listener) { + InferenceFunctionEvaluator.get(inferenceFunction, inferenceRunner).eval(foldContext, listener); + } } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizerTests.java new file mode 100644 index 0000000000000..9cd5bd324da4a --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizerTests.java @@ -0,0 +1,196 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer; + +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.esql.core.expression.Alias; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding; +import org.elasticsearch.xpack.esql.expression.function.vector.Knn; +import org.elasticsearch.xpack.esql.inference.InferenceRunner; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator; +import org.elasticsearch.xpack.esql.plan.logical.EsRelation; +import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.Filter; + +import java.util.List; + +import static org.elasticsearch.xpack.esql.EsqlTestUtils.as; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.of; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.relation; +import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.notNullValue; + +public class PreOptimizerTests extends ESTestCase { + + public void testEvalFunctionEmbeddingBytes() throws Exception { + testEvalFunctionEmbedding(BYTES_EMBEDDING_MODEL); + } + + public void testEvalFunctionEmbeddingBits() throws Exception { + testEvalFunctionEmbedding(BIT_EMBEDDING_MODEL); + } + + public void testEvalFunctionEmbeddingFloats() throws Exception { + testEvalFunctionEmbedding(FLOAT_EMBEDDING_MODEL); + } + + public void testKnnFunctionEmbeddingBytes() throws Exception { + testKnnFunctionEmbedding(BYTES_EMBEDDING_MODEL); + } + + public void testKnnFunctionEmbeddingBits() throws Exception { + testKnnFunctionEmbedding(BIT_EMBEDDING_MODEL); + } + + public void testKnnFunctionEmbeddingFloats() throws Exception { + testKnnFunctionEmbedding(FLOAT_EMBEDDING_MODEL); + } + + private void testEvalFunctionEmbedding(TextEmbeddingModelMock textEmbeddingModel) throws Exception { + String inferenceId = randomUUID(); + String query = String.join(" ", randomArray(1, String[]::new, () -> randomAlphaOfLength(randomIntBetween(1, 10)))); + String fieldName = randomIdentifier(); + + PreOptimizer preOptimizer = new PreOptimizer(mockInferenceRunner(textEmbeddingModel), FoldContext.small()); + EsRelation relation = relation(); + Eval eval = new Eval( + Source.EMPTY, + relation, + List.of(new Alias(Source.EMPTY, fieldName, new TextEmbedding(Source.EMPTY, of(query), of(inferenceId)))) + ); + + SetOnce preOptimizedPlanHolder = new SetOnce<>(); + preOptimizer.preOptimize(eval, ActionListener.wrap(preOptimizedPlanHolder::set, ESTestCase::fail)); + + assertBusy(() -> { + assertThat(preOptimizedPlanHolder.get(), notNullValue()); + Eval preOptimizedEval = as(preOptimizedPlanHolder.get(), Eval.class); + assertThat(preOptimizedEval.fields(), hasSize(1)); + assertThat(preOptimizedEval.fields().get(0).name(), equalTo(fieldName)); + Literal preOptimizedQuery = as(preOptimizedEval.fields().get(0).child(), Literal.class); + assertThat(preOptimizedQuery.dataType(), equalTo(DENSE_VECTOR)); + assertThat(preOptimizedQuery.value(), equalTo(textEmbeddingModel.embedding(query))); + }); + } + + private void testKnnFunctionEmbedding(TextEmbeddingModelMock textEmbeddingModel) throws Exception { + String inferenceId = randomUUID(); + String query = String.join(" ", randomArray(1, String[]::new, () -> randomAlphaOfLength(randomIntBetween(1, 10)))); + + PreOptimizer preOptimizer = new PreOptimizer(mockInferenceRunner(textEmbeddingModel), FoldContext.small()); + EsRelation relation = relation(); + Filter filter = new Filter( + Source.EMPTY, + relation, + new Knn(Source.EMPTY, getFieldAttribute("a"), new TextEmbedding(Source.EMPTY, of(query), of(inferenceId)), of(10), null) + ); + Knn knn = as(filter.condition(), Knn.class); + + SetOnce preOptimizedHolder = new SetOnce<>(); + preOptimizer.preOptimize(filter, ActionListener.wrap(preOptimizedHolder::set, ESTestCase::fail)); + + assertBusy(() -> { + assertThat(preOptimizedHolder.get(), notNullValue()); + Filter preOptimizedFilter = as(preOptimizedHolder.get(), Filter.class); + Knn preOptimizedKnn = as(preOptimizedFilter.condition(), Knn.class); + assertThat(preOptimizedKnn.field(), equalTo(knn.field())); + assertThat(preOptimizedKnn.k(), equalTo(knn.k())); + assertThat(preOptimizedKnn.options(), equalTo(knn.options())); + + Literal preOptimizedQuery = as(preOptimizedKnn.query(), Literal.class); + assertThat(preOptimizedQuery.dataType(), equalTo(DENSE_VECTOR)); + assertThat(preOptimizedQuery.value(), equalTo(textEmbeddingModel.embedding(query))); + }); + } + + private InferenceRunner mockInferenceRunner(TextEmbeddingModelMock textEmbeddingModel) { + return new InferenceRunner() { + @Override + public void execute(InferenceAction.Request request, ActionListener listener) { + listener.onResponse(new InferenceAction.Response(textEmbeddingModel.embeddingResults(request.getInput().getFirst()))); + } + + @Override + public void executeBulk(BulkInferenceRequestIterator requests, ActionListener> listener) { + listener.onFailure( + new UnsupportedOperationException("executeBulk should not be invoked for plans without inference functions") + ); + } + }; + } + + private interface TextEmbeddingModelMock { + TextEmbeddingResults embeddingResults(String input); + + float[] embedding(String input); + } + + private static final TextEmbeddingModelMock FLOAT_EMBEDDING_MODEL = new TextEmbeddingModelMock() { + public TextEmbeddingResults embeddingResults(String input) { + TextEmbeddingFloatResults.Embedding embedding = new TextEmbeddingFloatResults.Embedding(embedding(input)); + return new TextEmbeddingFloatResults(List.of(embedding)); + } + + public float[] embedding(String input) { + String[] tokens = input.split("\\s+"); + float[] embedding = new float[tokens.length]; + for (int i = 0; i < tokens.length; i++) { + embedding[i] = tokens[i].length(); + } + return embedding; + } + }; + + private static final TextEmbeddingModelMock BYTES_EMBEDDING_MODEL = new TextEmbeddingModelMock() { + public TextEmbeddingResults embeddingResults(String input) { + TextEmbeddingByteResults.Embedding embedding = new TextEmbeddingByteResults.Embedding(bytes(input)); + return new TextEmbeddingBitResults(List.of(embedding)); + } + + private byte[] bytes(String input) { + return input.getBytes(); + } + + public float[] embedding(String input) { + return new TextEmbeddingByteResults.Embedding(bytes(input)).toFloatArray(); + } + }; + + private static final TextEmbeddingModelMock BIT_EMBEDDING_MODEL = new TextEmbeddingModelMock() { + public TextEmbeddingResults embeddingResults(String input) { + TextEmbeddingByteResults.Embedding embedding = new TextEmbeddingByteResults.Embedding(bytes(input)); + return new TextEmbeddingBitResults(List.of(embedding)); + } + + private byte[] bytes(String input) { + String[] tokens = input.split("\\s+"); + byte[] embedding = new byte[tokens.length]; + for (int i = 0; i < tokens.length; i++) { + embedding[i] = (byte) (tokens[i].length() % 2); + } + return embedding; + } + + public float[] embedding(String input) { + return new TextEmbeddingByteResults.Embedding(bytes(input)).toFloatArray(); + } + }; +} From 3a506c5d6af7e213a1a60be795e6a12342e07e02 Mon Sep 17 00:00:00 2001 From: afoucret Date: Wed, 16 Jul 2025 18:29:50 +0200 Subject: [PATCH 12/31] Fix error. --- .../elasticsearch/xpack/esql/expression/ExpressionWritables.java | 1 + 1 file changed, 1 insertion(+) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java index fedae9ec54d9e..15eff01ad96a1 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.expression; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; import org.elasticsearch.xpack.esql.core.expression.ExpressionCoreWritables; import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateWritables; From 5e556640b740a0ceac76ecd560dda981ec7d1070 Mon Sep 17 00:00:00 2001 From: afoucret Date: Wed, 16 Jul 2025 21:01:31 +0200 Subject: [PATCH 13/31] Lint --- .../xpack/esql/core/tree/NodeInfo.java | 126 +++++++++--------- .../TextEmbeddingFunctionEvaluator.java | 22 ++- 2 files changed, 82 insertions(+), 66 deletions(-) diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/NodeInfo.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/NodeInfo.java index bea432cfb0d13..bff5f8293d44b 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/NodeInfo.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/NodeInfo.java @@ -256,12 +256,12 @@ protected void innerTransform(BiConsumer> rule, A newProps -> innerProperties().equals(newProps) ? node : ctor.apply( - node.source(), - (P1) newProps.get(0), - (P2) newProps.get(1), - (P3) newProps.get(2), - (P4) newProps.get(3) - ) + node.source(), + (P1) newProps.get(0), + (P2) newProps.get(1), + (P3) newProps.get(2), + (P4) newProps.get(3) + ) ) ); } @@ -317,13 +317,13 @@ protected void innerTransform(BiConsumer> rule, A newProps -> innerProperties().equals(newProps) ? node : ctor.apply( - node.source(), - (P1) newProps.get(0), - (P2) newProps.get(1), - (P3) newProps.get(2), - (P4) newProps.get(3), - (P5) newProps.get(4) - ) + node.source(), + (P1) newProps.get(0), + (P2) newProps.get(1), + (P3) newProps.get(2), + (P4) newProps.get(3), + (P5) newProps.get(4) + ) ) ); } @@ -383,14 +383,14 @@ protected void innerTransform(BiConsumer> rule, A newProps -> innerProperties().equals(newProps) ? node : ctor.apply( - node.source(), - (P1) newProps.get(0), - (P2) newProps.get(1), - (P3) newProps.get(2), - (P4) newProps.get(3), - (P5) newProps.get(4), - (P6) newProps.get(5) - ) + node.source(), + (P1) newProps.get(0), + (P2) newProps.get(1), + (P3) newProps.get(2), + (P4) newProps.get(3), + (P5) newProps.get(4), + (P6) newProps.get(5) + ) ) ); } @@ -454,15 +454,15 @@ protected void innerTransform(BiConsumer> rule, A newProps -> innerProperties().equals(newProps) ? node : ctor.apply( - node.source(), - (P1) newProps.get(0), - (P2) newProps.get(1), - (P3) newProps.get(2), - (P4) newProps.get(3), - (P5) newProps.get(4), - (P6) newProps.get(5), - (P7) newProps.get(6) - ) + node.source(), + (P1) newProps.get(0), + (P2) newProps.get(1), + (P3) newProps.get(2), + (P4) newProps.get(3), + (P5) newProps.get(4), + (P6) newProps.get(5), + (P7) newProps.get(6) + ) ) ); } @@ -530,16 +530,16 @@ protected void innerTransform(BiConsumer> rule, A newProps -> innerProperties().equals(newProps) ? node : ctor.apply( - node.source(), - (P1) newProps.get(0), - (P2) newProps.get(1), - (P3) newProps.get(2), - (P4) newProps.get(3), - (P5) newProps.get(4), - (P6) newProps.get(5), - (P7) newProps.get(6), - (P8) newProps.get(7) - ) + node.source(), + (P1) newProps.get(0), + (P2) newProps.get(1), + (P3) newProps.get(2), + (P4) newProps.get(3), + (P5) newProps.get(4), + (P6) newProps.get(5), + (P7) newProps.get(6), + (P8) newProps.get(7) + ) ) ); } @@ -611,17 +611,17 @@ protected void innerTransform(BiConsumer> rule, A newProps -> innerProperties().equals(newProps) ? node : ctor.apply( - node.source(), - (P1) newProps.get(0), - (P2) newProps.get(1), - (P3) newProps.get(2), - (P4) newProps.get(3), - (P5) newProps.get(4), - (P6) newProps.get(5), - (P7) newProps.get(6), - (P8) newProps.get(7), - (P9) newProps.get(8) - ) + node.source(), + (P1) newProps.get(0), + (P2) newProps.get(1), + (P3) newProps.get(2), + (P4) newProps.get(3), + (P5) newProps.get(4), + (P6) newProps.get(5), + (P7) newProps.get(6), + (P8) newProps.get(7), + (P9) newProps.get(8) + ) ) ); } @@ -697,18 +697,18 @@ protected void innerTransform(BiConsumer> rule, A newProps -> innerProperties().equals(newProps) ? node : ctor.apply( - node.source(), - (P1) newProps.get(0), - (P2) newProps.get(1), - (P3) newProps.get(2), - (P4) newProps.get(3), - (P5) newProps.get(4), - (P6) newProps.get(5), - (P7) newProps.get(6), - (P8) newProps.get(7), - (P9) newProps.get(8), - (P10) newProps.get(9) - ) + node.source(), + (P1) newProps.get(0), + (P2) newProps.get(1), + (P3) newProps.get(2), + (P4) newProps.get(3), + (P5) newProps.get(4), + (P6) newProps.get(5), + (P7) newProps.get(6), + (P8) newProps.get(7), + (P9) newProps.get(8), + (P10) newProps.get(9) + ) ) ); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingFunctionEvaluator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingFunctionEvaluator.java index 474ffb87e3b4f..a7a054fc81c15 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingFunctionEvaluator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingFunctionEvaluator.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; @@ -22,6 +23,7 @@ import org.elasticsearch.xpack.esql.inference.InferenceFunctionEvaluator; import org.elasticsearch.xpack.esql.inference.InferenceRunner; +import java.util.ArrayList; import java.util.List; public class TextEmbeddingFunctionEvaluator implements InferenceFunctionEvaluator { @@ -51,13 +53,27 @@ private InferenceAction.Request inferenceRequest(String inferenceId, String inpu } private Literal parseInferenceResponse(InferenceAction.Response response) { - float[] embeddingValues = switch (response.getResults()) { + if (response.getResults() instanceof TextEmbeddingResults textEmbeddingResults) { + return parseInferenceResponse(textEmbeddingResults); + } + throw new IllegalArgumentException("Inference response should be of type TextEmbeddingResults"); + } + + private Literal parseInferenceResponse(TextEmbeddingResults result) { + List embeddingList = new ArrayList<>(result.getFirstEmbeddingSize()); + for (float value : getEmbeddingValues(result)) { + embeddingList.add(value); + } + + return new Literal(f.source(), embeddingList, DataType.DENSE_VECTOR); + } + + private float[] getEmbeddingValues(TextEmbeddingResults result) { + return switch (result) { case TextEmbeddingFloatResults floatEmbeddingResults -> floatEmbeddingResults.embeddings().get(0).values(); case TextEmbeddingByteResults bytesEmbeddingResults -> bytesEmbeddingResults.embeddings().get(0).toFloatArray(); case TextEmbeddingBitResults bitsEmbeddingResults -> bitsEmbeddingResults.embeddings().get(0).toFloatArray(); default -> throw new IllegalArgumentException("Inference response should be of type TextEmbeddingResults"); }; - - return new Literal(f.source(), embeddingValues, DataType.DENSE_VECTOR); } } From 13360990462c4ce9bef5a5d59b60bd5edd9ab958 Mon Sep 17 00:00:00 2001 From: afoucret Date: Wed, 16 Jul 2025 21:03:18 +0200 Subject: [PATCH 14/31] Introduce a new PRE_OPTIMIZED to the LogicalPlan --- .../xpack/esql/optimizer/PreOptimizer.java | 9 ++++++++- .../xpack/esql/plan/logical/LogicalPlan.java | 8 ++++++++ .../esql/optimizer/PreOptimizerTests.java | 18 ++++++++++++++++-- 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizer.java index 6d40d93ad2bf6..42cbfcf3071aa 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizer.java @@ -36,7 +36,14 @@ public PreOptimizer(TransportActionServices services, FoldContext foldContext) { } public void preOptimize(LogicalPlan plan, ActionListener listener) { - inferencePreOptimizer.foldInferenceFunctions(plan, listener); + if (plan.analyzed() == false) { + throw new IllegalStateException("Expected analyzed plan"); + } + + inferencePreOptimizer.foldInferenceFunctions(plan, listener.safeMap(p -> { + p.setPreOptimized(); + return p; + })); } private static class InferencePreOptimizer { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/LogicalPlan.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/LogicalPlan.java index 762b22389ae24..f2dd34240ffef 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/LogicalPlan.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/LogicalPlan.java @@ -65,6 +65,14 @@ public boolean optimized() { return stage.ordinal() >= Stage.OPTIMIZED.ordinal(); } + public void setPreOptimized() { + stage = Stage.PRE_OPTIMIZED; + } + + public boolean preOptimized() { + return stage.ordinal() >= Stage.PRE_OPTIMIZED.ordinal(); + } + public void setOptimized() { stage = Stage.OPTIMIZED; } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizerTests.java index 9cd5bd324da4a..7483855fc3575 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizerTests.java @@ -27,6 +27,7 @@ import org.elasticsearch.xpack.esql.plan.logical.Eval; import org.elasticsearch.xpack.esql.plan.logical.Filter; +import java.util.ArrayList; import java.util.List; import static org.elasticsearch.xpack.esql.EsqlTestUtils.as; @@ -76,6 +77,7 @@ private void testEvalFunctionEmbedding(TextEmbeddingModelMock textEmbeddingModel relation, List.of(new Alias(Source.EMPTY, fieldName, new TextEmbedding(Source.EMPTY, of(query), of(inferenceId)))) ); + eval.setAnalyzed(); SetOnce preOptimizedPlanHolder = new SetOnce<>(); preOptimizer.preOptimize(eval, ActionListener.wrap(preOptimizedPlanHolder::set, ESTestCase::fail)); @@ -83,11 +85,12 @@ private void testEvalFunctionEmbedding(TextEmbeddingModelMock textEmbeddingModel assertBusy(() -> { assertThat(preOptimizedPlanHolder.get(), notNullValue()); Eval preOptimizedEval = as(preOptimizedPlanHolder.get(), Eval.class); + assertThat(preOptimizedEval.preOptimized(), equalTo(true)); assertThat(preOptimizedEval.fields(), hasSize(1)); assertThat(preOptimizedEval.fields().get(0).name(), equalTo(fieldName)); Literal preOptimizedQuery = as(preOptimizedEval.fields().get(0).child(), Literal.class); assertThat(preOptimizedQuery.dataType(), equalTo(DENSE_VECTOR)); - assertThat(preOptimizedQuery.value(), equalTo(textEmbeddingModel.embedding(query))); + assertThat(preOptimizedQuery.value(), equalTo(textEmbeddingModel.embeddingList(query))); }); } @@ -102,6 +105,7 @@ private void testKnnFunctionEmbedding(TextEmbeddingModelMock textEmbeddingModel) relation, new Knn(Source.EMPTY, getFieldAttribute("a"), new TextEmbedding(Source.EMPTY, of(query), of(inferenceId)), of(10), null) ); + filter.setAnalyzed(); Knn knn = as(filter.condition(), Knn.class); SetOnce preOptimizedHolder = new SetOnce<>(); @@ -110,6 +114,7 @@ private void testKnnFunctionEmbedding(TextEmbeddingModelMock textEmbeddingModel) assertBusy(() -> { assertThat(preOptimizedHolder.get(), notNullValue()); Filter preOptimizedFilter = as(preOptimizedHolder.get(), Filter.class); + assertThat(preOptimizedFilter.preOptimized(), equalTo(true)); Knn preOptimizedKnn = as(preOptimizedFilter.condition(), Knn.class); assertThat(preOptimizedKnn.field(), equalTo(knn.field())); assertThat(preOptimizedKnn.k(), equalTo(knn.k())); @@ -117,7 +122,7 @@ private void testKnnFunctionEmbedding(TextEmbeddingModelMock textEmbeddingModel) Literal preOptimizedQuery = as(preOptimizedKnn.query(), Literal.class); assertThat(preOptimizedQuery.dataType(), equalTo(DENSE_VECTOR)); - assertThat(preOptimizedQuery.value(), equalTo(textEmbeddingModel.embedding(query))); + assertThat(preOptimizedQuery.value(), equalTo(textEmbeddingModel.embeddingList(query))); }); } @@ -141,6 +146,15 @@ private interface TextEmbeddingModelMock { TextEmbeddingResults embeddingResults(String input); float[] embedding(String input); + + default List embeddingList(String input) { + float[] embedding = embedding(input); + List embeddingList = new ArrayList<>(embedding.length); + for (float value : embedding) { + embeddingList.add(value); + } + return embeddingList; + } } private static final TextEmbeddingModelMock FLOAT_EMBEDDING_MODEL = new TextEmbeddingModelMock() { From 9c619b9ac10a2b1146cd818a99c7d1c3e73b8fab Mon Sep 17 00:00:00 2001 From: afoucret Date: Wed, 16 Jul 2025 21:03:28 +0200 Subject: [PATCH 15/31] Add basic csv tests. --- .../xpack/esql/CsvTestsDataLoader.java | 23 +++++++++++++++++++ .../main/resources/text-embedding.csv-spec | 15 ++++++++++++ 2 files changed, 38 insertions(+) create mode 100644 x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java index c01141837fae3..976a064045928 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java @@ -403,6 +403,10 @@ public static void createInferenceEndpoints(RestClient client) throws IOExceptio createSparseEmbeddingInferenceEndpoint(client); } + if (clusterHasDenseEmbeddingInferenceEndpoint(client) == false) { + createDenseEmbeddingInferenceEndpoint(client); + } + if (clusterHasRerankInferenceEndpoint(client) == false) { createRerankInferenceEndpoint(client); } @@ -414,6 +418,7 @@ public static void createInferenceEndpoints(RestClient client) throws IOExceptio public static void deleteInferenceEndpoints(RestClient client) throws IOException { deleteSparseEmbeddingInferenceEndpoint(client); + deleteDenseEmbeddingInferenceEndpoint(client); deleteRerankInferenceEndpoint(client); deleteCompletionInferenceEndpoint(client); } @@ -437,6 +442,24 @@ public static boolean clusterHasSparseEmbeddingInferenceEndpoint(RestClient clie return clusterHasInferenceEndpoint(client, TaskType.SPARSE_EMBEDDING, "test_sparse_inference"); } + public static void createDenseEmbeddingInferenceEndpoint(RestClient client) throws IOException { + createInferenceEndpoint(client, TaskType.TEXT_EMBEDDING, "test_dense_inference", """ + { + "service": "text_embedding_test_service", + "service_settings": { "model": "my_model", "api_key": "abc64", "dimensions": 10 }, + "task_settings": { } + } + """); + } + + public static void deleteDenseEmbeddingInferenceEndpoint(RestClient client) throws IOException { + deleteInferenceEndpoint(client, "test_dense_inference"); + } + + public static boolean clusterHasDenseEmbeddingInferenceEndpoint(RestClient client) throws IOException { + return clusterHasInferenceEndpoint(client, TaskType.TEXT_EMBEDDING, "test_dense_inference"); + } + public static void createRerankInferenceEndpoint(RestClient client) throws IOException { createInferenceEndpoint(client, TaskType.RERANK, "test_reranker", """ { diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec new file mode 100644 index 0000000000000..6a4be805d9219 --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec @@ -0,0 +1,15 @@ +// Note: +// The "test_completion" service returns the prompt in uppercase, making the output easy to guess. + + +completion using a ROW source operator +required_capability: text_embedding_function +required_capability: dense_vector_field_type + +ROW input="Who is Victor Hugo?" +| EVAL embedding = TEXT_EMBEDDING("Who is Victor Hugo?", "test_dense_inference") +; + +input:keyword | embedding:dense_vector +Who is Victor Hugo? | [56.0, 50.0, 48.0, 50.0, 54.0, 52.0, 49.0, 51.0, 51.0, 56.0] +; From e90971dbf976f5b0561d0252de60ceeb98d3df96 Mon Sep 17 00:00:00 2001 From: afoucret Date: Thu, 17 Jul 2025 08:43:41 +0200 Subject: [PATCH 16/31] Fix forbidden API usage --- .../elasticsearch/xpack/esql/optimizer/PreOptimizerTests.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizerTests.java index 7483855fc3575..20fb66dae8326 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizerTests.java @@ -27,6 +27,7 @@ import org.elasticsearch.xpack.esql.plan.logical.Eval; import org.elasticsearch.xpack.esql.plan.logical.Filter; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; @@ -180,7 +181,7 @@ public TextEmbeddingResults embeddingResults(String input) { } private byte[] bytes(String input) { - return input.getBytes(); + return input.getBytes(StandardCharsets.UTF_8); } public float[] embedding(String input) { From 567175de89a64bb043f2ce1a7c37c7723e774083 Mon Sep 17 00:00:00 2001 From: afoucret Date: Thu, 17 Jul 2025 10:24:13 +0200 Subject: [PATCH 17/31] Exclude TEXT_EMBEDDING from CsvTests. --- .../src/test/java/org/elasticsearch/xpack/esql/CsvTests.java | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java index d149fb012a14b..88fc7dc669aba 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java @@ -307,6 +307,10 @@ public final void test() throws Throwable { "can't use KNN function in csv tests", testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KNN_FUNCTION_V3.capabilityName()) ); + assumeFalse( + "can't use TEXT_EMBEDDING function in csv tests", + testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.capabilityName()) + ); assumeFalse( "lookup join disabled for csv tests", testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.JOIN_LOOKUP_V12.capabilityName()) From 0497e7edb6e290ad00d69aa2335989d0b20b0c70 Mon Sep 17 00:00:00 2001 From: afoucret Date: Thu, 17 Jul 2025 10:54:18 +0200 Subject: [PATCH 18/31] Refactored inference folding in PreOptimizer. --- .../xpack/esql/optimizer/PreOptimizer.java | 34 ++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizer.java index 42cbfcf3071aa..c99cee417f6e9 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizer.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.optimizer; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.CountDownActionListener; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction; @@ -16,6 +17,11 @@ import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plugin.TransportActionServices; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + /** * The class is responsible for invoking any steps that need to be applied to the logical plan, * before this is being optimized. @@ -56,7 +62,33 @@ private InferencePreOptimizer(InferenceRunner inferenceRunner, FoldContext foldC } private void foldInferenceFunctions(LogicalPlan plan, ActionListener listener) { - plan.transformExpressionsUp(InferenceFunction.class, this::foldInferenceFunction, listener); + // First let's collect all the inference functions + List> inferenceFunctions = new ArrayList<>(); + plan.forEachExpressionUp(InferenceFunction.class, inferenceFunctions::add); + + if (inferenceFunctions.isEmpty()) { + // No inference functions found. Return the original plan. + listener.onResponse(plan); + return; + } + + // This is a map of inference functions to their results. + // We will use this map to replace the inference functions in the plan. + Map, Expression> inferenceFunctionsToResults = new HashMap<>(); + + // Prepare a listener that will be called when all inference functions are done. + // This listener will replace the inference functions in the plan with their results. + CountDownActionListener completionListener = new CountDownActionListener(inferenceFunctions.size(), listener.map(ignored -> + plan.transformExpressionsUp(InferenceFunction.class, f -> inferenceFunctionsToResults.getOrDefault(f, f)) + )); + + // Try to compute the result for each inference function. + for (InferenceFunction inferenceFunction : inferenceFunctions) { + foldInferenceFunction(inferenceFunction, completionListener.map(e -> { + inferenceFunctionsToResults.put(inferenceFunction, e); + return null; + })); + } } private void foldInferenceFunction(InferenceFunction inferenceFunction, ActionListener listener) { From 35c27a4cc8a56aa98f77ce4df473722701494d93 Mon Sep 17 00:00:00 2001 From: afoucret Date: Thu, 17 Jul 2025 10:54:33 +0200 Subject: [PATCH 19/31] Revert useless node transform changes. --- .../xpack/esql/core/tree/Node.java | 129 +---- .../xpack/esql/core/tree/NodeInfo.java | 236 -------- .../esql/core/tree/NodeTransformTests.java | 516 ------------------ .../xpack/esql/plan/QueryPlan.java | 173 ------ .../xpack/esql/plan/QueryPlanTests.java | 73 --- 5 files changed, 1 insertion(+), 1126 deletions(-) delete mode 100644 x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/tree/NodeTransformTests.java diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/Node.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/Node.java index 190c4c327e099..613f5b0ae76c2 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/Node.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/Node.java @@ -6,8 +6,6 @@ */ package org.elasticsearch.xpack.esql.core.tree; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.support.CountDownActionListener; import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.xpack.esql.core.QlIllegalArgumentException; @@ -16,8 +14,6 @@ import java.util.Iterator; import java.util.List; import java.util.Objects; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Predicate; @@ -191,45 +187,16 @@ public T transformDown(Function rule) { return node.transformChildren(child -> child.transformDown(rule)); } - @SuppressWarnings("unchecked") - public void transformDown(BiConsumer> rule, ActionListener listener) { - // First apply the rule to the current node (top-down) - rule.accept((T) this, listener.delegateFailureAndWrap((l, transformedNode) -> { - // Then recursively transform the children with the same rule - transformedNode.transformChildren((child, childListener) -> child.transformDown(rule, childListener), l); - })); - } - @SuppressWarnings("unchecked") public T transformDown(Class typeToken, Function rule) { return transformDown((t) -> (typeToken.isInstance(t) ? rule.apply((E) t) : t)); } - @SuppressWarnings("unchecked") - public void transformDown(Class typeToken, BiConsumer> rule, ActionListener listener) { - transformDown(typeToken::isInstance, rule, listener); - } - @SuppressWarnings("unchecked") public T transformDown(Predicate> nodePredicate, Function rule) { return transformDown((t) -> (nodePredicate.test(t) ? rule.apply((E) t) : t)); } - @SuppressWarnings("unchecked") - public void transformDown( - Predicate> nodePredicate, - BiConsumer> rule, - ActionListener listener - ) { - transformDown((T node, ActionListener l) -> { - if (nodePredicate.test(node)) { - rule.accept((E) node, l); - } else { - l.onResponse(node); - } - }, listener); - } - @SuppressWarnings("unchecked") public T transformUp(Function rule) { T transformed = transformChildren(child -> child.transformUp(rule)); @@ -238,25 +205,8 @@ public T transformUp(Function rule) { } @SuppressWarnings("unchecked") - public void transformUp(BiConsumer> rule, ActionListener listener) { - // First, recursively transform the children (depth-first, bottom-up) using the same async rule - transformChildren( - // traversal operation applied to each child - (child, childListener) -> child.transformUp(rule, childListener), - // After all children are transformed, apply the rule to the (possibly) new current node - listener.delegateFailureAndWrap((l, transformedChildrenNode) -> { - T node = transformedChildrenNode.equals(this) ? (T) this : transformedChildrenNode; - rule.accept(node, l); - }) - ); - } - public T transformUp(Class typeToken, Function rule) { - return transformUp(typeToken::isInstance, rule); - } - - public void transformUp(Class typeToken, BiConsumer> rule, ActionListener listener) { - transformUp(typeToken::isInstance, rule, listener); + return transformUp((t) -> (typeToken.isInstance(t) ? rule.apply((E) t) : t)); } @SuppressWarnings("unchecked") @@ -264,22 +214,6 @@ public T transformUp(Predicate> nodePredicate, Function (nodePredicate.test(t) ? rule.apply((E) t) : t)); } - @SuppressWarnings("unchecked") - public void transformUp( - Predicate> nodePredicate, - BiConsumer> rule, - ActionListener listener - ) { - transformUp((T node, ActionListener l) -> { - if (nodePredicate.test(node)) { - E typedNode = (E) node; - rule.accept((E) node, l); - } else { - l.onResponse(node); - } - }, listener); - } - @SuppressWarnings("unchecked") protected > T transformChildren(Function traversalOperation) { boolean childrenChanged = false; @@ -304,35 +238,6 @@ public void transformUp( return (childrenChanged ? replaceChildrenSameSize(transformedChildren) : (T) this); } - @SuppressWarnings("unchecked") - protected void transformChildren(BiConsumer> traversalOperation, ActionListener listener) { - if (children.isEmpty()) { - listener.onResponse((T) this); - return; - } - - final AtomicReference> transformedChildren = new AtomicReference<>(null); - - CountDownActionListener countDownListener = new CountDownActionListener( - children.size(), - listener.delegateFailureIgnoreResponseAndWrap((l) -> { - l.onResponse(transformedChildren.get() != null ? replaceChildren(transformedChildren.get()) : (T) this); - }) - ); - - for (int i = 0, s = children.size(); i < s; i++) { - T child = children.get(i); - final int childId = i; - traversalOperation.accept(child, countDownListener.map(next -> { - if (child.equals(next) == false) { - transformedChildren.compareAndSet(null, new ArrayList<>(children)); - transformedChildren.get().set(childId, next); - } - return null; - })); - } - } - public final T replaceChildrenSameSize(List newChildren) { if (newChildren.size() != children.size()) { throw new QlIllegalArgumentException( @@ -352,38 +257,14 @@ public T transformPropertiesOnly(Class typeToken, Function void transformPropertiesOnly( - Class typeToken, - BiConsumer> rule, - ActionListener listener - ) { - transformNodeProps(typeToken, rule, listener); - } - public T transformPropertiesDown(Class typeToken, Function rule) { return transformDown(t -> t.transformNodeProps(typeToken, rule)); } - public void transformPropertiesDown( - Class typeToken, - BiConsumer> rule, - ActionListener listener - ) { - transformDown((t, l) -> t.transformNodeProps(typeToken, rule, l), listener); - } - public T transformPropertiesUp(Class typeToken, Function rule) { return transformUp(t -> t.transformNodeProps(typeToken, rule)); } - public void transformPropertiesUp( - Class typeToken, - BiConsumer> rule, - ActionListener listener - ) { - transformUp((t, l) -> t.transformNodeProps(typeToken, rule, l), listener); - } - /** * Transform this node's properties. *

@@ -396,14 +277,6 @@ protected final T transformNodeProps(Class typeToken, Function void transformNodeProps( - Class typeToken, - BiConsumer> rule, - ActionListener listener - ) { - info().transform(rule, typeToken, listener); - } - /** * Return the information about this node. *

diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/NodeInfo.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/NodeInfo.java index bff5f8293d44b..28e4e739085d4 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/NodeInfo.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/NodeInfo.java @@ -6,15 +6,9 @@ */ package org.elasticsearch.xpack.esql.core.tree; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.support.CountDownActionListener; - -import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Objects; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.Function; @@ -66,27 +60,8 @@ final T transform(Function rule, Class typeToken) return innerTransform(realRule); } - public void transform( - BiConsumer> rule, - Class typeToken, - ActionListener listener - ) { - List children = node.children(); - BiConsumer> realRule = (p, l) -> { - if (p != children && (p == null || typeToken.isInstance(p)) && false == children.contains(p)) { - rule.accept(typeToken.cast(p), l); - } else { - l.onResponse(p); - } - }; - - innerTransform(realRule, listener); - } - protected abstract T innerTransform(Function rule); - protected abstract void innerTransform(BiConsumer> rule, ActionListener listener); - /** * Builds a {@link NodeInfo} for Nodes without any properties. */ @@ -100,10 +75,6 @@ protected List innerProperties() { protected T innerTransform(Function rule) { return node; } - - protected void innerTransform(BiConsumer> rule, ActionListener listener) { - listener.onResponse(node); - } }; } @@ -123,16 +94,6 @@ protected T innerTransform(Function rule) { return same ? node : ctor.apply(node.source(), newP1); } - - @SuppressWarnings("unchecked") - protected void innerTransform(BiConsumer> rule, ActionListener listener) { - transformProperties( - rule, - listener.safeMap( - newProps -> innerProperties().equals(newProps) ? node : ctor.apply(node.source(), (P1) newProps.get(0)) - ) - ); - } }; } @@ -155,18 +116,6 @@ protected T innerTransform(Function rule) { return same ? node : ctor.apply(node.source(), newP1, newP2); } - - @SuppressWarnings("unchecked") - protected void innerTransform(BiConsumer> rule, ActionListener listener) { - transformProperties( - rule, - listener.safeMap( - newProps -> innerProperties().equals(newProps) - ? node - : ctor.apply(node.source(), (P1) newProps.get(0), (P2) newProps.get(1)) - ) - ); - } }; } @@ -196,18 +145,6 @@ protected T innerTransform(Function rule) { return same ? node : ctor.apply(node.source(), newP1, newP2, newP3); } - - @SuppressWarnings("unchecked") - protected void innerTransform(BiConsumer> rule, ActionListener listener) { - transformProperties( - rule, - listener.safeMap( - newProps -> innerProperties().equals(newProps) - ? node - : ctor.apply(node.source(), (P1) newProps.get(0), (P2) newProps.get(1), (P3) newProps.get(2)) - ) - ); - } }; } @@ -247,24 +184,6 @@ protected T innerTransform(Function rule) { return same ? node : ctor.apply(node.source(), newP1, newP2, newP3, newP4); } - - @SuppressWarnings("unchecked") - protected void innerTransform(BiConsumer> rule, ActionListener listener) { - transformProperties( - rule, - listener.safeMap( - newProps -> innerProperties().equals(newProps) - ? node - : ctor.apply( - node.source(), - (P1) newProps.get(0), - (P2) newProps.get(1), - (P3) newProps.get(2), - (P4) newProps.get(3) - ) - ) - ); - } }; } @@ -308,25 +227,6 @@ protected T innerTransform(Function rule) { return same ? node : ctor.apply(node.source(), newP1, newP2, newP3, newP4, newP5); } - - @SuppressWarnings("unchecked") - protected void innerTransform(BiConsumer> rule, ActionListener listener) { - transformProperties( - rule, - listener.safeMap( - newProps -> innerProperties().equals(newProps) - ? node - : ctor.apply( - node.source(), - (P1) newProps.get(0), - (P2) newProps.get(1), - (P3) newProps.get(2), - (P4) newProps.get(3), - (P5) newProps.get(4) - ) - ) - ); - } }; } @@ -374,26 +274,6 @@ protected T innerTransform(Function rule) { return same ? node : ctor.apply(node.source(), newP1, newP2, newP3, newP4, newP5, newP6); } - - @SuppressWarnings("unchecked") - protected void innerTransform(BiConsumer> rule, ActionListener listener) { - transformProperties( - rule, - listener.safeMap( - newProps -> innerProperties().equals(newProps) - ? node - : ctor.apply( - node.source(), - (P1) newProps.get(0), - (P2) newProps.get(1), - (P3) newProps.get(2), - (P4) newProps.get(3), - (P5) newProps.get(4), - (P6) newProps.get(5) - ) - ) - ); - } }; } @@ -445,27 +325,6 @@ protected T innerTransform(Function rule) { return same ? node : ctor.apply(node.source(), newP1, newP2, newP3, newP4, newP5, newP6, newP7); } - - @SuppressWarnings("unchecked") - protected void innerTransform(BiConsumer> rule, ActionListener listener) { - transformProperties( - rule, - listener.safeMap( - newProps -> innerProperties().equals(newProps) - ? node - : ctor.apply( - node.source(), - (P1) newProps.get(0), - (P2) newProps.get(1), - (P3) newProps.get(2), - (P4) newProps.get(3), - (P5) newProps.get(4), - (P6) newProps.get(5), - (P7) newProps.get(6) - ) - ) - ); - } }; } @@ -521,28 +380,6 @@ protected T innerTransform(Function rule) { return same ? node : ctor.apply(node.source(), newP1, newP2, newP3, newP4, newP5, newP6, newP7, newP8); } - - @SuppressWarnings("unchecked") - protected void innerTransform(BiConsumer> rule, ActionListener listener) { - transformProperties( - rule, - listener.safeMap( - newProps -> innerProperties().equals(newProps) - ? node - : ctor.apply( - node.source(), - (P1) newProps.get(0), - (P2) newProps.get(1), - (P3) newProps.get(2), - (P4) newProps.get(3), - (P5) newProps.get(4), - (P6) newProps.get(5), - (P7) newProps.get(6), - (P8) newProps.get(7) - ) - ) - ); - } }; } @@ -602,29 +439,6 @@ protected T innerTransform(Function rule) { return same ? node : ctor.apply(node.source(), newP1, newP2, newP3, newP4, newP5, newP6, newP7, newP8, newP9); } - - @SuppressWarnings("unchecked") - protected void innerTransform(BiConsumer> rule, ActionListener listener) { - transformProperties( - rule, - listener.safeMap( - newProps -> innerProperties().equals(newProps) - ? node - : ctor.apply( - node.source(), - (P1) newProps.get(0), - (P2) newProps.get(1), - (P3) newProps.get(2), - (P4) newProps.get(3), - (P5) newProps.get(4), - (P6) newProps.get(5), - (P7) newProps.get(6), - (P8) newProps.get(7), - (P9) newProps.get(8) - ) - ) - ); - } }; } @@ -688,60 +502,10 @@ protected T innerTransform(Function rule) { return same ? node : ctor.apply(node.source(), newP1, newP2, newP3, newP4, newP5, newP6, newP7, newP8, newP9, newP10); } - - @SuppressWarnings("unchecked") - protected void innerTransform(BiConsumer> rule, ActionListener listener) { - transformProperties( - rule, - listener.safeMap( - newProps -> innerProperties().equals(newProps) - ? node - : ctor.apply( - node.source(), - (P1) newProps.get(0), - (P2) newProps.get(1), - (P3) newProps.get(2), - (P4) newProps.get(3), - (P5) newProps.get(4), - (P6) newProps.get(5), - (P7) newProps.get(6), - (P8) newProps.get(7), - (P9) newProps.get(8), - (P10) newProps.get(9) - ) - ) - ); - } }; } public interface NodeCtor10 { T apply(Source l, P1 p1, P2 p2, P3 p3, P4 p4, P5 p5, P6 p6, P7 p7, P8 p8, P9 p9, P10 p10); } - - protected void transformProperties(BiConsumer> rule, ActionListener> listener) { - List properties = innerProperties(); - if (properties.isEmpty()) { - listener.onResponse(properties); - return; - } - - AtomicReference> transformedProperties = new AtomicReference<>(null); - CountDownActionListener completionListener = new CountDownActionListener(properties.size(), ActionListener.wrap(ignored -> { - List result = transformedProperties.get() != null ? transformedProperties.get() : properties; - listener.onResponse(result); - }, listener::onFailure)); - - for (int i = 0, size = properties.size(); i < size; i++) { - final int currentIndex = i; - Object property = properties.get(currentIndex); - rule.accept(property, completionListener.delegateFailureAndWrap((l, transformed) -> { - if (properties.get(currentIndex).equals(transformed) == false) { - transformedProperties.compareAndSet(null, new ArrayList<>(properties)); - transformedProperties.get().set(currentIndex, transformed); - } - l.onResponse(null); - })); - } - } } diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/tree/NodeTransformTests.java b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/tree/NodeTransformTests.java deleted file mode 100644 index 6f22009ee39bd..0000000000000 --- a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/tree/NodeTransformTests.java +++ /dev/null @@ -1,516 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.esql.core.tree; - -import org.apache.lucene.util.SetOnce; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.test.ESTestCase; - -import java.util.List; -import java.util.function.Function; -import java.util.function.Predicate; - -import static org.elasticsearch.xpack.esql.core.tree.SourceTests.randomSource; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.instanceOf; - -public class NodeTransformTests extends ESTestCase { - // Transform Up Tests - public void testTransformUpSimpleLeafTransformation() throws Exception { - NodeTests.ChildrenAreAProperty tree = createSimpleTree(); - Function transformer = createLeafTransformer(); - - NodeTests.Dummy result = tree.transformUp(transformer); - - assertThat(result, instanceOf(NodeTests.ChildrenAreAProperty.class)); - NodeTests.ChildrenAreAProperty transformed = (NodeTests.ChildrenAreAProperty) result; - assertThat(transformed.children().get(0).thing(), equalTo("leaf1_transformed")); - assertThat(transformed.children().get(1).thing(), equalTo("leaf2_transformed")); - - // Verify async version matches - assertAsyncTransformMatches(tree, transformer, result); - } - - public void testTransformUpWithTypeToken() throws Exception { - NodeTests.ChildrenAreAProperty tree = createSimpleTree(); - Function transformer = n -> new NodeTests.NoChildren(n.source(), n.thing() + "_transformed"); - - NodeTests.Dummy result = tree.transformUp(NodeTests.NoChildren.class, transformer); - - assertThat(result, instanceOf(NodeTests.ChildrenAreAProperty.class)); - NodeTests.ChildrenAreAProperty transformed = (NodeTests.ChildrenAreAProperty) result; - assertThat(transformed.children().get(0).thing(), equalTo("leaf1_transformed")); - assertThat(transformed.children().get(1).thing(), equalTo("leaf2_transformed")); - - // Verify async version matches - SetOnce asyncResult = new SetOnce<>(); - tree.transformUp( - NodeTests.NoChildren.class, - (n, listener) -> listener.onResponse(transformer.apply(n)), - ActionListener.wrap(asyncResult::set, ESTestCase::fail) - ); - assertBusy(() -> assertThat(asyncResult.get(), equalTo(result))); - } - - public void testTransformUpWithPredicate() throws Exception { - NodeTests.ChildrenAreAProperty tree = createSimpleTree(); - Predicate> predicate = n -> n instanceof NodeTests.NoChildren && ((NodeTests.NoChildren) n).thing().equals("leaf1"); - Function transformer = n -> new NodeTests.NoChildren(n.source(), n.thing() + "_transformed"); - - NodeTests.Dummy result = tree.transformUp(predicate, transformer); - - assertThat(result, instanceOf(NodeTests.ChildrenAreAProperty.class)); - NodeTests.ChildrenAreAProperty transformed = (NodeTests.ChildrenAreAProperty) result; - assertThat(transformed.children().get(0).thing(), equalTo("leaf1_transformed")); - assertThat(transformed.children().get(1).thing(), equalTo("leaf2")); // Not transformed - - // Verify async version matches - SetOnce asyncResult = new SetOnce<>(); - tree.transformUp( - predicate, - (n, listener) -> listener.onResponse(transformer.apply(n)), - ActionListener.wrap(asyncResult::set, ESTestCase::fail) - ); - assertBusy(() -> assertThat(asyncResult.get(), equalTo(result))); - } - - public void testTransformUpErrorHandling() { - NodeTests.ChildrenAreAProperty tree = createSimpleTree(); - - RuntimeException e = expectThrows( - RuntimeException.class, - () -> tree.transformUp(n -> { throw new RuntimeException("test error"); }) - ); - assertThat(e.getMessage(), equalTo("test error")); - } - - public void testTransformUpAsyncErrorHandling() throws Exception { - NodeTests.ChildrenAreAProperty tree = createSimpleTree(); - SetOnce exception = new SetOnce<>(); - - tree.transformUp( - (n, l) -> l.onFailure(new RuntimeException("test error")), - ActionListener.wrap(r -> fail("should not be called"), e -> exception.set(asInstanceOf(RuntimeException.class, e))) - ); - - assertBusy(() -> assertThat(exception.get().getMessage(), equalTo("test error"))); - } - - public void testTransformUpNestedStructures() throws Exception { - NodeTests.ChildrenAreAProperty tree = createNestedTree(); - Function transformer = createAllNodesTransformer(); - - NodeTests.Dummy result = tree.transformUp(transformer); - - assertThat(result, instanceOf(NodeTests.ChildrenAreAProperty.class)); - NodeTests.ChildrenAreAProperty transformedOuter = (NodeTests.ChildrenAreAProperty) result; - assertThat(transformedOuter.thing(), equalTo("outer_transformed")); - - NodeTests.Dummy innerResult = transformedOuter.children().get(0); - assertThat(innerResult, instanceOf(NodeTests.ChildrenAreAProperty.class)); - NodeTests.ChildrenAreAProperty transformedInner = (NodeTests.ChildrenAreAProperty) innerResult; - assertThat(transformedInner.thing(), equalTo("inner_transformed")); - assertThat(transformedInner.children().get(0).thing(), equalTo("leaf1_transformed")); - assertThat(transformedInner.children().get(1).thing(), equalTo("leaf2_transformed")); - - // Verify async version matches - assertAsyncTransformMatches(tree, transformer, result); - } - - // Transform Down Tests - public void testTransformDownSimpleLeafTransformation() throws Exception { - NodeTests.ChildrenAreAProperty tree = createSimpleTree(); - Function transformer = createLeafTransformer(); - - NodeTests.Dummy result = tree.transformDown(transformer); - - assertThat(result, instanceOf(NodeTests.ChildrenAreAProperty.class)); - NodeTests.ChildrenAreAProperty transformed = (NodeTests.ChildrenAreAProperty) result; - assertThat(transformed.children().get(0).thing(), equalTo("leaf1_transformed")); - assertThat(transformed.children().get(1).thing(), equalTo("leaf2_transformed")); - - // Verify async version matches - assertAsyncTransformDownMatches(tree, transformer, result); - } - - public void testTransformDownWithTypeToken() throws Exception { - NodeTests.ChildrenAreAProperty tree = createSimpleTree(); - Function transformer = n -> new NodeTests.NoChildren(n.source(), n.thing() + "_transformed"); - - NodeTests.Dummy result = tree.transformDown(NodeTests.NoChildren.class, transformer); - - assertThat(result, instanceOf(NodeTests.ChildrenAreAProperty.class)); - NodeTests.ChildrenAreAProperty transformed = (NodeTests.ChildrenAreAProperty) result; - assertThat(transformed.children().get(0).thing(), equalTo("leaf1_transformed")); - assertThat(transformed.children().get(1).thing(), equalTo("leaf2_transformed")); - - // Verify async version matches - SetOnce asyncResult = new SetOnce<>(); - tree.transformDown( - NodeTests.NoChildren.class, - (n, listener) -> listener.onResponse(transformer.apply(n)), - ActionListener.wrap(asyncResult::set, ESTestCase::fail) - ); - assertBusy(() -> assertThat(asyncResult.get(), equalTo(result))); - } - - public void testTransformDownWithPredicate() throws Exception { - NodeTests.ChildrenAreAProperty tree = createSimpleTree(); - Predicate> predicate = n -> n instanceof NodeTests.NoChildren && ((NodeTests.NoChildren) n).thing().equals("leaf1"); - Function transformer = n -> new NodeTests.NoChildren(n.source(), n.thing() + "_transformed"); - - NodeTests.Dummy result = tree.transformDown(predicate, transformer); - - assertThat(result, instanceOf(NodeTests.ChildrenAreAProperty.class)); - NodeTests.ChildrenAreAProperty transformed = (NodeTests.ChildrenAreAProperty) result; - assertThat(transformed.children().get(0).thing(), equalTo("leaf1_transformed")); - assertThat(transformed.children().get(1).thing(), equalTo("leaf2")); // Not transformed - - // Verify async version matches - SetOnce asyncResult = new SetOnce<>(); - tree.transformDown( - predicate, - (n, listener) -> listener.onResponse(transformer.apply(n)), - ActionListener.wrap(asyncResult::set, ESTestCase::fail) - ); - assertBusy(() -> assertThat(asyncResult.get(), equalTo(result))); - } - - public void testTransformDownErrorHandling() { - NodeTests.ChildrenAreAProperty tree = createSimpleTree(); - - RuntimeException e = expectThrows( - RuntimeException.class, - () -> tree.transformDown(n -> { throw new RuntimeException("test error"); }) - ); - assertThat(e.getMessage(), equalTo("test error")); - } - - public void testTransformDownAsyncErrorHandling() throws Exception { - NodeTests.ChildrenAreAProperty tree = createSimpleTree(); - SetOnce exception = new SetOnce<>(); - - tree.transformDown((n, listener) -> { - if (n instanceof NodeTests.NoChildren) { - listener.onFailure(new RuntimeException("test error")); - } else { - listener.onResponse(n); - } - }, ActionListener.wrap(r -> fail("should not be called"), e -> exception.set(asInstanceOf(RuntimeException.class, e)))); - - assertBusy(() -> { - assertNotNull(exception.get()); - assertThat(exception.get().getMessage(), equalTo("test error")); - }); - } - - public void testTransformDownNestedStructures() throws Exception { - NodeTests.ChildrenAreAProperty tree = createNestedTree(); - Function transformer = createAllNodesTransformer(); - - NodeTests.Dummy result = tree.transformDown(transformer); - - assertThat(result, instanceOf(NodeTests.ChildrenAreAProperty.class)); - NodeTests.ChildrenAreAProperty transformedOuter = (NodeTests.ChildrenAreAProperty) result; - assertThat(transformedOuter.thing(), equalTo("outer_transformed")); - - NodeTests.Dummy innerResult = transformedOuter.children().get(0); - assertThat(innerResult, instanceOf(NodeTests.ChildrenAreAProperty.class)); - NodeTests.ChildrenAreAProperty transformedInner = (NodeTests.ChildrenAreAProperty) innerResult; - assertThat(transformedInner.thing(), equalTo("inner_transformed")); - assertThat(transformedInner.children().get(0).thing(), equalTo("leaf1_transformed")); - assertThat(transformedInner.children().get(1).thing(), equalTo("leaf2_transformed")); - - // Verify async version matches - assertAsyncTransformDownMatches(tree, transformer, result); - } - - public void testTransformPropertiesOnly() throws Exception { - NodeTests.ChildrenAreAProperty tree = createSimpleTree(); - - // Sync transformation: change the root mode property only - NodeTests.Dummy result = tree.transformPropertiesOnly(String.class, s -> s + "_changed"); - NodeTests.ChildrenAreAProperty transformed = (NodeTests.ChildrenAreAProperty) result; - assertThat(transformed.thing(), equalTo("node_changed")); - // Children should remain unchanged because transformPropertiesOnly does not traverse - assertThat(transformed.children().get(0).thing(), equalTo("leaf1")); - assertThat(transformed.children().get(1).thing(), equalTo("leaf2")); - - // Async variant should yield identical result - SetOnce asyncResult = new SetOnce<>(); - tree.transformPropertiesOnly( - String.class, - (s, l) -> l.onResponse(s + "_changed"), - ActionListener.wrap(asyncResult::set, ESTestCase::fail) - ); - assertBusy(() -> assertThat(asyncResult.get(), equalTo(result))); - } - - public void testTransformPropertiesDown() throws Exception { - NodeTests.ChildrenAreAProperty tree = createNestedTree(); - - // Sync transformation: propagate "_changed" to every "thing" property top-down - NodeTests.Dummy result = tree.transformPropertiesDown(String.class, s -> s + "_changed"); - // Root, inner, and leaves should all have suffix - NodeTests.ChildrenAreAProperty outer = (NodeTests.ChildrenAreAProperty) result; - assertThat(outer.thing(), equalTo("outer_changed")); - NodeTests.ChildrenAreAProperty inner = (NodeTests.ChildrenAreAProperty) outer.children().get(0); - assertThat(inner.thing(), equalTo("inner_changed")); - assertThat(inner.children().get(0).thing(), equalTo("leaf1_changed")); - assertThat(inner.children().get(1).thing(), equalTo("leaf2_changed")); - - // Async variant should yield identical result - SetOnce asyncResult = new SetOnce<>(); - tree.transformPropertiesDown( - String.class, - (s, l) -> l.onResponse(s + "_changed"), - ActionListener.wrap(asyncResult::set, ESTestCase::fail) - ); - assertBusy(() -> assertThat(asyncResult.get(), equalTo(result))); - } - - public void testTransformPropertiesUp() throws Exception { - NodeTests.ChildrenAreAProperty tree = createNestedTree(); - - // Sync transformation: propagate "_changed" to every "thing" property bottom-up - NodeTests.Dummy result = tree.transformPropertiesUp(String.class, s -> s + "_changed"); - NodeTests.ChildrenAreAProperty outer = (NodeTests.ChildrenAreAProperty) result; - assertThat(outer.thing(), equalTo("outer_changed")); - NodeTests.ChildrenAreAProperty inner = (NodeTests.ChildrenAreAProperty) outer.children().get(0); - assertThat(inner.thing(), equalTo("inner_changed")); - assertThat(inner.children().get(0).thing(), equalTo("leaf1_changed")); - assertThat(inner.children().get(1).thing(), equalTo("leaf2_changed")); - - // Async variant should yield identical result - SetOnce asyncResult = new SetOnce<>(); - tree.transformPropertiesUp( - String.class, - (s, l) -> l.onResponse(s + "_changed"), - ActionListener.wrap(asyncResult::set, ESTestCase::fail) - ); - assertBusy(() -> assertThat(asyncResult.get(), equalTo(result))); - } - - // Tests demonstrating behavioral differences between transformUp and transformDown - public void testTransformUpVsDownOrderDependentTransformation() { - NodeTests.NoChildren leaf1 = new NodeTests.NoChildren(randomSource(), "leaf"); - NodeTests.NoChildren leaf2 = new NodeTests.NoChildren(randomSource(), "leaf"); - NodeTests.ChildrenAreAProperty innerNode = new NodeTests.ChildrenAreAProperty(randomSource(), List.of(leaf1, leaf2), "inner"); - NodeTests.ChildrenAreAProperty outerNode = new NodeTests.ChildrenAreAProperty(randomSource(), List.of(innerNode), "outer"); - - Function transformerDown = n -> { - if (n instanceof NodeTests.ChildrenAreAProperty) { - NodeTests.ChildrenAreAProperty cn = (NodeTests.ChildrenAreAProperty) n; - return new NodeTests.ChildrenAreAProperty(cn.source(), cn.children(), cn.thing() + "_DOWN"); - } - return n; - }; - - Function transformerUp = n -> { - if (n instanceof NodeTests.ChildrenAreAProperty) { - NodeTests.ChildrenAreAProperty cn = (NodeTests.ChildrenAreAProperty) n; - return new NodeTests.ChildrenAreAProperty(cn.source(), cn.children(), cn.thing() + "_UP"); - } - return n; - }; - - // Transform down: parent first, then children - NodeTests.Dummy resultDown = outerNode.transformDown(transformerDown); - NodeTests.ChildrenAreAProperty outerDown = (NodeTests.ChildrenAreAProperty) resultDown; - NodeTests.ChildrenAreAProperty innerDown = (NodeTests.ChildrenAreAProperty) outerDown.children().get(0); - - // Transform up: children first, then parent - NodeTests.Dummy resultUp = outerNode.transformUp(transformerUp); - NodeTests.ChildrenAreAProperty outerUp = (NodeTests.ChildrenAreAProperty) resultUp; - NodeTests.ChildrenAreAProperty innerUp = (NodeTests.ChildrenAreAProperty) outerUp.children().get(0); - - // Verify transformation order is reflected in results - assertThat(outerDown.thing(), equalTo("outer_DOWN")); - assertThat(innerDown.thing(), equalTo("inner_DOWN")); - assertThat(outerUp.thing(), equalTo("outer_UP")); - assertThat(innerUp.thing(), equalTo("inner_UP")); - } - - public void testTransformUpVsDownChildDependentLogic() { - NodeTests.NoChildren leaf1 = new NodeTests.NoChildren(randomSource(), "A"); - NodeTests.NoChildren leaf2 = new NodeTests.NoChildren(randomSource(), "B"); - NodeTests.ChildrenAreAProperty node = new NodeTests.ChildrenAreAProperty(randomSource(), List.of(leaf1, leaf2), "parent"); - - // Transformer that changes parent based on children's current state - Function transformer = n -> { - if (n instanceof NodeTests.ChildrenAreAProperty) { - NodeTests.ChildrenAreAProperty cn = (NodeTests.ChildrenAreAProperty) n; - // Count how many children have "transformed" in their name - long transformedChildrenCount = cn.children().stream().filter(child -> child.thing().contains("transformed")).count(); - return new NodeTests.ChildrenAreAProperty( - cn.source(), - cn.children(), - cn.thing() + "_sees_" + transformedChildrenCount + "_transformed_children" - ); - } else if (n instanceof NodeTests.NoChildren) { - return new NodeTests.NoChildren(n.source(), n.thing() + "_transformed"); - } - return n; - }; - - // Transform down: parent sees children in original state - NodeTests.Dummy resultDown = node.transformDown(transformer); - NodeTests.ChildrenAreAProperty parentDown = (NodeTests.ChildrenAreAProperty) resultDown; - - // Transform up: parent sees children after they've been transformed - NodeTests.Dummy resultUp = node.transformUp(transformer); - NodeTests.ChildrenAreAProperty parentUp = (NodeTests.ChildrenAreAProperty) resultUp; - - // Key difference: transformDown parent sees 0 transformed children, - // transformUp parent sees 2 transformed children - assertThat(parentDown.thing(), equalTo("parent_sees_0_transformed_children")); - assertThat(parentUp.thing(), equalTo("parent_sees_2_transformed_children")); - - // Both should have transformed children - assertThat(parentDown.children().get(0).thing(), equalTo("A_transformed")); - assertThat(parentDown.children().get(1).thing(), equalTo("B_transformed")); - assertThat(parentUp.children().get(0).thing(), equalTo("A_transformed")); - assertThat(parentUp.children().get(1).thing(), equalTo("B_transformed")); - } - - public void testTransformUpVsDownConditionalTransformation() { - NodeTests.NoChildren leaf1 = new NodeTests.NoChildren(randomSource(), "child1"); - NodeTests.NoChildren leaf2 = new NodeTests.NoChildren(randomSource(), "child2"); - NodeTests.ChildrenAreAProperty node = new NodeTests.ChildrenAreAProperty(randomSource(), List.of(leaf1, leaf2), "STOP"); - - // Transformer that stops transformation if parent has "STOP" in name - Function transformer = n -> { - if (n instanceof NodeTests.ChildrenAreAProperty) { - NodeTests.ChildrenAreAProperty cn = (NodeTests.ChildrenAreAProperty) n; - if (cn.thing().contains("STOP")) { - // Return node unchanged - return cn; - } else { - return new NodeTests.ChildrenAreAProperty(cn.source(), cn.children(), cn.thing() + "_processed"); - } - } else if (n instanceof NodeTests.NoChildren) { - return new NodeTests.NoChildren(n.source(), n.thing() + "_transformed"); - } - return n; - }; - - NodeTests.Dummy resultDown = node.transformDown(transformer); - NodeTests.ChildrenAreAProperty parentDown = (NodeTests.ChildrenAreAProperty) resultDown; - - NodeTests.Dummy resultUp = node.transformUp(transformer); - NodeTests.ChildrenAreAProperty parentUp = (NodeTests.ChildrenAreAProperty) resultUp; - - // Both parents should remain unchanged (contain "STOP") - assertThat(parentDown.thing(), equalTo("STOP")); - assertThat(parentUp.thing(), equalTo("STOP")); - - // Both should have transformed children - assertThat(parentDown.children().get(0).thing(), equalTo("child1_transformed")); - assertThat(parentUp.children().get(0).thing(), equalTo("child1_transformed")); - } - - public void testTransformUpVsDownAccumulativeChanges() { - NodeTests.NoChildren leaf = new NodeTests.NoChildren(randomSource(), "0"); - NodeTests.AChildIsAProperty innerNode = new NodeTests.AChildIsAProperty(randomSource(), leaf, "0"); - NodeTests.AChildIsAProperty outerNode = new NodeTests.AChildIsAProperty(randomSource(), innerNode, "0"); - - // Transformer that increments numeric values - Function transformer = n -> { - try { - int currentValue = Integer.parseInt(n.thing()); - String newValue = String.valueOf(currentValue + 1); - - if (n instanceof NodeTests.NoChildren) { - return new NodeTests.NoChildren(n.source(), newValue); - } else if (n instanceof NodeTests.AChildIsAProperty) { - NodeTests.AChildIsAProperty an = (NodeTests.AChildIsAProperty) n; - return new NodeTests.AChildIsAProperty(an.source(), an.child(), newValue); - } - } catch (NumberFormatException e) { - // If not a number, leave unchanged - } - return n; - }; - - NodeTests.Dummy resultDown = outerNode.transformDown(transformer); - NodeTests.Dummy resultUp = outerNode.transformUp(transformer); - - // Extract the final values - NodeTests.AChildIsAProperty outerDown = (NodeTests.AChildIsAProperty) resultDown; - NodeTests.AChildIsAProperty innerDown = (NodeTests.AChildIsAProperty) outerDown.child(); - NodeTests.NoChildren leafDown = (NodeTests.NoChildren) innerDown.child(); - - NodeTests.AChildIsAProperty outerUp = (NodeTests.AChildIsAProperty) resultUp; - NodeTests.AChildIsAProperty innerUp = (NodeTests.AChildIsAProperty) outerUp.child(); - NodeTests.NoChildren leafUp = (NodeTests.NoChildren) innerUp.child(); - - // All nodes should be incremented to "1" - assertThat(leafDown.thing(), equalTo("1")); - assertThat(leafUp.thing(), equalTo("1")); - assertThat(innerDown.thing(), equalTo("1")); - assertThat(innerUp.thing(), equalTo("1")); - assertThat(outerDown.thing(), equalTo("1")); - assertThat(outerUp.thing(), equalTo("1")); - } - - // Helper methods for transform tests - private NodeTests.ChildrenAreAProperty createSimpleTree() { - NodeTests.NoChildren leaf1 = new NodeTests.NoChildren(randomSource(), "leaf1"); - NodeTests.NoChildren leaf2 = new NodeTests.NoChildren(randomSource(), "leaf2"); - return new NodeTests.ChildrenAreAProperty(randomSource(), List.of(leaf1, leaf2), "node"); - } - - private NodeTests.ChildrenAreAProperty createNestedTree() { - NodeTests.NoChildren leaf1 = new NodeTests.NoChildren(randomSource(), "leaf1"); - NodeTests.NoChildren leaf2 = new NodeTests.NoChildren(randomSource(), "leaf2"); - NodeTests.ChildrenAreAProperty innerNode = new NodeTests.ChildrenAreAProperty(randomSource(), List.of(leaf1, leaf2), "inner"); - return new NodeTests.ChildrenAreAProperty(randomSource(), List.of(innerNode), "outer"); - } - - private Function createLeafTransformer() { - return n -> n instanceof NodeTests.NoChildren ? new NodeTests.NoChildren(n.source(), n.thing() + "_transformed") : n; - } - - private Function createAllNodesTransformer() { - return n -> { - if (n instanceof NodeTests.NoChildren) { - return new NodeTests.NoChildren(n.source(), ((NodeTests.NoChildren) n).thing() + "_transformed"); - } else if (n instanceof NodeTests.ChildrenAreAProperty) { - NodeTests.ChildrenAreAProperty cn = (NodeTests.ChildrenAreAProperty) n; - return new NodeTests.ChildrenAreAProperty(cn.source(), cn.children(), cn.thing() + "_transformed"); - } - return n; - }; - } - - private void assertAsyncTransformMatches( - NodeTests.Dummy node, - Function transformer, - NodeTests.Dummy expectedResult - ) throws Exception { - SetOnce asyncResult = new SetOnce<>(); - ((Node) node).transformUp( - (n, listener) -> listener.onResponse(transformer.apply(n)), - ActionListener.wrap(asyncResult::set, ESTestCase::fail) - ); - assertBusy(() -> assertThat(asyncResult.get(), equalTo(expectedResult))); - } - - private void assertAsyncTransformDownMatches( - NodeTests.Dummy node, - Function transformer, - NodeTests.Dummy expectedResult - ) throws Exception { - SetOnce asyncResult = new SetOnce<>(); - ((Node) node).transformDown( - (n, listener) -> listener.onResponse(transformer.apply(n)), - ActionListener.wrap(asyncResult::set, ESTestCase::fail) - ); - assertBusy(() -> assertThat(asyncResult.get(), equalTo(expectedResult))); - } -} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/QueryPlan.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/QueryPlan.java index 0e127776c30e6..81a89950b0a02 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/QueryPlan.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/QueryPlan.java @@ -6,8 +6,6 @@ */ package org.elasticsearch.xpack.esql.plan; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.support.CountDownActionListener; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.AttributeSet; import org.elasticsearch.xpack.esql.core.expression.Expression; @@ -18,8 +16,6 @@ import java.util.ArrayList; import java.util.Collection; import java.util.List; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Predicate; @@ -106,78 +102,18 @@ public PlanType transformExpressionsOnly(Function doTransformExpression(e, exp -> exp.transformDown(rule))); } - public void transformExpressionsOnly(BiConsumer> rule, ActionListener listener) { - transformPropertiesOnly( - Object.class, - (prop, propListener) -> doTransformExpression( - prop, - (expr, exprListener) -> expr.transformDown(rule, exprListener), - propListener - ), - listener - ); - } - public PlanType transformExpressionsOnly(Class typeToken, Function rule) { return transformPropertiesOnly(Object.class, e -> doTransformExpression(e, exp -> exp.transformDown(typeToken, rule))); } - public void transformExpressionsOnly( - Class typeToken, - BiConsumer> rule, - ActionListener listener - ) { - transformPropertiesOnly( - Object.class, - (prop, propListener) -> doTransformExpression( - prop, - (expr, exprListener) -> expr.transformDown(typeToken, rule, exprListener), - propListener - ), - listener - ); - } - public PlanType transformExpressionsOnlyUp(Class typeToken, Function rule) { return transformPropertiesOnly(Object.class, e -> doTransformExpression(e, exp -> exp.transformUp(typeToken, rule))); } - public void transformExpressionsOnlyUp( - Class typeToken, - BiConsumer> rule, - ActionListener listener - ) { - transformPropertiesOnly( - Object.class, - (prop, propListener) -> doTransformExpression( - prop, - (expr, exprListener) -> expr.transformUp(typeToken, rule, exprListener), - propListener - ), - listener - ); - } - public PlanType transformExpressionsDown(Class typeToken, Function rule) { return transformPropertiesDown(Object.class, e -> doTransformExpression(e, exp -> exp.transformDown(typeToken, rule))); } - public void transformExpressionsDown( - Class typeToken, - BiConsumer> rule, - ActionListener listener - ) { - transformPropertiesDown( - Object.class, - (prop, propListener) -> doTransformExpression( - prop, - (expr, exprListener) -> expr.transformDown(typeToken, rule, exprListener), - propListener - ), - listener - ); - } - public PlanType transformExpressionsDown( Predicate> shouldVisit, Class typeToken, @@ -189,47 +125,10 @@ public PlanType transformExpressionsDown( ); } - public void transformExpressionsDown( - Predicate> shouldVisit, - Class typeToken, - BiConsumer> rule, - ActionListener listener - ) { - transformDown( - shouldVisit, - (plan, planListener) -> plan.transformNodeProps( - Object.class, - (prop, propListener) -> doTransformExpression( - prop, - (expr, exprListener) -> expr.transformDown(typeToken, rule, exprListener), - propListener - ), - planListener - ), - listener - ); - } - public PlanType transformExpressionsUp(Class typeToken, Function rule) { return transformPropertiesUp(Object.class, e -> doTransformExpression(e, exp -> exp.transformUp(typeToken, rule))); } - public void transformExpressionsUp( - Class typeToken, - BiConsumer> rule, - ActionListener listener - ) { - transformPropertiesUp( - Object.class, - (prop, propListener) -> doTransformExpression( - prop, - (expr, exprListener) -> expr.transformUp(typeToken, rule, exprListener), - propListener - ), - listener - ); - } - public PlanType transformExpressionsUp( Predicate> shouldVisit, Class typeToken, @@ -241,27 +140,6 @@ public PlanType transformExpressionsUp( ); } - public void transformExpressionsUp( - Predicate> shouldVisit, - Class typeToken, - BiConsumer> rule, - ActionListener listener - ) { - transformUp( - shouldVisit, - (plan, planListener) -> plan.transformNodeProps( - Object.class, - (prop, propListener) -> doTransformExpression( - prop, - (expr, exprListener) -> expr.transformUp(typeToken, rule, exprListener), - propListener - ), - planListener - ), - listener - ); - } - @SuppressWarnings("unchecked") private static Object doTransformExpression(Object arg, Function traversal) { if (arg instanceof Expression exp) { @@ -339,55 +217,4 @@ private static void doForEachExpression(Object arg, Consumer travers } } } - - private static void doTransformExpression( - Object arg, - BiConsumer> traversal, - ActionListener listener - ) { - if (arg instanceof Expression exp) { - traversal.accept(exp, listener.map(r -> (Object) r)); - } else if (arg instanceof Collection c && c.isEmpty()) { - listener.onResponse(arg); - } else if (arg instanceof List list) { - AtomicReference> transformed = new AtomicReference<>(null); - CountDownActionListener completionListener = new CountDownActionListener( - list.size(), - listener.delegateFailureIgnoreResponseAndWrap(l -> l.onResponse(transformed.get() != null ? transformed.get() : arg)) - ); - for (int i = 0; i < list.size(); i++) { - final int idx = i; - Object el = list.get(i); - doTransformExpression(el, traversal, completionListener.map(next -> { - if (el.equals(next) == false) { - transformed.compareAndSet(null, new ArrayList<>(list)); - transformed.get().set(idx, next); - } - return null; - })); - } - return; - } else if (arg instanceof Collection c) { - AtomicReference> transformed = new AtomicReference<>(null); - CountDownActionListener completionListener = new CountDownActionListener( - c.size(), - listener.delegateFailureIgnoreResponseAndWrap(l -> l.onResponse(transformed.get() != null ? transformed.get() : arg)) - ); - int i = 0; - for (Object el : c) { - final int idx = i++; - doTransformExpression(el, traversal, completionListener.map(next -> { - if (next.equals(el) == false) { - if (el.equals(next) == false) { - transformed.compareAndSet(null, new ArrayList<>(c)); - transformed.get().set(idx, next); - } - } - return null; - })); - } - } else { - listener.onResponse(arg); - } - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/QueryPlanTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/QueryPlanTests.java index fdc3817b6ef9d..dadcd12b31030 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/QueryPlanTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/QueryPlanTests.java @@ -7,8 +7,6 @@ package org.elasticsearch.xpack.esql.plan; -import org.apache.lucene.util.SetOnce; -import org.elasticsearch.action.ActionListener; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Expressions; @@ -16,7 +14,6 @@ import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; -import org.elasticsearch.xpack.esql.core.tree.Node; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; import org.elasticsearch.xpack.esql.plan.logical.Filter; import org.elasticsearch.xpack.esql.plan.logical.Limit; @@ -26,7 +23,6 @@ import java.util.ArrayList; import java.util.List; -import java.util.function.Predicate; import static java.util.Arrays.asList; import static java.util.Collections.emptyList; @@ -38,7 +34,6 @@ import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; import static org.hamcrest.Matchers.contains; -import static org.hamcrest.Matchers.equalTo; public class QueryPlanTests extends ESTestCase { @@ -49,15 +44,6 @@ public void testTransformWithExpressionTopLevel() throws Exception { assertEquals(Limit.class, transformed.getClass()); Limit l = (Limit) transformed; assertEquals(24, l.limit().fold(FoldContext.small())); - - // Test async version is returning the same result as the sync version - SetOnce asyncResultHolder = new SetOnce<>(); - limit.transformExpressionsOnly( - Literal.class, - (e, listener) -> listener.onResponse(of(24)), - ActionListener.wrap(asyncResultHolder::set, ESTestCase::fail) - ); - assertBusy(() -> { assertThat(asyncResultHolder.get(), equalTo(transformed)); }); } public void testTransformWithExpressionTree() throws Exception { @@ -69,15 +55,6 @@ public void testTransformWithExpressionTree() throws Exception { OrderBy order = (OrderBy) transformed; assertEquals(Limit.class, order.child().getClass()); assertEquals(24, ((Limit) order.child()).limit().fold(FoldContext.small())); - - // Test async version is returning the same result as the sync version - SetOnce asyncResultHolder = new SetOnce<>(); - o.transformExpressionsDown( - Literal.class, - (e, listener) -> listener.onResponse(of(24)), - ActionListener.wrap(asyncResultHolder::set, ESTestCase::fail) - ); - assertBusy(() -> { assertThat(asyncResultHolder.get(), equalTo(transformed)); }); } public void testTransformWithExpressionTopLevelInCollection() throws Exception { @@ -97,56 +74,6 @@ public void testTransformWithExpressionTopLevelInCollection() throws Exception { NamedExpression o = p.projections().get(0); assertEquals("changed", o.name()); - - // Test async version is returning the same result as the sync version - SetOnce asyncResultHolder = new SetOnce<>(); - project.transformExpressionsOnly( - NamedExpression.class, - (n, listener) -> listener.onResponse(n.name().equals("one") ? new FieldAttribute(EMPTY, "changed", one.field()) : n), - ActionListener.wrap(asyncResultHolder::set, ESTestCase::fail) - ); - assertBusy(() -> { assertThat(asyncResultHolder.get(), equalTo(transformed)); }); - } - - public void testTransformExpressionsUpTree() throws Exception { - Limit limit = new Limit(EMPTY, of(42), relation()); - OrderBy order = new OrderBy(EMPTY, limit, emptyList()); - - LogicalPlan transformed = order.transformExpressionsUp(Literal.class, l -> of(24)); - - assertEquals(OrderBy.class, transformed.getClass()); - OrderBy out = (OrderBy) transformed; - assertEquals(24, ((Limit) out.child()).limit().fold(FoldContext.small())); - - // Test async version is returning the same result as the sync version - SetOnce asyncResult = new SetOnce<>(); - order.transformExpressionsUp( - Literal.class, - (lit, listener) -> listener.onResponse(of(24)), - ActionListener.wrap(asyncResult::set, ESTestCase::fail) - ); - assertBusy(() -> assertThat(asyncResult.get(), equalTo(transformed))); - } - - public void testTransformExpressionsDownWithPredicate() throws Exception { - Limit limit = new Limit(EMPTY, of(42), relation()); - OrderBy outer = new OrderBy(EMPTY, limit, emptyList()); - - Predicate> onlyLimit = n -> n instanceof Limit; - - LogicalPlan transformed = outer.transformExpressionsDown(onlyLimit, Literal.class, lit -> of(24)); - - assertEquals(24, ((Limit) ((OrderBy) transformed).child()).limit().fold(FoldContext.small())); - - // Test async version is returning the same result as the sync version - SetOnce asyncResult = new SetOnce<>(); - outer.transformExpressionsDown( - onlyLimit, - Literal.class, - (lit, listener) -> listener.onResponse(of(24)), - ActionListener.wrap(asyncResult::set, ESTestCase::fail) - ); - assertBusy(() -> assertThat(asyncResult.get(), equalTo(transformed))); } public void testForEachWithExpressionTopLevel() throws Exception { From 24e242fcd0760f517678624a6d7c93102146a280 Mon Sep 17 00:00:00 2001 From: afoucret Date: Thu, 17 Jul 2025 10:55:32 +0200 Subject: [PATCH 20/31] Fix typo --- .../testFixtures/src/main/resources/text-embedding.csv-spec | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec index 6a4be805d9219..2a771cab88274 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec @@ -1,8 +1,4 @@ -// Note: -// The "test_completion" service returns the prompt in uppercase, making the output easy to guess. - - -completion using a ROW source operator +text_embedding using a ROW source operator required_capability: text_embedding_function required_capability: dense_vector_field_type From b459ef37791be1b203a70edb8cde9b6bad03ff8d Mon Sep 17 00:00:00 2001 From: afoucret Date: Thu, 17 Jul 2025 11:26:35 +0200 Subject: [PATCH 21/31] Renamed PreOptimizer into LogicalPlanPreOptimizer --- .../optimizer/LogicalPlanPreOptimizer.java | 68 +++++- .../xpack/esql/optimizer/PreOptimizer.java | 98 -------- .../xpack/esql/plan/logical/LogicalPlan.java | 8 - .../LogicalPlanPreOptimizerTests.java | 201 ++++++++++++++++- .../esql/optimizer/PreOptimizerTests.java | 211 ------------------ 5 files changed, 260 insertions(+), 326 deletions(-) delete mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizer.java delete mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizerTests.java diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java index fdd8e1318f636..4e9eba6cec19a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java @@ -8,7 +8,19 @@ package org.elasticsearch.xpack.esql.optimizer; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.CountDownActionListener; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction; +import org.elasticsearch.xpack.esql.inference.InferenceFunctionEvaluator; +import org.elasticsearch.xpack.esql.inference.InferenceRunner; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plugin.TransportActionServices; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; /** * The class is responsible for invoking any steps that need to be applied to the logical plan, @@ -19,10 +31,10 @@ */ public class LogicalPlanPreOptimizer { - private final LogicalPreOptimizerContext preOptimizerContext; + private final InferenceFunctionFolding inferenceFunctionFolding; - public LogicalPlanPreOptimizer(LogicalPreOptimizerContext preOptimizerContext) { - this.preOptimizerContext = preOptimizerContext; + public LogicalPlanPreOptimizer(TransportActionServices services, LogicalPreOptimizerContext preOptimizerContext) { + this.inferenceFunctionFolding = new InferenceFunctionFolding(services.inferenceRunner(), preOptimizerContext.foldCtx()); } /** @@ -44,7 +56,53 @@ public void preOptimize(LogicalPlan plan, ActionListener listener) } private void doPreOptimize(LogicalPlan plan, ActionListener listener) { - // this is where we will be executing async tasks - listener.onResponse(plan); + inferenceFunctionFolding.foldInferenceFunctions(plan, listener); + } + + private static class InferenceFunctionFolding { + private final InferenceRunner inferenceRunner; + private final FoldContext foldContext; + + private InferenceFunctionFolding(InferenceRunner inferenceRunner, FoldContext foldContext) { + this.inferenceRunner = inferenceRunner; + this.foldContext = foldContext; + } + + private void foldInferenceFunctions(LogicalPlan plan, ActionListener listener) { + // First let's collect all the inference functions + List> inferenceFunctions = new ArrayList<>(); + plan.forEachExpressionUp(InferenceFunction.class, inferenceFunctions::add); + + if (inferenceFunctions.isEmpty()) { + // No inference functions found. Return the original plan. + listener.onResponse(plan); + return; + } + + // This is a map of inference functions to their results. + // We will use this map to replace the inference functions in the plan. + Map, Expression> inferenceFunctionsToResults = new HashMap<>(); + + // Prepare a listener that will be called when all inference functions are done. + // This listener will replace the inference functions in the plan with their results. + CountDownActionListener completionListener = new CountDownActionListener( + inferenceFunctions.size(), + listener.map( + ignored -> plan.transformExpressionsUp(InferenceFunction.class, f -> inferenceFunctionsToResults.getOrDefault(f, f)) + ) + ); + + // Try to compute the result for each inference function. + for (InferenceFunction inferenceFunction : inferenceFunctions) { + foldInferenceFunction(inferenceFunction, completionListener.map(e -> { + inferenceFunctionsToResults.put(inferenceFunction, e); + return null; + })); + } + } + + private void foldInferenceFunction(InferenceFunction inferenceFunction, ActionListener listener) { + InferenceFunctionEvaluator.get(inferenceFunction, inferenceRunner).eval(foldContext, listener); + } } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizer.java deleted file mode 100644 index c99cee417f6e9..0000000000000 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizer.java +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.esql.optimizer; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.support.CountDownActionListener; -import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.expression.FoldContext; -import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction; -import org.elasticsearch.xpack.esql.inference.InferenceFunctionEvaluator; -import org.elasticsearch.xpack.esql.inference.InferenceRunner; -import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.plugin.TransportActionServices; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -/** - * The class is responsible for invoking any steps that need to be applied to the logical plan, - * before this is being optimized. - *

- * This is useful, especially if you need to execute some async tasks before the plan is optimized. - *

- */ -public class PreOptimizer { - - private final InferencePreOptimizer inferencePreOptimizer; - - public PreOptimizer(TransportActionServices services, FoldContext foldContext) { - this(services.inferenceRunner(), foldContext); - } - - PreOptimizer(InferenceRunner inferenceRunner, FoldContext foldContext) { - this.inferencePreOptimizer = new InferencePreOptimizer(inferenceRunner, foldContext); - } - - public void preOptimize(LogicalPlan plan, ActionListener listener) { - if (plan.analyzed() == false) { - throw new IllegalStateException("Expected analyzed plan"); - } - - inferencePreOptimizer.foldInferenceFunctions(plan, listener.safeMap(p -> { - p.setPreOptimized(); - return p; - })); - } - - private static class InferencePreOptimizer { - private final InferenceRunner inferenceRunner; - private final FoldContext foldContext; - - private InferencePreOptimizer(InferenceRunner inferenceRunner, FoldContext foldContext) { - this.inferenceRunner = inferenceRunner; - this.foldContext = foldContext; - } - - private void foldInferenceFunctions(LogicalPlan plan, ActionListener listener) { - // First let's collect all the inference functions - List> inferenceFunctions = new ArrayList<>(); - plan.forEachExpressionUp(InferenceFunction.class, inferenceFunctions::add); - - if (inferenceFunctions.isEmpty()) { - // No inference functions found. Return the original plan. - listener.onResponse(plan); - return; - } - - // This is a map of inference functions to their results. - // We will use this map to replace the inference functions in the plan. - Map, Expression> inferenceFunctionsToResults = new HashMap<>(); - - // Prepare a listener that will be called when all inference functions are done. - // This listener will replace the inference functions in the plan with their results. - CountDownActionListener completionListener = new CountDownActionListener(inferenceFunctions.size(), listener.map(ignored -> - plan.transformExpressionsUp(InferenceFunction.class, f -> inferenceFunctionsToResults.getOrDefault(f, f)) - )); - - // Try to compute the result for each inference function. - for (InferenceFunction inferenceFunction : inferenceFunctions) { - foldInferenceFunction(inferenceFunction, completionListener.map(e -> { - inferenceFunctionsToResults.put(inferenceFunction, e); - return null; - })); - } - } - - private void foldInferenceFunction(InferenceFunction inferenceFunction, ActionListener listener) { - InferenceFunctionEvaluator.get(inferenceFunction, inferenceRunner).eval(foldContext, listener); - } - } -} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/LogicalPlan.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/LogicalPlan.java index f2dd34240ffef..762b22389ae24 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/LogicalPlan.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/LogicalPlan.java @@ -65,14 +65,6 @@ public boolean optimized() { return stage.ordinal() >= Stage.OPTIMIZED.ordinal(); } - public void setPreOptimized() { - stage = Stage.PRE_OPTIMIZED; - } - - public boolean preOptimized() { - return stage.ordinal() >= Stage.PRE_OPTIMIZED.ordinal(); - } - public void setOptimized() { stage = Stage.OPTIMIZED; } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java index 8e573dd1cf3c9..845fbc380d092 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java @@ -10,26 +10,47 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionListener; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding; import org.elasticsearch.xpack.esql.expression.function.scalar.string.Concat; +import org.elasticsearch.xpack.esql.expression.function.vector.Knn; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; +import org.elasticsearch.xpack.esql.inference.InferenceRunner; +import org.elasticsearch.xpack.esql.inference.InferenceServices; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator; +import org.elasticsearch.xpack.esql.plan.logical.EsRelation; import org.elasticsearch.xpack.esql.plan.logical.Eval; import org.elasticsearch.xpack.esql.plan.logical.Filter; import org.elasticsearch.xpack.esql.plan.logical.Limit; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.Project; +import org.elasticsearch.xpack.esql.plugin.TransportActionServices; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; import java.util.List; import static org.elasticsearch.xpack.esql.EsqlTestUtils.as; import static org.elasticsearch.xpack.esql.EsqlTestUtils.fieldAttribute; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute; import static org.elasticsearch.xpack.esql.EsqlTestUtils.of; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.relation; +import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.notNullValue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class LogicalPlanPreOptimizerTests extends ESTestCase { @@ -52,14 +73,42 @@ public void testPreOptimizeFailsIfPlanIsNotAnalyzed() throws Exception { }); } + public void testEvalFunctionEmbeddingBytes() throws Exception { + testEvalFunctionEmbedding(BYTES_EMBEDDING_MODEL); + } + + public void testEvalFunctionEmbeddingBits() throws Exception { + testEvalFunctionEmbedding(BIT_EMBEDDING_MODEL); + } + + public void testEvalFunctionEmbeddingFloats() throws Exception { + testEvalFunctionEmbedding(FLOAT_EMBEDDING_MODEL); + } + + public void testKnnFunctionEmbeddingBytes() throws Exception { + testKnnFunctionEmbedding(BYTES_EMBEDDING_MODEL); + } + + public void testKnnFunctionEmbeddingBits() throws Exception { + testKnnFunctionEmbedding(BIT_EMBEDDING_MODEL); + } + + public void testKnnFunctionEmbeddingFloats() throws Exception { + testKnnFunctionEmbedding(FLOAT_EMBEDDING_MODEL); + } + public LogicalPlan preOptimizedPlan(LogicalPlan plan) throws Exception { + return preOptimizedPlan(preOptimizer(), plan); + } + + public LogicalPlan preOptimizedPlan(LogicalPlanPreOptimizer preOptimizer, LogicalPlan plan) throws Exception { // set plan as analyzed plan.setPreOptimized(); SetOnce resultHolder = new SetOnce<>(); SetOnce exceptionHolder = new SetOnce<>(); - preOptimizer().preOptimize(plan, ActionListener.wrap(resultHolder::set, exceptionHolder::set)); + preOptimizer.preOptimize(plan, ActionListener.wrap(resultHolder::set, exceptionHolder::set)); if (exceptionHolder.get() != null) { throw exceptionHolder.get(); @@ -71,9 +120,64 @@ public LogicalPlan preOptimizedPlan(LogicalPlan plan) throws Exception { return resultHolder.get(); } - private LogicalPlanPreOptimizer preOptimizer() { + private void testEvalFunctionEmbedding(TextEmbeddingModelMock textEmbeddingModel) throws Exception { + String inferenceId = randomUUID(); + String query = String.join(" ", randomArray(1, String[]::new, () -> randomAlphaOfLength(randomIntBetween(1, 10)))); + String fieldName = randomIdentifier(); + + LogicalPlanPreOptimizer preOptimizer = preOptimizer(textEmbeddingModel); + EsRelation relation = relation(); + Eval eval = new Eval( + Source.EMPTY, + relation, + List.of(new Alias(Source.EMPTY, fieldName, new TextEmbedding(Source.EMPTY, of(query), of(inferenceId)))) + ); + eval.setAnalyzed(); + + Eval preOptimizedEval = as(preOptimizedPlan(preOptimizer, eval), Eval.class); + assertThat(preOptimizedEval.fields(), hasSize(1)); + assertThat(preOptimizedEval.fields().get(0).name(), equalTo(fieldName)); + Literal preOptimizedQuery = as(preOptimizedEval.fields().get(0).child(), Literal.class); + assertThat(preOptimizedQuery.dataType(), equalTo(DENSE_VECTOR)); + assertThat(preOptimizedQuery.value(), equalTo(textEmbeddingModel.embeddingList(query))); + + } + + private void testKnnFunctionEmbedding(TextEmbeddingModelMock textEmbeddingModel) throws Exception { + String inferenceId = randomUUID(); + String query = String.join(" ", randomArray(1, String[]::new, () -> randomAlphaOfLength(randomIntBetween(1, 10)))); + + LogicalPlanPreOptimizer preOptimizer = preOptimizer(textEmbeddingModel); + EsRelation relation = relation(); + Filter filter = new Filter( + Source.EMPTY, + relation, + new Knn(Source.EMPTY, getFieldAttribute("a"), new TextEmbedding(Source.EMPTY, of(query), of(inferenceId)), of(10), null) + ); + Knn knn = as(filter.condition(), Knn.class); + + Filter preOptimizedFilter = as(preOptimizedPlan(preOptimizer, filter), Filter.class); + Knn preOptimizedKnn = as(preOptimizedFilter.condition(), Knn.class); + assertThat(preOptimizedKnn.field(), equalTo(knn.field())); + assertThat(preOptimizedKnn.k(), equalTo(knn.k())); + assertThat(preOptimizedKnn.options(), equalTo(knn.options())); + + Literal preOptimizedQuery = as(preOptimizedKnn.query(), Literal.class); + assertThat(preOptimizedQuery.dataType(), equalTo(DENSE_VECTOR)); + assertThat(preOptimizedQuery.value(), equalTo(textEmbeddingModel.embeddingList(query))); + } + + private static LogicalPlanPreOptimizer preOptimizer() { + return preOptimizer(randomFrom(FLOAT_EMBEDDING_MODEL, BYTES_EMBEDDING_MODEL, BIT_EMBEDDING_MODEL)); + } + + private static LogicalPlanPreOptimizer preOptimizer(TextEmbeddingModelMock textEmbeddingModel) { + return preOptimizer(mockInferenceRunner(textEmbeddingModel)); + } + + public static LogicalPlanPreOptimizer preOptimizer(InferenceRunner inferenceRunner) { LogicalPreOptimizerContext preOptimizerContext = new LogicalPreOptimizerContext(FoldContext.small()); - return new LogicalPlanPreOptimizer(preOptimizerContext); + return new LogicalPlanPreOptimizer(mockTransportActionServices(inferenceRunner), preOptimizerContext); } private LogicalPlan randomPlan() { @@ -92,10 +196,11 @@ private LogicalPlan randomPlan() { } private Expression randomExpression() { - return switch (randomInt(3)) { + return switch (randomInt(4)) { case 0 -> of(randomInt()); case 1 -> of(randomIdentifier()); case 2 -> new Add(Source.EMPTY, of(randomInt()), of(randomDouble())); + case 3 -> new TextEmbedding(Source.EMPTY, of(randomIdentifier()), of(randomIdentifier())); default -> new Concat(Source.EMPTY, of(randomIdentifier()), randomList(1, 10, () -> of(randomIdentifier()))); }; } @@ -107,4 +212,92 @@ private Expression randomCondition() { return EsqlTestUtils.greaterThanOf(randomExpression(), randomExpression()); } + + private static TransportActionServices mockTransportActionServices(InferenceRunner inferenceRunner) { + InferenceServices inferenceServices = mock(InferenceServices.class); + when(inferenceServices.inferenceRunner()).thenReturn(inferenceRunner); + return new TransportActionServices(null, null, null, null, null, null, null, inferenceServices); + } + + private static InferenceRunner mockInferenceRunner(TextEmbeddingModelMock textEmbeddingModel) { + return new InferenceRunner() { + @Override + public void execute(InferenceAction.Request request, ActionListener listener) { + listener.onResponse(new InferenceAction.Response(textEmbeddingModel.embeddingResults(request.getInput().getFirst()))); + } + + @Override + public void executeBulk(BulkInferenceRequestIterator requests, ActionListener> listener) { + listener.onFailure( + new UnsupportedOperationException("executeBulk should not be invoked for plans without inference functions") + ); + } + }; + } + + private interface TextEmbeddingModelMock { + TextEmbeddingResults embeddingResults(String input); + + float[] embedding(String input); + + default List embeddingList(String input) { + float[] embedding = embedding(input); + List embeddingList = new ArrayList<>(embedding.length); + for (float value : embedding) { + embeddingList.add(value); + } + return embeddingList; + } + } + + private static final TextEmbeddingModelMock FLOAT_EMBEDDING_MODEL = new TextEmbeddingModelMock() { + public TextEmbeddingResults embeddingResults(String input) { + TextEmbeddingFloatResults.Embedding embedding = new TextEmbeddingFloatResults.Embedding(embedding(input)); + return new TextEmbeddingFloatResults(List.of(embedding)); + } + + public float[] embedding(String input) { + String[] tokens = input.split("\\s+"); + float[] embedding = new float[tokens.length]; + for (int i = 0; i < tokens.length; i++) { + embedding[i] = tokens[i].length(); + } + return embedding; + } + }; + + private static final TextEmbeddingModelMock BYTES_EMBEDDING_MODEL = new TextEmbeddingModelMock() { + public TextEmbeddingResults embeddingResults(String input) { + TextEmbeddingByteResults.Embedding embedding = new TextEmbeddingByteResults.Embedding(bytes(input)); + return new TextEmbeddingBitResults(List.of(embedding)); + } + + private byte[] bytes(String input) { + return input.getBytes(StandardCharsets.UTF_8); + } + + public float[] embedding(String input) { + return new TextEmbeddingByteResults.Embedding(bytes(input)).toFloatArray(); + } + }; + + private static final TextEmbeddingModelMock BIT_EMBEDDING_MODEL = new TextEmbeddingModelMock() { + public TextEmbeddingResults embeddingResults(String input) { + TextEmbeddingByteResults.Embedding embedding = new TextEmbeddingByteResults.Embedding(bytes(input)); + return new TextEmbeddingBitResults(List.of(embedding)); + } + + private byte[] bytes(String input) { + String[] tokens = input.split("\\s+"); + byte[] embedding = new byte[tokens.length]; + for (int i = 0; i < tokens.length; i++) { + embedding[i] = (byte) (tokens[i].length() % 2); + } + return embedding; + } + + public float[] embedding(String input) { + return new TextEmbeddingByteResults.Embedding(bytes(input)).toFloatArray(); + } + }; } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizerTests.java deleted file mode 100644 index 20fb66dae8326..0000000000000 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizerTests.java +++ /dev/null @@ -1,211 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.esql.optimizer; - -import org.apache.lucene.util.SetOnce; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; -import org.elasticsearch.xpack.esql.core.expression.Alias; -import org.elasticsearch.xpack.esql.core.expression.FoldContext; -import org.elasticsearch.xpack.esql.core.expression.Literal; -import org.elasticsearch.xpack.esql.core.tree.Source; -import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding; -import org.elasticsearch.xpack.esql.expression.function.vector.Knn; -import org.elasticsearch.xpack.esql.inference.InferenceRunner; -import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator; -import org.elasticsearch.xpack.esql.plan.logical.EsRelation; -import org.elasticsearch.xpack.esql.plan.logical.Eval; -import org.elasticsearch.xpack.esql.plan.logical.Filter; - -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.List; - -import static org.elasticsearch.xpack.esql.EsqlTestUtils.as; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.of; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.relation; -import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.hasSize; -import static org.hamcrest.Matchers.notNullValue; - -public class PreOptimizerTests extends ESTestCase { - - public void testEvalFunctionEmbeddingBytes() throws Exception { - testEvalFunctionEmbedding(BYTES_EMBEDDING_MODEL); - } - - public void testEvalFunctionEmbeddingBits() throws Exception { - testEvalFunctionEmbedding(BIT_EMBEDDING_MODEL); - } - - public void testEvalFunctionEmbeddingFloats() throws Exception { - testEvalFunctionEmbedding(FLOAT_EMBEDDING_MODEL); - } - - public void testKnnFunctionEmbeddingBytes() throws Exception { - testKnnFunctionEmbedding(BYTES_EMBEDDING_MODEL); - } - - public void testKnnFunctionEmbeddingBits() throws Exception { - testKnnFunctionEmbedding(BIT_EMBEDDING_MODEL); - } - - public void testKnnFunctionEmbeddingFloats() throws Exception { - testKnnFunctionEmbedding(FLOAT_EMBEDDING_MODEL); - } - - private void testEvalFunctionEmbedding(TextEmbeddingModelMock textEmbeddingModel) throws Exception { - String inferenceId = randomUUID(); - String query = String.join(" ", randomArray(1, String[]::new, () -> randomAlphaOfLength(randomIntBetween(1, 10)))); - String fieldName = randomIdentifier(); - - PreOptimizer preOptimizer = new PreOptimizer(mockInferenceRunner(textEmbeddingModel), FoldContext.small()); - EsRelation relation = relation(); - Eval eval = new Eval( - Source.EMPTY, - relation, - List.of(new Alias(Source.EMPTY, fieldName, new TextEmbedding(Source.EMPTY, of(query), of(inferenceId)))) - ); - eval.setAnalyzed(); - - SetOnce preOptimizedPlanHolder = new SetOnce<>(); - preOptimizer.preOptimize(eval, ActionListener.wrap(preOptimizedPlanHolder::set, ESTestCase::fail)); - - assertBusy(() -> { - assertThat(preOptimizedPlanHolder.get(), notNullValue()); - Eval preOptimizedEval = as(preOptimizedPlanHolder.get(), Eval.class); - assertThat(preOptimizedEval.preOptimized(), equalTo(true)); - assertThat(preOptimizedEval.fields(), hasSize(1)); - assertThat(preOptimizedEval.fields().get(0).name(), equalTo(fieldName)); - Literal preOptimizedQuery = as(preOptimizedEval.fields().get(0).child(), Literal.class); - assertThat(preOptimizedQuery.dataType(), equalTo(DENSE_VECTOR)); - assertThat(preOptimizedQuery.value(), equalTo(textEmbeddingModel.embeddingList(query))); - }); - } - - private void testKnnFunctionEmbedding(TextEmbeddingModelMock textEmbeddingModel) throws Exception { - String inferenceId = randomUUID(); - String query = String.join(" ", randomArray(1, String[]::new, () -> randomAlphaOfLength(randomIntBetween(1, 10)))); - - PreOptimizer preOptimizer = new PreOptimizer(mockInferenceRunner(textEmbeddingModel), FoldContext.small()); - EsRelation relation = relation(); - Filter filter = new Filter( - Source.EMPTY, - relation, - new Knn(Source.EMPTY, getFieldAttribute("a"), new TextEmbedding(Source.EMPTY, of(query), of(inferenceId)), of(10), null) - ); - filter.setAnalyzed(); - Knn knn = as(filter.condition(), Knn.class); - - SetOnce preOptimizedHolder = new SetOnce<>(); - preOptimizer.preOptimize(filter, ActionListener.wrap(preOptimizedHolder::set, ESTestCase::fail)); - - assertBusy(() -> { - assertThat(preOptimizedHolder.get(), notNullValue()); - Filter preOptimizedFilter = as(preOptimizedHolder.get(), Filter.class); - assertThat(preOptimizedFilter.preOptimized(), equalTo(true)); - Knn preOptimizedKnn = as(preOptimizedFilter.condition(), Knn.class); - assertThat(preOptimizedKnn.field(), equalTo(knn.field())); - assertThat(preOptimizedKnn.k(), equalTo(knn.k())); - assertThat(preOptimizedKnn.options(), equalTo(knn.options())); - - Literal preOptimizedQuery = as(preOptimizedKnn.query(), Literal.class); - assertThat(preOptimizedQuery.dataType(), equalTo(DENSE_VECTOR)); - assertThat(preOptimizedQuery.value(), equalTo(textEmbeddingModel.embeddingList(query))); - }); - } - - private InferenceRunner mockInferenceRunner(TextEmbeddingModelMock textEmbeddingModel) { - return new InferenceRunner() { - @Override - public void execute(InferenceAction.Request request, ActionListener listener) { - listener.onResponse(new InferenceAction.Response(textEmbeddingModel.embeddingResults(request.getInput().getFirst()))); - } - - @Override - public void executeBulk(BulkInferenceRequestIterator requests, ActionListener> listener) { - listener.onFailure( - new UnsupportedOperationException("executeBulk should not be invoked for plans without inference functions") - ); - } - }; - } - - private interface TextEmbeddingModelMock { - TextEmbeddingResults embeddingResults(String input); - - float[] embedding(String input); - - default List embeddingList(String input) { - float[] embedding = embedding(input); - List embeddingList = new ArrayList<>(embedding.length); - for (float value : embedding) { - embeddingList.add(value); - } - return embeddingList; - } - } - - private static final TextEmbeddingModelMock FLOAT_EMBEDDING_MODEL = new TextEmbeddingModelMock() { - public TextEmbeddingResults embeddingResults(String input) { - TextEmbeddingFloatResults.Embedding embedding = new TextEmbeddingFloatResults.Embedding(embedding(input)); - return new TextEmbeddingFloatResults(List.of(embedding)); - } - - public float[] embedding(String input) { - String[] tokens = input.split("\\s+"); - float[] embedding = new float[tokens.length]; - for (int i = 0; i < tokens.length; i++) { - embedding[i] = tokens[i].length(); - } - return embedding; - } - }; - - private static final TextEmbeddingModelMock BYTES_EMBEDDING_MODEL = new TextEmbeddingModelMock() { - public TextEmbeddingResults embeddingResults(String input) { - TextEmbeddingByteResults.Embedding embedding = new TextEmbeddingByteResults.Embedding(bytes(input)); - return new TextEmbeddingBitResults(List.of(embedding)); - } - - private byte[] bytes(String input) { - return input.getBytes(StandardCharsets.UTF_8); - } - - public float[] embedding(String input) { - return new TextEmbeddingByteResults.Embedding(bytes(input)).toFloatArray(); - } - }; - - private static final TextEmbeddingModelMock BIT_EMBEDDING_MODEL = new TextEmbeddingModelMock() { - public TextEmbeddingResults embeddingResults(String input) { - TextEmbeddingByteResults.Embedding embedding = new TextEmbeddingByteResults.Embedding(bytes(input)); - return new TextEmbeddingBitResults(List.of(embedding)); - } - - private byte[] bytes(String input) { - String[] tokens = input.split("\\s+"); - byte[] embedding = new byte[tokens.length]; - for (int i = 0; i < tokens.length; i++) { - embedding[i] = (byte) (tokens[i].length() % 2); - } - return embedding; - } - - public float[] embedding(String input) { - return new TextEmbeddingByteResults.Embedding(bytes(input)).toFloatArray(); - } - }; -} From 40402926eab85c3c089a1a1bee16dbaf04d4b545 Mon Sep 17 00:00:00 2001 From: afoucret Date: Thu, 17 Jul 2025 11:27:47 +0200 Subject: [PATCH 22/31] Add TEXT_EMBEDDING_FUNCTION capability to EsqlSpecTestCase::requiresInferenceEndpoint --- .../xpack/esql/qa/rest/EsqlSpecTestCase.java | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java index 7e0bd6031f455..6813f0106f411 100644 --- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java @@ -77,6 +77,7 @@ import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.RERANK; import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.SEMANTIC_TEXT_FIELD_CAPS; import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.SOURCE_FIELD_MAPPING; +import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION; import static org.elasticsearch.xpack.esql.qa.rest.RestEsqlTestCase.assertNotPartial; import static org.elasticsearch.xpack.esql.qa.rest.RestEsqlTestCase.hasCapabilities; @@ -211,8 +212,12 @@ protected boolean supportsInferenceTestService() { } protected boolean requiresInferenceEndpoint() { - return Stream.of(SEMANTIC_TEXT_FIELD_CAPS.capabilityName(), RERANK.capabilityName(), COMPLETION.capabilityName()) - .anyMatch(testCase.requiredCapabilities::contains); + return Stream.of( + SEMANTIC_TEXT_FIELD_CAPS.capabilityName(), + RERANK.capabilityName(), + COMPLETION.capabilityName(), + TEXT_EMBEDDING_FUNCTION.capabilityName() + ).anyMatch(testCase.requiredCapabilities::contains); } protected boolean supportsIndexModeLookup() throws IOException { From 0fda8485844f95041e8231e1d5f850f5712568df Mon Sep 17 00:00:00 2001 From: afoucret Date: Thu, 17 Jul 2025 15:39:10 +0200 Subject: [PATCH 23/31] Restore previous implementation of foldable for TextEmbedding function. --- .../expression/function/inference/InferenceFunction.java | 6 ------ .../esql/expression/function/inference/TextEmbedding.java | 5 +++++ 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java index 331d688f7b0f9..d32587dc71058 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java @@ -36,10 +36,4 @@ protected InferenceFunction(Source source, List children) { public abstract TaskType taskType(); public abstract PlanType withInferenceResolutionError(String inferenceId, String error); - - @Override - public boolean foldable() { - // Inference functions are not foldable and need to be evaluated using an async inference call. - return false; - } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbedding.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbedding.java index 0ba6d138a13f7..14e1cad72253e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbedding.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbedding.java @@ -91,6 +91,11 @@ public Expression inferenceId() { return inferenceId; } + @Override + public boolean foldable() { + return inferenceId.foldable() && inputText.foldable(); + } + @Override public DataType dataType() { return DataType.DENSE_VECTOR; From cd033113437ffaf0a10ec0b4a5711e1e69824540 Mon Sep 17 00:00:00 2001 From: afoucret Date: Thu, 17 Jul 2025 17:12:18 +0200 Subject: [PATCH 24/31] Fix CsvTests --- .../src/test/java/org/elasticsearch/xpack/esql/CsvTests.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java index 88fc7dc669aba..9c3b411ac814a 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java @@ -588,7 +588,7 @@ private ActualResults executePlan(BigArrays bigArrays) throws Exception { null, null, null, - new LogicalPlanPreOptimizer(new LogicalPreOptimizerContext(foldCtx)), + new LogicalPlanPreOptimizer(EsqlTestUtils.MOCK_TRANSPORT_ACTION_SERVICES, new LogicalPreOptimizerContext(foldCtx)), functionRegistry, new LogicalPlanOptimizer(new LogicalOptimizerContext(configuration, foldCtx)), mapper, From f92e1f5412a5361f81057e728fa49f2ffbe2feb0 Mon Sep 17 00:00:00 2001 From: afoucret Date: Thu, 17 Jul 2025 18:08:58 +0200 Subject: [PATCH 25/31] Get rid of the InferenceServices class. --- .../xpack/esql/inference/InferenceResolver.java | 5 +++-- .../xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java | 7 +++---- .../xpack/esql/optimizer/PhysicalPlanOptimizerTests.java | 1 + .../xpack/esql/planner/LocalExecutionPlannerTests.java | 1 + 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java index f7d349281e004..6813dc845fcce 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java @@ -15,6 +15,7 @@ import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan; @@ -155,8 +156,8 @@ private Factory(Client client) { this.client = client; } - public InferenceResolver create() { - return new InferenceResolver(client); + public InferenceResolver create(EsqlFunctionRegistry functionRegistry) { + return new InferenceResolver(client, functionRegistry); } } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java index 845fbc380d092..73f9d9848be7f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java @@ -26,7 +26,6 @@ import org.elasticsearch.xpack.esql.expression.function.vector.Knn; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; import org.elasticsearch.xpack.esql.inference.InferenceRunner; -import org.elasticsearch.xpack.esql.inference.InferenceServices; import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator; import org.elasticsearch.xpack.esql.plan.logical.EsRelation; import org.elasticsearch.xpack.esql.plan.logical.Eval; @@ -214,9 +213,9 @@ private Expression randomCondition() { } private static TransportActionServices mockTransportActionServices(InferenceRunner inferenceRunner) { - InferenceServices inferenceServices = mock(InferenceServices.class); - when(inferenceServices.inferenceRunner()).thenReturn(inferenceRunner); - return new TransportActionServices(null, null, null, null, null, null, null, inferenceServices); + TransportActionServices services = mock(TransportActionServices.class); + when(services.inferenceRunner()).thenReturn(inferenceRunner); + return services; } private static InferenceRunner mockInferenceRunner(TextEmbeddingModelMock textEmbeddingModel) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java index 3464e2e74217a..0a0ce69afd3c9 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java @@ -7888,6 +7888,7 @@ private LocalExecutionPlanner.LocalExecutionPlan physicalOperationsFromPhysicalP TestBlockFactory.getNonBreakingInstance(), Settings.EMPTY, config, + null, new ExchangeSourceHandler(10, null)::createExchangeSource, () -> exchangeSinkHandler.createExchangeSink(() -> {}), null, diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java index b56f4a3a4898b..971c4719d2c99 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java @@ -318,6 +318,7 @@ private LocalExecutionPlanner planner() throws IOException { null, null, null, + null, esPhysicalOperationProviders(shardContexts), shardContexts ); From f3c9324e851470de4d121e2c3d01575a45525ef9 Mon Sep 17 00:00:00 2001 From: afoucret Date: Fri, 18 Jul 2025 14:51:58 +0200 Subject: [PATCH 26/31] Improved the inference pre-optimization. --- .../function/inference/InferenceFunction.java | 16 + .../function/inference/TextEmbedding.java | 7 + .../inference/InferenceFunctionEvaluator.java | 10 +- .../optimizer/LogicalPlanPreOptimizer.java | 71 +--- .../InferenceFunctionConstantFolding.java | 144 ++++++++ .../preoptimizer/PreOptimizerRule.java | 25 ++ .../AbstractLogicalPlanPreOptimizerTests.java | 315 ++++++++++++++++++ .../LogicalPlanPreOptimizerTests.java | 303 +++-------------- ...InferenceFunctionConstantFoldingTests.java | 111 ++++++ 9 files changed, 678 insertions(+), 324 deletions(-) create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/InferenceFunctionConstantFolding.java create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/PreOptimizerRule.java create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/AbstractLogicalPlanPreOptimizerTests.java create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/InferenceFunctionConstantFoldingTests.java diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java index d32587dc71058..a39a893eb1816 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java @@ -11,6 +11,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.function.Function; import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.inference.InferenceFunctionEvaluator; import java.util.List; @@ -35,5 +36,20 @@ protected InferenceFunction(Source source, List children) { */ public abstract TaskType taskType(); + /** + * Returns a new instance of the function with the specified inference resolution error. + */ public abstract PlanType withInferenceResolutionError(String inferenceId, String error); + + /** + * Returns the inference function evaluator factory. + */ + public abstract InferenceFunctionEvaluator.Factory inferenceEvaluatorFactory(); + + /** + * Returns true if the function has a nested inference function. + */ + public boolean hasNestedInferenceFunction() { + return anyMatch(e -> e instanceof InferenceFunction && e != this); + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbedding.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbedding.java index 14e1cad72253e..cbd9f7aac0cf1 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbedding.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbedding.java @@ -20,6 +20,8 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle; import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.inference.InferenceFunctionEvaluator; +import org.elasticsearch.xpack.esql.inference.textembedding.TextEmbeddingFunctionEvaluator; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import java.io.IOException; @@ -129,6 +131,11 @@ public TaskType taskType() { return TaskType.TEXT_EMBEDDING; } + @Override + public InferenceFunctionEvaluator.Factory inferenceEvaluatorFactory() { + return inferenceRunner -> new TextEmbeddingFunctionEvaluator(this, inferenceRunner); + } + @Override public TextEmbedding withInferenceResolutionError(String inferenceId, String error) { return new TextEmbedding(source(), inputText, new UnresolvedAttribute(inferenceId().source(), inferenceId, error)); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java index 49a4d4cf8c4bf..38db294e5ddf5 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java @@ -10,18 +10,12 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FoldContext; -import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction; -import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding; -import org.elasticsearch.xpack.esql.inference.textembedding.TextEmbeddingFunctionEvaluator; public interface InferenceFunctionEvaluator { void eval(FoldContext foldContext, ActionListener listener); - static InferenceFunctionEvaluator get(InferenceFunction inferenceFunction, InferenceRunner inferenceRunner) { - return switch (inferenceFunction) { - case TextEmbedding textEmbedding -> new TextEmbeddingFunctionEvaluator(textEmbedding, inferenceRunner); - default -> throw new IllegalArgumentException("Unsupported inference function: " + inferenceFunction.getClass()); - }; + interface Factory { + InferenceFunctionEvaluator get(InferenceRunner inferenceRunner); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java index 4e9eba6cec19a..2a9057957f06a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java @@ -8,19 +8,13 @@ package org.elasticsearch.xpack.esql.optimizer; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.support.CountDownActionListener; -import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.expression.FoldContext; -import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction; -import org.elasticsearch.xpack.esql.inference.InferenceFunctionEvaluator; -import org.elasticsearch.xpack.esql.inference.InferenceRunner; +import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer.InferenceFunctionConstantFolding; +import org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer.PreOptimizerRule; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plugin.TransportActionServices; -import java.util.ArrayList; -import java.util.HashMap; import java.util.List; -import java.util.Map; /** * The class is responsible for invoking any steps that need to be applied to the logical plan, @@ -31,10 +25,10 @@ */ public class LogicalPlanPreOptimizer { - private final InferenceFunctionFolding inferenceFunctionFolding; + private final List rules; public LogicalPlanPreOptimizer(TransportActionServices services, LogicalPreOptimizerContext preOptimizerContext) { - this.inferenceFunctionFolding = new InferenceFunctionFolding(services.inferenceRunner(), preOptimizerContext.foldCtx()); + rules = List.of(new InferenceFunctionConstantFolding(services.inferenceRunner(), preOptimizerContext.foldCtx())); } /** @@ -55,54 +49,19 @@ public void preOptimize(LogicalPlan plan, ActionListener listener) })); } + /** + * Loop over the rules and apply them to the logical plan. + * + * @param plan the analyzed logical plan to pre-optimize + * @param listener the listener returning the pre-optimized plan when pre-optimization is complete + */ private void doPreOptimize(LogicalPlan plan, ActionListener listener) { - inferenceFunctionFolding.foldInferenceFunctions(plan, listener); - } - - private static class InferenceFunctionFolding { - private final InferenceRunner inferenceRunner; - private final FoldContext foldContext; + SubscribableListener rulesListener = SubscribableListener.newSucceeded(plan); - private InferenceFunctionFolding(InferenceRunner inferenceRunner, FoldContext foldContext) { - this.inferenceRunner = inferenceRunner; - this.foldContext = foldContext; + for (PreOptimizerRule rule : rules) { + rulesListener = rulesListener.andThen((l, p) -> rule.apply(p, l)); } - private void foldInferenceFunctions(LogicalPlan plan, ActionListener listener) { - // First let's collect all the inference functions - List> inferenceFunctions = new ArrayList<>(); - plan.forEachExpressionUp(InferenceFunction.class, inferenceFunctions::add); - - if (inferenceFunctions.isEmpty()) { - // No inference functions found. Return the original plan. - listener.onResponse(plan); - return; - } - - // This is a map of inference functions to their results. - // We will use this map to replace the inference functions in the plan. - Map, Expression> inferenceFunctionsToResults = new HashMap<>(); - - // Prepare a listener that will be called when all inference functions are done. - // This listener will replace the inference functions in the plan with their results. - CountDownActionListener completionListener = new CountDownActionListener( - inferenceFunctions.size(), - listener.map( - ignored -> plan.transformExpressionsUp(InferenceFunction.class, f -> inferenceFunctionsToResults.getOrDefault(f, f)) - ) - ); - - // Try to compute the result for each inference function. - for (InferenceFunction inferenceFunction : inferenceFunctions) { - foldInferenceFunction(inferenceFunction, completionListener.map(e -> { - inferenceFunctionsToResults.put(inferenceFunction, e); - return null; - })); - } - } - - private void foldInferenceFunction(InferenceFunction inferenceFunction, ActionListener listener) { - InferenceFunctionEvaluator.get(inferenceFunction, inferenceRunner).eval(foldContext, listener); - } + rulesListener.addListener(listener); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/InferenceFunctionConstantFolding.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/InferenceFunctionConstantFolding.java new file mode 100644 index 0000000000000..3dc788fb8e955 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/InferenceFunctionConstantFolding.java @@ -0,0 +1,144 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.CountDownActionListener; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction; +import org.elasticsearch.xpack.esql.inference.InferenceRunner; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Pre-optimizer rule that evaluates inference functions (like TEXT_EMBEDDING) into constant values. + *

+ * This rule identifies foldable inference functions in the logical plan, executes them using the + * inference runner, and replaces them with their computed results. This enables downstream + * optimizations to work with the actual embedding values rather than the function calls. + *

+ * The rule processes inference functions recursively, handling newly revealed functions that might + * appear after the first round of folding. + */ +public class InferenceFunctionConstantFolding implements PreOptimizerRule { + private final InferenceRunner inferenceRunner; + private final FoldContext foldContext; + + /** + * Creates a new instance of the InferenceFunctionConstantFolding rule. + * + * @param inferenceRunner the inference runner to use for evaluating inference functions + * @param foldContext the fold context to use for evaluating inference functions + */ + public InferenceFunctionConstantFolding(InferenceRunner inferenceRunner, FoldContext foldContext) { + this.inferenceRunner = inferenceRunner; + this.foldContext = foldContext; + } + + /** + * Applies the InferenceFunctionConstantFolding rule to the given logical plan. + * + * @param plan the logical plan to apply the rule to + * @param listener the listener to notify when the rule has been applied + */ + @Override + public void apply(LogicalPlan plan, ActionListener listener) { + foldInferenceFunctions(plan, listener); + } + + /** + * Recursively folds inference functions in the logical plan. + *

+ * This method collects all foldable inference functions, evaluates them in parallel, + * and then replaces them with their results. If new inference functions are revealed + * after the first round of folding, it recursively processes them as well. + * + * @param plan the logical plan to fold inference functions in + * @param listener the listener to notify when the folding is complete + */ + private void foldInferenceFunctions(LogicalPlan plan, ActionListener listener) { + // First let's collect all the inference foldable inference functions + List> inferenceFunctions = collectFoldableInferenceFunctions(plan); + + if (inferenceFunctions.isEmpty()) { + // No inference functions that can be evaluated at this time found. Return the original plan. + listener.onResponse(plan); + return; + } + + // This is a map of inference functions to their results. + // We will use this map to replace the inference functions in the plan. + Map, Expression> inferenceFunctionsToResults = new HashMap<>(); + + // Prepare a listener that will be called when all inference functions are done. + // This listener will replace the inference functions in the plan with their results and then recursively fold the remaining + // inference functions. + CountDownActionListener completionListener = new CountDownActionListener( + inferenceFunctions.size(), + listener.delegateFailureIgnoreResponseAndWrap(l -> { + // Replace the inference functions in the plan with their results + LogicalPlan next = plan.transformExpressionsUp( + InferenceFunction.class, + f -> inferenceFunctionsToResults.getOrDefault(f, f) + ); + + // Recursively fold the remaining inference functions + foldInferenceFunctions(next, l); + }) + ); + + // Try to compute the result for each inference function. + for (InferenceFunction inferenceFunction : inferenceFunctions) { + foldInferenceFunction(inferenceFunction, completionListener.map(e -> { + inferenceFunctionsToResults.put(inferenceFunction, e); + return null; + })); + } + } + + /** + * Collects all foldable inference functions from the logical plan. + *

+ * A function is considered foldable if: + * 1. It's an instance of InferenceFunction + * 2. It's marked as foldable (all parameters are constants) + * 3. It doesn't contain nested inference functions + * + * @param plan the logical plan to collect inference functions from + * @return a list of foldable inference functions + */ + private List> collectFoldableInferenceFunctions(LogicalPlan plan) { + List> inferenceFunctions = new ArrayList<>(); + + plan.forEachExpressionUp(InferenceFunction.class, f -> { + if (f.foldable() && f.hasNestedInferenceFunction() == false) { + inferenceFunctions.add(f); + } + }); + + return inferenceFunctions; + } + + /** + * Evaluates a single inference function asynchronously. + *

+ * Uses the inference function's evaluator factory to create an evaluator + * that can process the function with the given inference runner. + * + * @param inferenceFunction the inference function to evaluate + * @param listener the listener to notify when the evaluation is complete + */ + private void foldInferenceFunction(InferenceFunction inferenceFunction, ActionListener listener) { + inferenceFunction.inferenceEvaluatorFactory().get(inferenceRunner).eval(foldContext, listener); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/PreOptimizerRule.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/PreOptimizerRule.java new file mode 100644 index 0000000000000..437778bdca6a5 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/PreOptimizerRule.java @@ -0,0 +1,25 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; + +/** + * A rule that can be applied to a logical plan before it is optimized. + */ +public interface PreOptimizerRule { + + /** + * Apply the rule to the logical plan. + * + * @param plan the analyzed logical plan to pre-optimize + * @param listener the listener returning the pre-optimized plan when pre-optimization is complete + */ + void apply(LogicalPlan plan, ActionListener listener); +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/AbstractLogicalPlanPreOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/AbstractLogicalPlanPreOptimizerTests.java new file mode 100644 index 0000000000000..63bb7c50ad09d --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/AbstractLogicalPlanPreOptimizerTests.java @@ -0,0 +1,315 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.esql.EsqlTestUtils; +import org.elasticsearch.xpack.esql.core.expression.Alias; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding; +import org.elasticsearch.xpack.esql.expression.function.scalar.string.Concat; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; +import org.elasticsearch.xpack.esql.inference.InferenceRunner; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator; +import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.Filter; +import org.elasticsearch.xpack.esql.plan.logical.Limit; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.logical.Project; +import org.elasticsearch.xpack.esql.plugin.TransportActionServices; +import org.junit.After; +import org.junit.Before; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.core.TimeValue.timeValueNanos; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.fieldAttribute; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.of; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Base test class for LogicalPlanPreOptimizer tests. + *

+ * Provides common infrastructure for testing pre-optimization rules, including: + * - Thread pool management for async operations + * - inference model implementations for inference function testing + * - Mock services and runners for inference execution + * - Helper methods for plan creation and manipulation + */ +public class AbstractLogicalPlanPreOptimizerTests extends ESTestCase { + + // + // Embedding model types and implementations + // + + /** + * Available textembedding model types for testing. + */ + public enum TestEmbeddingModel { + FLOAT_EMBEDDING_MODEL, + BYTES_EMBEDDING_MODEL, + BITS_EMBEDDING_MODEL + } + + /** + * Interface for embedding model implementations. + */ + protected interface TextEmbeddingModelMock { + /** + * Returns embedding results for the given input text. + */ + TextEmbeddingResults embeddingResults(String input); + + /** + * Returns embedding values as a float array for the given input text. + */ + float[] embedding(String input); + + /** + * Returns embedding values as a list of floats for the given input text. + * Default implementation converts the float array to a list. + */ + default List embeddingList(String input) { + float[] embedding = embedding(input); + List embeddingList = new ArrayList<>(embedding.length); + for (float value : embedding) { + embeddingList.add(value); + } + return embeddingList; + } + } + + /** + * Map of embedding model implementations by type. + */ + private final static Map TEST_EMBEDDING_MODELS = Map.ofEntries( + // Float embedding model implementation + Map.entry(TestEmbeddingModel.FLOAT_EMBEDDING_MODEL, new TextEmbeddingModelMock() { + @Override + public TextEmbeddingResults embeddingResults(String input) { + TextEmbeddingFloatResults.Embedding embedding = new TextEmbeddingFloatResults.Embedding(embedding(input)); + return new TextEmbeddingFloatResults(List.of(embedding)); + } + + @Override + public float[] embedding(String input) { + String[] tokens = input.split("\\s+"); + float[] embedding = new float[tokens.length]; + for (int i = 0; i < tokens.length; i++) { + embedding[i] = tokens[i].length(); + } + return embedding; + } + }), + + // Byte embedding model implementation + Map.entry(TestEmbeddingModel.BYTES_EMBEDDING_MODEL, new TextEmbeddingModelMock() { + @Override + public TextEmbeddingResults embeddingResults(String input) { + TextEmbeddingByteResults.Embedding embedding = new TextEmbeddingByteResults.Embedding(bytes(input)); + return new TextEmbeddingBitResults(List.of(embedding)); + } + + private byte[] bytes(String input) { + return input.getBytes(StandardCharsets.UTF_8); + } + + @Override + public float[] embedding(String input) { + return new TextEmbeddingByteResults.Embedding(bytes(input)).toFloatArray(); + } + }), + + // Bit embedding model implementation + Map.entry(TestEmbeddingModel.BITS_EMBEDDING_MODEL, new TextEmbeddingModelMock() { + @Override + public TextEmbeddingResults embeddingResults(String input) { + TextEmbeddingByteResults.Embedding embedding = new TextEmbeddingByteResults.Embedding(bytes(input)); + return new TextEmbeddingBitResults(List.of(embedding)); + } + + private byte[] bytes(String input) { + String[] tokens = input.split("\\s+"); + byte[] embedding = new byte[tokens.length]; + for (int i = 0; i < tokens.length; i++) { + embedding[i] = (byte) (tokens[i].length() % 2); + } + return embedding; + } + + @Override + public float[] embedding(String input) { + return new TextEmbeddingByteResults.Embedding(bytes(input)).toFloatArray(); + } + }) + ); + + // + // Thread pool management for async testing + // + + private ThreadPool threadPool; + + @Before + public void setupThreadPool() { + threadPool = createThreadPool(); + } + + @After + public void shutdownThreadPool() { + terminate(threadPool); + } + + /** + * Runs the given runnable with a random delay to simulate async behavior. + * Uses the thread pool to execute the runnable. + * + * @param runnable the runnable to execute + */ + private void runWithDelay(Runnable runnable) { + if (randomBoolean()) { + threadPool.schedule(runnable, timeValueNanos(between(0, 5000)), threadPool.generic()); + } else { + threadPool.generic().execute(runnable); + } + } + + // + // Pre-optimizer and inference runner setup + // + + /** + * Creates a LogicalPlanPreOptimizer with the specified embedding model. + * + * @param textEmbeddingModel the embedding model to use + * @return a pre-optimizer configured with the specified model + */ + protected LogicalPlanPreOptimizer preOptimizer(TestEmbeddingModel textEmbeddingModel) { + LogicalPreOptimizerContext preOptimizerContext = new LogicalPreOptimizerContext(FoldContext.small()); + return new LogicalPlanPreOptimizer(mockTransportActionServices(textEmbeddingModel), preOptimizerContext); + } + + /** + * Creates a mock inference runner that uses the specified embedding model. + * + * @param textEmbeddingModel the embedding model to use + * @return a mock inference runner + */ + protected InferenceRunner mockedInferenceRunner(TestEmbeddingModel textEmbeddingModel) { + return new InferenceRunner() { + @Override + public void execute(InferenceAction.Request request, ActionListener listener) { + try { + runWithDelay( + () -> listener.onResponse( + new InferenceAction.Response( + TEST_EMBEDDING_MODELS.get(textEmbeddingModel).embeddingResults(request.getInput().getFirst()) + ) + ) + ); + } catch (Exception e) { + listener.onFailure(e); + } + } + + @Override + public void executeBulk(BulkInferenceRequestIterator requests, ActionListener> listener) { + listener.onFailure(new UnsupportedOperationException("executeBulk is not supported in this test")); + } + }; + } + + /** + * Creates mock transport action services with the specified embedding model. + * + * @param textEmbeddingModel the embedding model to use + * @return mock transport action services + */ + private TransportActionServices mockTransportActionServices(TestEmbeddingModel textEmbeddingModel) { + TransportActionServices services = mock(TransportActionServices.class); + when(services.inferenceRunner()).thenReturn(mockedInferenceRunner(textEmbeddingModel)); + return services; + } + + /** + * Gets the embedding list for the given model and input text. + * + * @param textEmbeddingModel the embedding model to use + * @param input the input text + * @return the embedding as a list of floats + */ + protected static List embedding(TestEmbeddingModel textEmbeddingModel, String input) { + return TEST_EMBEDDING_MODELS.get(textEmbeddingModel).embeddingList(input); + } + + // + // Plan and expression generation helpers + // + + /** + * Creates a random logical plan for testing. + * The plan consists of a relation with random commands applied to it. + * + * @return a random logical plan + */ + protected LogicalPlan randomPlan() { + LogicalPlan plan = EsqlTestUtils.relation(); + int numCommands = between(0, 100); + + for (int i = 0; i < numCommands; i++) { + plan = switch (randomInt(3)) { + case 0 -> new Eval(Source.EMPTY, plan, List.of(new Alias(Source.EMPTY, randomIdentifier(), randomExpression()))); + case 1 -> new Limit(Source.EMPTY, of(randomInt()), plan); + case 2 -> new Filter(Source.EMPTY, plan, randomCondition()); + default -> new Project(Source.EMPTY, plan, List.of(new Alias(Source.EMPTY, randomIdentifier(), fieldAttribute()))); + }; + } + return plan; + } + + /** + * Creates a random expression for testing. + * + * @return a random expression + */ + protected Expression randomExpression() { + return switch (randomInt(4)) { + case 0 -> of(randomInt()); + case 1 -> of(randomIdentifier()); + case 2 -> new Add(Source.EMPTY, of(randomInt()), of(randomDouble())); + case 3 -> new TextEmbedding(Source.EMPTY, of(randomIdentifier()), of(randomIdentifier())); + default -> new Concat(Source.EMPTY, of(randomIdentifier()), randomList(1, 10, () -> of(randomIdentifier()))); + }; + } + + /** + * Creates a random condition expression for testing. + * + * @return a random condition expression + */ + protected Expression randomCondition() { + if (randomBoolean()) { + return EsqlTestUtils.equalsOf(randomExpression(), randomExpression()); + } + + return EsqlTestUtils.greaterThanOf(randomExpression(), randomExpression()); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java index 73f9d9848be7f..4f572fa2531dc 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java @@ -7,64 +7,57 @@ package org.elasticsearch.xpack.esql.optimizer; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.xpack.esql.EsqlTestUtils; -import org.elasticsearch.xpack.esql.core.expression.Alias; -import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.expression.FoldContext; -import org.elasticsearch.xpack.esql.core.expression.Literal; -import org.elasticsearch.xpack.esql.core.tree.Source; -import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding; -import org.elasticsearch.xpack.esql.expression.function.scalar.string.Concat; -import org.elasticsearch.xpack.esql.expression.function.vector.Knn; -import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; -import org.elasticsearch.xpack.esql.inference.InferenceRunner; -import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator; -import org.elasticsearch.xpack.esql.plan.logical.EsRelation; -import org.elasticsearch.xpack.esql.plan.logical.Eval; -import org.elasticsearch.xpack.esql.plan.logical.Filter; -import org.elasticsearch.xpack.esql.plan.logical.Limit; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.plan.logical.Project; -import org.elasticsearch.xpack.esql.plugin.TransportActionServices; -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.List; +import java.util.Arrays; import static org.elasticsearch.xpack.esql.EsqlTestUtils.as; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.fieldAttribute; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.of; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.relation; -import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.notNullValue; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; -public class LogicalPlanPreOptimizerTests extends ESTestCase { +public class LogicalPlanPreOptimizerTests extends AbstractLogicalPlanPreOptimizerTests { + + private final TestEmbeddingModel embeddingModel; + + public LogicalPlanPreOptimizerTests(TestEmbeddingModel embeddingModel) { + this.embeddingModel = embeddingModel; + } + + @ParametersFactory(argumentFormatting = "%1$s") + public static Iterable parameters() { + return Arrays.stream(TestEmbeddingModel.values()).map(textEmbeddingModel -> new Object[] { textEmbeddingModel }).toList(); + } + /** + * Tests that the pre-optimizer correctly marks plans as pre-optimized. + */ public void testPlanIsMarkedAsPreOptimized() throws Exception { for (int round = 0; round < 100; round++) { - // We want to make sure that the pre-optimizer woks for a wide range of plans - preOptimizedPlan(randomPlan()); + // Create a random plan for testing + LogicalPlan plan = randomPlan(); + plan.setAnalyzed(); + + // Apply pre-optimization + preOptimizedPlan(preOptimizer(embeddingModel), plan); } } - public void testPreOptimizeFailsIfPlanIsNotAnalyzed() throws Exception { + /** + * Tests that the pre-optimizer fails when given a plan that is not analyzed. + */ + public void testPreOptimizeFailsPlanIsNotAnalyzed() throws Exception { + // Create a plan that is not marked as analyzed LogicalPlan plan = EsqlTestUtils.relation(); SetOnce exceptionHolder = new SetOnce<>(); - preOptimizer().preOptimize(plan, ActionListener.wrap(r -> fail("Should have failed"), exceptionHolder::set)); + // Apply pre-optimization and expect failure + preOptimizer(embeddingModel).preOptimize(plan, ActionListener.wrap(r -> fail("Should have failed"), exceptionHolder::set)); assertBusy(() -> { assertThat(exceptionHolder.get(), notNullValue()); IllegalStateException e = as(exceptionHolder.get(), IllegalStateException.class); @@ -72,231 +65,21 @@ public void testPreOptimizeFailsIfPlanIsNotAnalyzed() throws Exception { }); } - public void testEvalFunctionEmbeddingBytes() throws Exception { - testEvalFunctionEmbedding(BYTES_EMBEDDING_MODEL); - } - - public void testEvalFunctionEmbeddingBits() throws Exception { - testEvalFunctionEmbedding(BIT_EMBEDDING_MODEL); - } - - public void testEvalFunctionEmbeddingFloats() throws Exception { - testEvalFunctionEmbedding(FLOAT_EMBEDDING_MODEL); - } - - public void testKnnFunctionEmbeddingBytes() throws Exception { - testKnnFunctionEmbedding(BYTES_EMBEDDING_MODEL); - } - - public void testKnnFunctionEmbeddingBits() throws Exception { - testKnnFunctionEmbedding(BIT_EMBEDDING_MODEL); - } - - public void testKnnFunctionEmbeddingFloats() throws Exception { - testKnnFunctionEmbedding(FLOAT_EMBEDDING_MODEL); - } - - public LogicalPlan preOptimizedPlan(LogicalPlan plan) throws Exception { - return preOptimizedPlan(preOptimizer(), plan); - } - - public LogicalPlan preOptimizedPlan(LogicalPlanPreOptimizer preOptimizer, LogicalPlan plan) throws Exception { - // set plan as analyzed + /** + * Executes pre-optimization on the given plan and returns the result. + */ + protected LogicalPlan preOptimizedPlan(LogicalPlanPreOptimizer preOptimizer, LogicalPlan plan) throws Exception { + // set plan as analyzed to meet pre-optimizer requirements plan.setPreOptimized(); - SetOnce resultHolder = new SetOnce<>(); - SetOnce exceptionHolder = new SetOnce<>(); - - preOptimizer.preOptimize(plan, ActionListener.wrap(resultHolder::set, exceptionHolder::set)); - - if (exceptionHolder.get() != null) { - throw exceptionHolder.get(); - } - - assertThat(resultHolder.get(), notNullValue()); - assertThat(resultHolder.get().preOptimized(), equalTo(true)); - - return resultHolder.get(); - } + PlainActionFuture logicalPlanFuture = new PlainActionFuture<>(); + preOptimizer.preOptimize(plan, logicalPlanFuture); - private void testEvalFunctionEmbedding(TextEmbeddingModelMock textEmbeddingModel) throws Exception { - String inferenceId = randomUUID(); - String query = String.join(" ", randomArray(1, String[]::new, () -> randomAlphaOfLength(randomIntBetween(1, 10)))); - String fieldName = randomIdentifier(); + LogicalPlan preOptimized = logicalPlanFuture.get(); - LogicalPlanPreOptimizer preOptimizer = preOptimizer(textEmbeddingModel); - EsRelation relation = relation(); - Eval eval = new Eval( - Source.EMPTY, - relation, - List.of(new Alias(Source.EMPTY, fieldName, new TextEmbedding(Source.EMPTY, of(query), of(inferenceId)))) - ); - eval.setAnalyzed(); - - Eval preOptimizedEval = as(preOptimizedPlan(preOptimizer, eval), Eval.class); - assertThat(preOptimizedEval.fields(), hasSize(1)); - assertThat(preOptimizedEval.fields().get(0).name(), equalTo(fieldName)); - Literal preOptimizedQuery = as(preOptimizedEval.fields().get(0).child(), Literal.class); - assertThat(preOptimizedQuery.dataType(), equalTo(DENSE_VECTOR)); - assertThat(preOptimizedQuery.value(), equalTo(textEmbeddingModel.embeddingList(query))); + assertThat(preOptimized, notNullValue()); + assertThat(preOptimized.preOptimized(), equalTo(true)); + return preOptimized; } - - private void testKnnFunctionEmbedding(TextEmbeddingModelMock textEmbeddingModel) throws Exception { - String inferenceId = randomUUID(); - String query = String.join(" ", randomArray(1, String[]::new, () -> randomAlphaOfLength(randomIntBetween(1, 10)))); - - LogicalPlanPreOptimizer preOptimizer = preOptimizer(textEmbeddingModel); - EsRelation relation = relation(); - Filter filter = new Filter( - Source.EMPTY, - relation, - new Knn(Source.EMPTY, getFieldAttribute("a"), new TextEmbedding(Source.EMPTY, of(query), of(inferenceId)), of(10), null) - ); - Knn knn = as(filter.condition(), Knn.class); - - Filter preOptimizedFilter = as(preOptimizedPlan(preOptimizer, filter), Filter.class); - Knn preOptimizedKnn = as(preOptimizedFilter.condition(), Knn.class); - assertThat(preOptimizedKnn.field(), equalTo(knn.field())); - assertThat(preOptimizedKnn.k(), equalTo(knn.k())); - assertThat(preOptimizedKnn.options(), equalTo(knn.options())); - - Literal preOptimizedQuery = as(preOptimizedKnn.query(), Literal.class); - assertThat(preOptimizedQuery.dataType(), equalTo(DENSE_VECTOR)); - assertThat(preOptimizedQuery.value(), equalTo(textEmbeddingModel.embeddingList(query))); - } - - private static LogicalPlanPreOptimizer preOptimizer() { - return preOptimizer(randomFrom(FLOAT_EMBEDDING_MODEL, BYTES_EMBEDDING_MODEL, BIT_EMBEDDING_MODEL)); - } - - private static LogicalPlanPreOptimizer preOptimizer(TextEmbeddingModelMock textEmbeddingModel) { - return preOptimizer(mockInferenceRunner(textEmbeddingModel)); - } - - public static LogicalPlanPreOptimizer preOptimizer(InferenceRunner inferenceRunner) { - LogicalPreOptimizerContext preOptimizerContext = new LogicalPreOptimizerContext(FoldContext.small()); - return new LogicalPlanPreOptimizer(mockTransportActionServices(inferenceRunner), preOptimizerContext); - } - - private LogicalPlan randomPlan() { - LogicalPlan plan = EsqlTestUtils.relation(); - int numCommands = between(0, 100); - - for (int i = 0; i < numCommands; i++) { - plan = switch (randomInt(3)) { - case 0 -> new Eval(Source.EMPTY, plan, List.of(new Alias(Source.EMPTY, randomIdentifier(), randomExpression()))); - case 1 -> new Limit(Source.EMPTY, of(randomInt()), plan); - case 2 -> new Filter(Source.EMPTY, plan, randomCondition()); - default -> new Project(Source.EMPTY, plan, List.of(new Alias(Source.EMPTY, randomIdentifier(), fieldAttribute()))); - }; - } - return plan; - } - - private Expression randomExpression() { - return switch (randomInt(4)) { - case 0 -> of(randomInt()); - case 1 -> of(randomIdentifier()); - case 2 -> new Add(Source.EMPTY, of(randomInt()), of(randomDouble())); - case 3 -> new TextEmbedding(Source.EMPTY, of(randomIdentifier()), of(randomIdentifier())); - default -> new Concat(Source.EMPTY, of(randomIdentifier()), randomList(1, 10, () -> of(randomIdentifier()))); - }; - } - - private Expression randomCondition() { - if (randomBoolean()) { - return EsqlTestUtils.equalsOf(randomExpression(), randomExpression()); - } - - return EsqlTestUtils.greaterThanOf(randomExpression(), randomExpression()); - } - - private static TransportActionServices mockTransportActionServices(InferenceRunner inferenceRunner) { - TransportActionServices services = mock(TransportActionServices.class); - when(services.inferenceRunner()).thenReturn(inferenceRunner); - return services; - } - - private static InferenceRunner mockInferenceRunner(TextEmbeddingModelMock textEmbeddingModel) { - return new InferenceRunner() { - @Override - public void execute(InferenceAction.Request request, ActionListener listener) { - listener.onResponse(new InferenceAction.Response(textEmbeddingModel.embeddingResults(request.getInput().getFirst()))); - } - - @Override - public void executeBulk(BulkInferenceRequestIterator requests, ActionListener> listener) { - listener.onFailure( - new UnsupportedOperationException("executeBulk should not be invoked for plans without inference functions") - ); - } - }; - } - - private interface TextEmbeddingModelMock { - TextEmbeddingResults embeddingResults(String input); - - float[] embedding(String input); - - default List embeddingList(String input) { - float[] embedding = embedding(input); - List embeddingList = new ArrayList<>(embedding.length); - for (float value : embedding) { - embeddingList.add(value); - } - return embeddingList; - } - } - - private static final TextEmbeddingModelMock FLOAT_EMBEDDING_MODEL = new TextEmbeddingModelMock() { - public TextEmbeddingResults embeddingResults(String input) { - TextEmbeddingFloatResults.Embedding embedding = new TextEmbeddingFloatResults.Embedding(embedding(input)); - return new TextEmbeddingFloatResults(List.of(embedding)); - } - - public float[] embedding(String input) { - String[] tokens = input.split("\\s+"); - float[] embedding = new float[tokens.length]; - for (int i = 0; i < tokens.length; i++) { - embedding[i] = tokens[i].length(); - } - return embedding; - } - }; - - private static final TextEmbeddingModelMock BYTES_EMBEDDING_MODEL = new TextEmbeddingModelMock() { - public TextEmbeddingResults embeddingResults(String input) { - TextEmbeddingByteResults.Embedding embedding = new TextEmbeddingByteResults.Embedding(bytes(input)); - return new TextEmbeddingBitResults(List.of(embedding)); - } - - private byte[] bytes(String input) { - return input.getBytes(StandardCharsets.UTF_8); - } - - public float[] embedding(String input) { - return new TextEmbeddingByteResults.Embedding(bytes(input)).toFloatArray(); - } - }; - - private static final TextEmbeddingModelMock BIT_EMBEDDING_MODEL = new TextEmbeddingModelMock() { - public TextEmbeddingResults embeddingResults(String input) { - TextEmbeddingByteResults.Embedding embedding = new TextEmbeddingByteResults.Embedding(bytes(input)); - return new TextEmbeddingBitResults(List.of(embedding)); - } - - private byte[] bytes(String input) { - String[] tokens = input.split("\\s+"); - byte[] embedding = new byte[tokens.length]; - for (int i = 0; i < tokens.length; i++) { - embedding[i] = (byte) (tokens[i].length() % 2); - } - return embedding; - } - - public float[] embedding(String input) { - return new TextEmbeddingByteResults.Embedding(bytes(input)).toFloatArray(); - } - }; } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/InferenceFunctionConstantFoldingTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/InferenceFunctionConstantFoldingTests.java new file mode 100644 index 0000000000000..940e613a3154d --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/InferenceFunctionConstantFoldingTests.java @@ -0,0 +1,111 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.xpack.esql.core.expression.Alias; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding; +import org.elasticsearch.xpack.esql.expression.function.vector.Knn; +import org.elasticsearch.xpack.esql.optimizer.AbstractLogicalPlanPreOptimizerTests; +import org.elasticsearch.xpack.esql.plan.logical.EsRelation; +import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.Filter; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; + +import java.util.Arrays; +import java.util.List; + +import static org.elasticsearch.xpack.esql.EsqlTestUtils.as; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.of; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.relation; +import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public class InferenceFunctionConstantFoldingTests extends AbstractLogicalPlanPreOptimizerTests { + + private final TestEmbeddingModel embeddingModel; + + public InferenceFunctionConstantFoldingTests(TestEmbeddingModel embeddingModel) { + this.embeddingModel = embeddingModel; + } + + @ParametersFactory(argumentFormatting = "textEmbeddingType=%1$s") + public static Iterable parameters() { + return Arrays.stream(TestEmbeddingModel.values()).map(textEmbeddingModel -> new Object[] { textEmbeddingModel }).toList(); + } + + /** + * Tests that the rule correctly evaluates TEXT_EMBEDDING functions in Eval nodes. + */ + public void testEvalFunctionEmbedding() throws Exception { + // Setup: Create a plan with an Eval node containing a TEXT_EMBEDDING function + String inferenceId = randomUUID(); + String query = String.join(" ", randomArray(1, String[]::new, () -> randomAlphaOfLength(randomIntBetween(1, 10)))); + String fieldName = randomIdentifier(); + + EsRelation relation = relation(); + Eval eval = new Eval( + Source.EMPTY, + relation, + List.of(new Alias(Source.EMPTY, fieldName, new TextEmbedding(Source.EMPTY, of(query), of(inferenceId)))) + ); + eval.setAnalyzed(); + + Eval preOptimizedEval = as(inferenceFunctionConstantFolding(eval), Eval.class); + + assertThat(preOptimizedEval.fields(), hasSize(1)); + assertThat(preOptimizedEval.fields().get(0).name(), equalTo(fieldName)); + Literal preOptimizedQuery = as(preOptimizedEval.fields().get(0).child(), Literal.class); + assertThat(preOptimizedQuery.dataType(), equalTo(DENSE_VECTOR)); + assertThat(preOptimizedQuery.value(), equalTo(embedding(embeddingModel, query))); + } + + /** + * Tests that the rule correctly evaluates TEXT_EMBEDDING functions as KNN query. + */ + public void testKnnFunctionEmbedding() throws Exception { + // Setup: Create a plan with a Filter node containing a KNN predicate with a TEXT_EMBEDDING function + String inferenceId = randomUUID(); + String query = String.join(" ", randomArray(1, String[]::new, () -> randomAlphaOfLength(randomIntBetween(1, 10)))); + + EsRelation relation = relation(); + Filter filter = new Filter( + Source.EMPTY, + relation, + new Knn(Source.EMPTY, getFieldAttribute("a"), new TextEmbedding(Source.EMPTY, of(query), of(inferenceId)), of(10), null) + ); + Knn knn = as(filter.condition(), Knn.class); + + Filter preOptimizedFilter = as(inferenceFunctionConstantFolding(filter), Filter.class); + + Knn preOptimizedKnn = as(preOptimizedFilter.condition(), Knn.class); + assertThat(preOptimizedKnn.field(), equalTo(knn.field())); + assertThat(preOptimizedKnn.k(), equalTo(knn.k())); + assertThat(preOptimizedKnn.options(), equalTo(knn.options())); + + Literal preOptimizedQuery = as(preOptimizedKnn.query(), Literal.class); + assertThat(preOptimizedQuery.dataType(), equalTo(DENSE_VECTOR)); + assertThat(preOptimizedQuery.value(), equalTo(embedding(embeddingModel, query))); + } + + /** + * Helper method to apply the InferenceFunctionConstantFolding rule to a logical plan. + */ + private LogicalPlan inferenceFunctionConstantFolding(LogicalPlan plan) { + PlainActionFuture preOptimized = new PlainActionFuture<>(); + new InferenceFunctionConstantFolding(mockedInferenceRunner(embeddingModel), FoldContext.small()).apply(plan, preOptimized); + return preOptimized.actionGet(); + } +} From 3824f4fdfb9cd4de405da6309909a28474c3c4fe Mon Sep 17 00:00:00 2001 From: afoucret Date: Fri, 18 Jul 2025 15:09:59 +0200 Subject: [PATCH 27/31] Checkstyle! --- .../esql/optimizer/AbstractLogicalPlanPreOptimizerTests.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/AbstractLogicalPlanPreOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/AbstractLogicalPlanPreOptimizerTests.java index 63bb7c50ad09d..119aa3cf67779 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/AbstractLogicalPlanPreOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/AbstractLogicalPlanPreOptimizerTests.java @@ -100,7 +100,7 @@ default List embeddingList(String input) { /** * Map of embedding model implementations by type. */ - private final static Map TEST_EMBEDDING_MODELS = Map.ofEntries( + private static final Map TEST_EMBEDDING_MODELS = Map.ofEntries( // Float embedding model implementation Map.entry(TestEmbeddingModel.FLOAT_EMBEDDING_MODEL, new TextEmbeddingModelMock() { @Override From 70819e10bb2653a6efbf88d69d20c6cc0cad10ef Mon Sep 17 00:00:00 2001 From: afoucret Date: Thu, 24 Jul 2025 19:52:56 +0200 Subject: [PATCH 28/31] Huge simplification of the bul inference runner & operator --- .../esql/inference/InferenceFunctionEvaluator.java | 3 ++- .../TextEmbeddingFunctionEvaluator.java | 10 +++++----- .../esql/optimizer/LogicalPlanPreOptimizer.java | 2 +- .../InferenceFunctionConstantFolding.java | 14 +++++++------- 4 files changed, 15 insertions(+), 14 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java index 38db294e5ddf5..baf7e85e0f220 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java @@ -10,12 +10,13 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner; public interface InferenceFunctionEvaluator { void eval(FoldContext foldContext, ActionListener listener); interface Factory { - InferenceFunctionEvaluator get(InferenceRunner inferenceRunner); + InferenceFunctionEvaluator get(BulkInferenceRunner inferenceRunner); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingFunctionEvaluator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingFunctionEvaluator.java index a7a054fc81c15..2740cd80d9067 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingFunctionEvaluator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingFunctionEvaluator.java @@ -21,20 +21,20 @@ import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding; import org.elasticsearch.xpack.esql.inference.InferenceFunctionEvaluator; -import org.elasticsearch.xpack.esql.inference.InferenceRunner; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner; import java.util.ArrayList; import java.util.List; public class TextEmbeddingFunctionEvaluator implements InferenceFunctionEvaluator { - private final InferenceRunner inferenceRunner; + private final BulkInferenceRunner bulkInferenceRunner; private final TextEmbedding f; - public TextEmbeddingFunctionEvaluator(TextEmbedding f, InferenceRunner inferenceRunner) { + public TextEmbeddingFunctionEvaluator(TextEmbedding f, BulkInferenceRunner bulkInferenceRunner) { this.f = f; - this.inferenceRunner = inferenceRunner; + this.bulkInferenceRunner = bulkInferenceRunner; } @Override @@ -45,7 +45,7 @@ public void eval(FoldContext foldContext, ActionListener listener) { String inferenceId = BytesRefs.toString(f.inferenceId().fold(foldContext)); String inputText = BytesRefs.toString(f.inputText().fold(foldContext)); - inferenceRunner.execute(inferenceRequest(inferenceId, inputText), listener.map(this::parseInferenceResponse)); + //bulkInferenceRunner.executeBulk(inferenceRequest(inferenceId, inputText), listener.map(this::parseInferenceResponse)); } private InferenceAction.Request inferenceRequest(String inferenceId, String inputText) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java index 2a9057957f06a..c0ee08f830d72 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java @@ -28,7 +28,7 @@ public class LogicalPlanPreOptimizer { private final List rules; public LogicalPlanPreOptimizer(TransportActionServices services, LogicalPreOptimizerContext preOptimizerContext) { - rules = List.of(new InferenceFunctionConstantFolding(services.inferenceRunner(), preOptimizerContext.foldCtx())); + rules = List.of(new InferenceFunctionConstantFolding(services.bulkInferenceRunner(), preOptimizerContext.foldCtx())); } /** diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/InferenceFunctionConstantFolding.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/InferenceFunctionConstantFolding.java index 3dc788fb8e955..531a1493d7831 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/InferenceFunctionConstantFolding.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/InferenceFunctionConstantFolding.java @@ -12,7 +12,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction; -import org.elasticsearch.xpack.esql.inference.InferenceRunner; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import java.util.ArrayList; @@ -31,17 +31,17 @@ * appear after the first round of folding. */ public class InferenceFunctionConstantFolding implements PreOptimizerRule { - private final InferenceRunner inferenceRunner; + private final BulkInferenceRunner bulkInferenceRunner; private final FoldContext foldContext; /** * Creates a new instance of the InferenceFunctionConstantFolding rule. * - * @param inferenceRunner the inference runner to use for evaluating inference functions - * @param foldContext the fold context to use for evaluating inference functions + * @param bulkInferenceRunner the inference runner to use for evaluating inference functions + * @param foldContext the fold context to use for evaluating inference functions */ - public InferenceFunctionConstantFolding(InferenceRunner inferenceRunner, FoldContext foldContext) { - this.inferenceRunner = inferenceRunner; + public InferenceFunctionConstantFolding(BulkInferenceRunner bulkInferenceRunner, FoldContext foldContext) { + this.bulkInferenceRunner = bulkInferenceRunner; this.foldContext = foldContext; } @@ -139,6 +139,6 @@ private List> collectFoldableInferenceFunctions(LogicalPlan * @param listener the listener to notify when the evaluation is complete */ private void foldInferenceFunction(InferenceFunction inferenceFunction, ActionListener listener) { - inferenceFunction.inferenceEvaluatorFactory().get(inferenceRunner).eval(foldContext, listener); + inferenceFunction.inferenceEvaluatorFactory().get(bulkInferenceRunner).eval(foldContext, listener); } } From c09800a803b2b7474e3001a91d1eaccef6113027 Mon Sep 17 00:00:00 2001 From: afoucret Date: Thu, 24 Jul 2025 19:59:39 +0200 Subject: [PATCH 29/31] Lint --- .../TextEmbeddingFunctionEvaluator.java | 36 +++++++++++++++---- 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingFunctionEvaluator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingFunctionEvaluator.java index 2740cd80d9067..dcc7f00ad6916 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingFunctionEvaluator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingFunctionEvaluator.java @@ -21,9 +21,11 @@ import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding; import org.elasticsearch.xpack.esql.inference.InferenceFunctionEvaluator; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator; import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner; import java.util.ArrayList; +import java.util.Iterator; import java.util.List; public class TextEmbeddingFunctionEvaluator implements InferenceFunctionEvaluator { @@ -42,18 +44,40 @@ public void eval(FoldContext foldContext, ActionListener listener) { assert f.inferenceId() != null && f.inferenceId().foldable() : "inferenceId should not be null and be foldable"; assert f.inputText() != null && f.inputText().foldable() : "inputText should not be null and be foldable"; - String inferenceId = BytesRefs.toString(f.inferenceId().fold(foldContext)); - String inputText = BytesRefs.toString(f.inputText().fold(foldContext)); + final String inferenceId = BytesRefs.toString(f.inferenceId().fold(foldContext)); + final String inputText = BytesRefs.toString(f.inputText().fold(foldContext)); - //bulkInferenceRunner.executeBulk(inferenceRequest(inferenceId, inputText), listener.map(this::parseInferenceResponse)); + bulkInferenceRunner.executeBulk(new BulkInferenceRequestIterator() { + private final Iterator it = List.of(inferenceRequest(inferenceId, inputText)).iterator(); + + @Override + public void close() { + + } + + @Override + public boolean hasNext() { + return it.hasNext(); + } + + @Override + public InferenceAction.Request next() { + return it.next(); + } + + @Override + public int estimatedSize() { + return 1; + } + }, listener.map(this::parseInferenceResponse)); } - private InferenceAction.Request inferenceRequest(String inferenceId, String inputText) { + private static InferenceAction.Request inferenceRequest(String inferenceId, String inputText) { return InferenceAction.Request.builder(inferenceId, TaskType.TEXT_EMBEDDING).setInput(List.of(inputText)).build(); } - private Literal parseInferenceResponse(InferenceAction.Response response) { - if (response.getResults() instanceof TextEmbeddingResults textEmbeddingResults) { + private Literal parseInferenceResponse(List responses) { + if (responses.getFirst().getResults() instanceof TextEmbeddingResults textEmbeddingResults) { return parseInferenceResponse(textEmbeddingResults); } throw new IllegalArgumentException("Inference response should be of type TextEmbeddingResults"); From cab92a1b7e0482d36b1a9eb4e2098d3c09539a8e Mon Sep 17 00:00:00 2001 From: afoucret Date: Mon, 28 Jul 2025 14:44:52 +0200 Subject: [PATCH 30/31] Fix tests after rebasing. --- .../xpack/esql/analysis/Analyzer.java | 31 ++++++- .../esql/inference/InferenceResolver.java | 74 ++++++++++++++++- .../esql/inference/InferenceService.java | 5 +- .../optimizer/LogicalPlanPreOptimizer.java | 4 +- .../xpack/esql/session/EsqlSession.java | 2 +- .../xpack/esql/analysis/AnalyzerTests.java | 40 ++++----- .../inference/InferenceOperatorTestCase.java | 4 +- .../inference/InferenceResolverTests.java | 3 +- .../AbstractLogicalPlanPreOptimizerTests.java | 82 ++++++++++--------- .../optimizer/PhysicalPlanOptimizerTests.java | 1 - ...InferenceFunctionConstantFoldingTests.java | 2 +- .../planner/LocalExecutionPlannerTests.java | 1 - 12 files changed, 176 insertions(+), 73 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java index a1a13b72c6ab1..f6ecef8316bd5 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java @@ -66,6 +66,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum; import org.elasticsearch.xpack.esql.expression.function.aggregate.SumOverTime; import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction; +import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction; import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest; @@ -1311,7 +1312,35 @@ private static class ResolveInference extends ParameterizedRule resolveInferencePlan(p, context)); + return plan.transformDown(InferencePlan.class, p -> resolveInferencePlan(p, context)) + .transformExpressionsOnly(InferenceFunction.class, f -> resolveInferenceFunction(f, context)); + } + + private InferenceFunction resolveInferenceFunction(InferenceFunction inferenceFunction, AnalyzerContext context) { + assert inferenceFunction.inferenceId().resolved() && inferenceFunction.inferenceId().foldable(); + + String inferenceId = BytesRefs.toString(inferenceFunction.inferenceId().fold(FoldContext.small())); + ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId); + + if (resolvedInference == null) { + String error = context.inferenceResolution().getError(inferenceId); + return inferenceFunction.withInferenceResolutionError(inferenceId, error); + } + + if (resolvedInference.taskType() != inferenceFunction.taskType()) { + String error = "cannot use inference endpoint [" + + inferenceId + + "] with task type [" + + resolvedInference.taskType() + + "] within a " + + context.functionRegistry().snapshotRegistry().functionName(inferenceFunction.getClass()) + + " function. Only inference endpoints with the task type [" + + inferenceFunction.taskType() + + "] are supported."; + return inferenceFunction.withInferenceResolutionError(inferenceId, error); + } + + return inferenceFunction; } private LogicalPlan resolveInferencePlan(InferencePlan plan, AnalyzerContext context) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java index 6813dc845fcce..9870b7bb22722 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java @@ -16,6 +16,9 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; +import org.elasticsearch.xpack.esql.expression.function.FunctionDefinition; +import org.elasticsearch.xpack.esql.expression.function.UnresolvedFunction; +import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan; @@ -31,13 +34,16 @@ public class InferenceResolver { private final Client client; + private final EsqlFunctionRegistry functionRegistry; + /** * Constructs a new {@code InferenceResolver}. * * @param client The Elasticsearch client for executing inference deployment lookups */ - public InferenceResolver(Client client) { + public InferenceResolver(Client client, EsqlFunctionRegistry functionRegistry) { this.client = client; + this.functionRegistry = functionRegistry; } /** @@ -72,6 +78,7 @@ public void resolveInferenceIds(LogicalPlan plan, ActionListener c) { collectInferenceIdsFromInferencePlans(plan, c); + collectInferenceIdsFromInferenceFunctions(plan, c); } /** @@ -131,6 +138,38 @@ private void collectInferenceIdsFromInferencePlans(LogicalPlan plan, Consumer c.accept(inferenceId(inferencePlan))); } + /** + * Collects inference IDs from inference function calls within the logical plan. + *

+ * This method scans the logical plan for {@link UnresolvedFunction} instances that represent + * inference functions (e.g., TEXT_EMBEDDING). For each inference function found: + *

    + *
  1. Resolves the function definition through the registry and checks if the function implements {@link InferenceFunction}
  2. + *
  3. Extracts the inference deployment ID from the function arguments
  4. + *
+ *

+ * This operates during pre-analysis when functions are still unresolved, allowing early + * validation of inference deployments before query optimization. + * + * @param plan The logical plan to scan for inference function calls + * @param c Consumer function to receive each discovered inference ID + */ + private void collectInferenceIdsFromInferenceFunctions(LogicalPlan plan, Consumer c) { + EsqlFunctionRegistry snapshotRegistry = functionRegistry.snapshotRegistry(); + plan.forEachExpressionUp(UnresolvedFunction.class, f -> { + String functionName = snapshotRegistry.resolveAlias(f.name()); + if (snapshotRegistry.functionExists(functionName)) { + FunctionDefinition def = snapshotRegistry.resolveFunction(functionName); + if (InferenceFunction.class.isAssignableFrom(def.clazz())) { + String inferenceId = inferenceId(f, def); + if (inferenceId != null) { + c.accept(inferenceId); + } + } + } + }); + } + /** * Extracts the inference ID from an InferencePlan object. * @@ -141,10 +180,43 @@ private static String inferenceId(InferencePlan plan) { return inferenceId(plan.inferenceId()); } + /** + * Extracts the inference ID from an Expression (expect the expression to be a constant). + */ private static String inferenceId(Expression e) { return BytesRefs.toString(e.fold(FoldContext.small())); } + /** + * Extracts the inference ID from an {@link UnresolvedFunction} instance. + *

+ * This method inspects the function's arguments to find the inference ID. + * Currently, it only supports positional parameters named "inference_id". + * + * @param f The unresolved function to extract the ID from + * @param def The function definition + * @return The inference ID as a string, or null if not found + */ + public String inferenceId(UnresolvedFunction f, FunctionDefinition def) { + EsqlFunctionRegistry.FunctionDescription functionDescription = EsqlFunctionRegistry.description(def); + + for (int i = 0; i < functionDescription.args().size(); i++) { + EsqlFunctionRegistry.ArgSignature arg = functionDescription.args().get(i); + + if (arg.name().equals(InferenceFunction.INFERENCE_ID_PARAMETER_NAME)) { + // Found a positional parameter named "inference_id", so use its value + Expression argValue = f.arguments().get(i); + if (argValue != null && argValue.foldable()) { + return inferenceId(argValue); + } + } + + // TODO: support inference ID as an optional named parameter + } + + return null; + } + public static Factory factory(Client client) { return new Factory(client); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceService.java index 37c163beaecda..af6df9457a6eb 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceService.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.inference; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner; import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunnerConfig; @@ -35,8 +36,8 @@ private InferenceService(InferenceResolver.Factory inferenceResolverFactory, Bul * * @return a new inference resolver instance */ - public InferenceResolver inferenceResolver() { - return inferenceResolverFactory.create(); + public InferenceResolver inferenceResolver(EsqlFunctionRegistry functionRegistry) { + return inferenceResolverFactory.create(functionRegistry); } public BulkInferenceRunner bulkInferenceRunner() { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java index c0ee08f830d72..81dfd59484a1f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java @@ -28,7 +28,9 @@ public class LogicalPlanPreOptimizer { private final List rules; public LogicalPlanPreOptimizer(TransportActionServices services, LogicalPreOptimizerContext preOptimizerContext) { - rules = List.of(new InferenceFunctionConstantFolding(services.bulkInferenceRunner(), preOptimizerContext.foldCtx())); + rules = List.of( + new InferenceFunctionConstantFolding(services.inferenceService().bulkInferenceRunner(), preOptimizerContext.foldCtx()) + ); } /** diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java index 34ece5661c8d7..175f5bac87a0a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java @@ -751,7 +751,7 @@ private static void analyzeAndMaybeRetry( } private void resolveInferences(LogicalPlan plan, PreAnalysisResult preAnalysisResult, ActionListener l) { - inferenceService.inferenceResolver().resolveInferenceIds(plan, l.map(preAnalysisResult::withInferenceResolution)); + inferenceService.inferenceResolver(functionRegistry).resolveInferenceIds(plan, l.map(preAnalysisResult::withInferenceResolution)); } private PhysicalPlan logicalPlanToPhysicalPlan(LogicalPlan optimizedPlan, EsqlQueryRequest request) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index 17725a63aa777..a32d9f61122f2 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -50,7 +50,7 @@ import org.elasticsearch.xpack.esql.expression.function.fulltext.MultiMatch; import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString; import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket; -import org.elasticsearch.xpack.esql.expression.function.inference.EmbedText; +import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDateNanos; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDatetime; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger; @@ -3843,13 +3843,13 @@ public void testResolveCompletionOutputFieldOverwriteInputField() { public void testResolveEmbedTextInferenceId() { LogicalPlan plan = analyze(""" FROM books METADATA _score - | EVAL embedding = EMBED_TEXT("description", "text-embedding-inference-id") + | EVAL embedding = TEXT_EMBEDDING("description", "text-embedding-inference-id") """, "mapping-books.json"); var limit = as(plan, Limit.class); var eval = as(limit.child(), Eval.class); var embedTextAlias = as(eval.fields().get(0), Alias.class); - var embedText = as(embedTextAlias.child(), EmbedText.class); + var embedText = as(embedTextAlias.child(), TextEmbedding.class); assertThat(embedText.inferenceId(), equalTo(string("text-embedding-inference-id"))); assertThat(embedText.inputText(), equalTo(string("description"))); @@ -3859,11 +3859,11 @@ public void testResolveEmbedTextInferenceIdInvalidTaskType() { assertError( """ FROM books METADATA _score - | EVAL embedding = EMBED_TEXT(description, "completion-inference-id") + | EVAL embedding = TEXT_EMBEDDING("blue", "completion-inference-id") """, "mapping-books.json", new QueryParams(), - "cannot use inference endpoint [completion-inference-id] with task type [completion] within a embed_text function." + "cannot use inference endpoint [completion-inference-id] with task type [completion] within a text_embedding function." + " Only inference endpoints with the task type [text_embedding] are supported" ); } @@ -3871,28 +3871,28 @@ public void testResolveEmbedTextInferenceIdInvalidTaskType() { public void testResolveEmbedTextInferenceMissingInferenceId() { assertError(""" FROM books METADATA _score - | EVAL embedding = EMBED_TEXT(description, "unknown-inference-id") + | EVAL embedding = TEXT_EMBEDDING("blue", "unknown-inference-id") """, "mapping-books.json", new QueryParams(), "unresolved inference [unknown-inference-id]"); } public void testResolveEmbedTextInferenceIdResolutionError() { assertError(""" FROM books METADATA _score - | EVAL embedding = EMBED_TEXT(description, "error-inference-id") + | EVAL embedding = TEXT_EMBEDDING("blue", "error-inference-id") """, "mapping-books.json", new QueryParams(), "error with inference resolution"); } public void testResolveEmbedTextInNestedExpression() { LogicalPlan plan = analyze(""" FROM colors METADATA _score - | WHERE KNN(rgb_vector, EMBED_TEXT("blue", "text-embedding-inference-id"), 10) + | WHERE KNN(rgb_vector, TEXT_EMBEDDING("blue", "text-embedding-inference-id"), 10) """, "mapping-colors.json"); var limit = as(plan, Limit.class); var filter = as(limit.child(), Filter.class); - // Navigate to the EMBED_TEXT function within the KNN function - filter.condition().forEachDown(EmbedText.class, embedText -> { + // Navigate to the TEXT_EMBEDDING function within the KNN function + filter.condition().forEachDown(TextEmbedding.class, embedText -> { assertThat(embedText.inferenceId(), equalTo(string("text-embedding-inference-id"))); assertThat(embedText.inputText(), equalTo(string("blue"))); }); @@ -3901,37 +3901,37 @@ public void testResolveEmbedTextInNestedExpression() { public void testResolveEmbedTextDataType() { LogicalPlan plan = analyze(""" FROM books METADATA _score - | EVAL embedding = EMBED_TEXT("description", "text-embedding-inference-id") + | EVAL embedding = TEXT_EMBEDDING("description", "text-embedding-inference-id") """, "mapping-books.json"); var limit = as(plan, Limit.class); var eval = as(limit.child(), Eval.class); var embedTextAlias = as(eval.fields().get(0), Alias.class); - var embedText = as(embedTextAlias.child(), EmbedText.class); + var embedText = as(embedTextAlias.child(), TextEmbedding.class); assertThat(embedText.dataType(), equalTo(DataType.DENSE_VECTOR)); } public void testResolveEmbedTextInvalidParameters() { assertError( - "FROM books METADATA _score| EVAL embedding = EMBED_TEXT(description, \"text-embedding-inference-id\")", + "FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING(description, \"text-embedding-inference-id\")", "mapping-books.json", new QueryParams(), - "first argument of [EMBED_TEXT(description, \"text-embedding-inference-id\")] must be a constant, received [description]" + "first argument of [TEXT_EMBEDDING(description, \"text-embedding-inference-id\")] must be a constant, received [description]" ); assertError( - "FROM books METADATA _score| EVAL embedding = EMBED_TEXT(description)", + "FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING(description)", "mapping-books.json", new QueryParams(), - "error building [embed_text]: function [embed_text] expects exactly two arguments, it received 1", + "error building [text_embedding]: function [text_embedding] expects exactly two arguments, it received 1", ParsingException.class ); } public void testResolveEmbedTextWithPositionalQueryParams() { LogicalPlan plan = analyze( - "FROM books METADATA _score| EVAL embedding = EMBED_TEXT(?, ?)", + "FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING(?, ?)", "mapping-books.json", new QueryParams(List.of(paramAsConstant(null, "description"), paramAsConstant(null, "text-embedding-inference-id"))) ); @@ -3939,7 +3939,7 @@ public void testResolveEmbedTextWithPositionalQueryParams() { var limit = as(plan, Limit.class); var eval = as(limit.child(), Eval.class); var embedTextAlias = as(eval.fields().get(0), Alias.class); - var embedText = as(embedTextAlias.child(), EmbedText.class); + var embedText = as(embedTextAlias.child(), TextEmbedding.class); assertThat(embedText.inferenceId(), equalTo(string("text-embedding-inference-id"))); assertThat(embedText.inputText(), equalTo(string("description"))); @@ -3947,7 +3947,7 @@ public void testResolveEmbedTextWithPositionalQueryParams() { public void testResolveEmbedTextWithNamedQueryParams() { LogicalPlan plan = analyze( - "FROM books METADATA _score| EVAL embedding = EMBED_TEXT(?inputText, ?inferenceId)", + "FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING(?inputText, ?inferenceId)", "mapping-books.json", new QueryParams( List.of(paramAsConstant("inputText", "description"), paramAsConstant("inferenceId", "text-embedding-inference-id")) @@ -3957,7 +3957,7 @@ public void testResolveEmbedTextWithNamedQueryParams() { var limit = as(plan, Limit.class); var eval = as(limit.child(), Eval.class); var embedTextAlias = as(eval.fields().get(0), Alias.class); - var embedText = as(embedTextAlias.child(), EmbedText.class); + var embedText = as(embedTextAlias.child(), TextEmbedding.class); assertThat(embedText.inferenceId(), equalTo(string("text-embedding-inference-id"))); assertThat(embedText.inputText(), equalTo(string("description"))); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java index e72eecccf5ab8..d68cc212b0137 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java @@ -36,7 +36,6 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.client.NoOpClient; import org.elasticsearch.threadpool.FixedExecutorBuilder; -import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; @@ -56,8 +55,7 @@ public abstract class InferenceOperatorTestCase listener) { - try { - runWithDelay( - () -> listener.onResponse( - new InferenceAction.Response( - TEST_EMBEDDING_MODELS.get(textEmbeddingModel).embeddingResults(request.getInput().getFirst()) - ) - ) - ); - } catch (Exception e) { - listener.onFailure(e); + @SuppressWarnings("unchecked") + protected void doExecute( + ActionType action, + Request request, + ActionListener listener + ) { + if (action instanceof InferenceAction && request instanceof InferenceAction.Request inferenceRequest) { + TextEmbeddingResults inferenceResult = TEST_EMBEDDING_MODELS.get(textEmbeddingModel) + .embeddingResults(inferenceRequest.getInput().getFirst()); + listener.onResponse((Response) new InferenceAction.Response(inferenceResult)); + return; + } - } - @Override - public void executeBulk(BulkInferenceRequestIterator requests, ActionListener> listener) { - listener.onFailure(new UnsupportedOperationException("executeBulk is not supported in this test")); + listener.onFailure(new UnsupportedOperationException("Unexpected action: " + action)); } }; + return new BulkInferenceRunner(mockClient, between(1, 10)); } /** @@ -244,9 +245,10 @@ public void executeBulk(BulkInferenceRequestIterator requests, ActionListener

  • exchangeSinkHandler.createExchangeSink(() -> {}), null, diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/InferenceFunctionConstantFoldingTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/InferenceFunctionConstantFoldingTests.java index 940e613a3154d..f34db2efae229 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/InferenceFunctionConstantFoldingTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/InferenceFunctionConstantFoldingTests.java @@ -105,7 +105,7 @@ public void testKnnFunctionEmbedding() throws Exception { */ private LogicalPlan inferenceFunctionConstantFolding(LogicalPlan plan) { PlainActionFuture preOptimized = new PlainActionFuture<>(); - new InferenceFunctionConstantFolding(mockedInferenceRunner(embeddingModel), FoldContext.small()).apply(plan, preOptimized); + new InferenceFunctionConstantFolding(mockBulkInferenceRunner(embeddingModel), FoldContext.small()).apply(plan, preOptimized); return preOptimized.actionGet(); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java index 971c4719d2c99..b56f4a3a4898b 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java @@ -318,7 +318,6 @@ private LocalExecutionPlanner planner() throws IOException { null, null, null, - null, esPhysicalOperationProviders(shardContexts), shardContexts ); From 71d591b2494ba7029069646dd0bb3362df9ed63d Mon Sep 17 00:00:00 2001 From: afoucret Date: Fri, 1 Aug 2025 14:37:54 +0200 Subject: [PATCH 31/31] Update from main. --- .../org/elasticsearch/xpack/esql/execution/PlanExecutor.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java index 414e1f372ea3f..d42cd3f106997 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java @@ -87,7 +87,7 @@ public void esql( indexResolver, enrichPolicyResolver, preAnalyzer, - new LogicalPlanPreOptimizer(new LogicalPreOptimizerContext(foldContext)), + new LogicalPlanPreOptimizer(services, new LogicalPreOptimizerContext(foldContext)), functionRegistry, new LogicalPlanOptimizer(new LogicalOptimizerContext(cfg, foldContext)), mapper,