From 59087fb9189471f06f237fdf636eb799dbe2cce0 Mon Sep 17 00:00:00 2001 From: afoucret Date: Tue, 9 Sep 2025 10:28:12 +0200 Subject: [PATCH 01/26] Add text embedding function definition. --- .../esql/images/functions/text_embedding.svg | 1 + .../definition/functions/text_embedding.json | 9 + .../kibana/docs/functions/text_embedding.md | 4 + .../xpack/esql/action/EsqlCapabilities.java | 5 + .../esql/expression/ExpressionWritables.java | 6 + .../function/EsqlFunctionRegistry.java | 4 +- .../function/inference/InferenceFunction.java | 41 +++++ .../inference/InferenceWritables.java | 36 ++++ .../function/inference/TextEmbedding.java | 164 ++++++++++++++++++ .../inference/TextEmbeddingErrorTests.java | 74 ++++++++ .../TextEmbeddingSerializationTests.java | 46 +++++ .../inference/TextEmbeddingTests.java | 72 ++++++++ 12 files changed, 461 insertions(+), 1 deletion(-) create mode 100644 docs/reference/query-languages/esql/images/functions/text_embedding.svg create mode 100644 docs/reference/query-languages/esql/kibana/definition/functions/text_embedding.json create mode 100644 docs/reference/query-languages/esql/kibana/docs/functions/text_embedding.md create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceWritables.java create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbedding.java create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingErrorTests.java create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingSerializationTests.java create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingTests.java diff --git a/docs/reference/query-languages/esql/images/functions/text_embedding.svg b/docs/reference/query-languages/esql/images/functions/text_embedding.svg new file mode 100644 index 0000000000000..dab58c5e5bda0 --- /dev/null +++ b/docs/reference/query-languages/esql/images/functions/text_embedding.svg @@ -0,0 +1 @@ +TEXT_EMBEDDING(text,inference_id) \ No newline at end of file diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/text_embedding.json b/docs/reference/query-languages/esql/kibana/definition/functions/text_embedding.json new file mode 100644 index 0000000000000..473572136ade7 --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/definition/functions/text_embedding.json @@ -0,0 +1,9 @@ +{ + "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.", + "type" : "scalar", + "name" : "text_embedding", + "description" : "Generates dense vector embeddings for text using a specified inference emdpoint.", + "signatures" : [ ], + "preview" : true, + "snapshot_only" : true +} diff --git a/docs/reference/query-languages/esql/kibana/docs/functions/text_embedding.md b/docs/reference/query-languages/esql/kibana/docs/functions/text_embedding.md new file mode 100644 index 0000000000000..01cf89a5397f5 --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/docs/functions/text_embedding.md @@ -0,0 +1,4 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +### TEXT EMBEDDING +Generates dense vector embeddings for text using a specified inference emdpoint. diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index a9dad8c322b12..07cf2f0bb8557 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -1303,6 +1303,11 @@ public enum Cap { */ KNN_FUNCTION_V5(Build.current().isSnapshot()), + /** + * Support for the {@code TEXT_EMBEDDING} function for generating dense vector embeddings. + */ + TEXT_EMBEDDING_FUNCTION(Build.current().isSnapshot()), + /** * Support for the LIKE operator with a list of wildcards. */ diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java index 20de89a53780d..29d6b27f22078 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java @@ -12,6 +12,7 @@ import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateWritables; import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextWritables; +import org.elasticsearch.xpack.esql.expression.function.inference.InferenceWritables; import org.elasticsearch.xpack.esql.expression.function.scalar.ScalarFunctionWritables; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FromBase64; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToAggregateMetricDouble; @@ -118,6 +119,7 @@ public static List getNamedWriteables() { entries.addAll(fullText()); entries.addAll(unaryScalars()); entries.addAll(vector()); + entries.addAll(inference()); return entries; } @@ -260,4 +262,8 @@ private static List fullText() { private static List vector() { return VectorWritables.getNamedWritables(); } + + private static List inference() { + return InferenceWritables.getNamedWritables(); + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index f4d20dcafd1a0..2ce93a2ee54c0 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -193,6 +193,7 @@ import org.elasticsearch.xpack.esql.expression.function.vector.L1Norm; import org.elasticsearch.xpack.esql.expression.function.vector.L2Norm; import org.elasticsearch.xpack.esql.expression.function.vector.Magnitude; +import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding; import org.elasticsearch.xpack.esql.parser.ParsingException; import org.elasticsearch.xpack.esql.session.Configuration; @@ -519,7 +520,8 @@ private static FunctionDefinition[][] snapshotFunctions() { def(Magnitude.class, Magnitude::new, "v_magnitude"), def(Hamming.class, Hamming::new, "v_hamming"), def(UrlEncode.class, UrlEncode::new, "url_encode"), - def(UrlDecode.class, UrlDecode::new, "url_decode") } }; + def(UrlDecode.class, UrlDecode::new, "url_decode"), + def(TextEmbedding.class, bi(TextEmbedding::new), "text_embedding") } }; } public EsqlFunctionRegistry snapshotRegistry() { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java new file mode 100644 index 0000000000000..d2d6d9b6e2af7 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.inference; + +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.function.Function; +import org.elasticsearch.xpack.esql.core.tree.Source; + +import java.util.List; + +/** + * Base class for ESQL functions that use inference endpoints (e.g., TEXT_EMBEDDING). + */ +public abstract class InferenceFunction> extends Function { + + public static final String INFERENCE_ID_PARAMETER_NAME = "inference_id"; + + protected InferenceFunction(Source source, List children) { + super(source, children); + } + + /** The inference endpoint identifier expression. */ + public abstract Expression inferenceId(); + + /** The task type required by this function (e.g., TEXT_EMBEDDING). */ + public abstract TaskType taskType(); + + /** Returns a copy with inference resolution error for display to user. */ + public abstract PlanType withInferenceResolutionError(String inferenceId, String error); + + /** True if this function contains nested inference function calls. */ + public boolean hasNestedInferenceFunction() { + return anyMatch(e -> e instanceof InferenceFunction && e != this); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceWritables.java new file mode 100644 index 0000000000000..9809ef0d46b66 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceWritables.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.inference; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * Defines the named writables for inference functions in ESQL. + */ +public final class InferenceWritables { + + private InferenceWritables() { + // Utility class + throw new UnsupportedOperationException(); + } + + public static List getNamedWritables() { + List entries = new ArrayList<>(); + + if (EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()) { + entries.add(TextEmbedding.ENTRY); + } + + return Collections.unmodifiableList(entries); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbedding.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbedding.java new file mode 100644 index 0000000000000..a5ef509df1dff --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbedding.java @@ -0,0 +1,164 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.inference; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isString; + +/** + * TEXT_EMBEDDING function converts text to dense vector embeddings using an inference endpoint. + */ +public class TextEmbedding extends InferenceFunction { + + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "TextEmbedding", + TextEmbedding::new + ); + + private final Expression inferenceId; + private final Expression inputText; + + @FunctionInfo( + returnType = "dense_vector", + description = "Generates dense vector embeddings for text using a specified inference endpoint.", + appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) }, + preview = true + ) + public TextEmbedding( + Source source, + @Param(name = "text", type = { "keyword", "text" }, description = "Text to embed") Expression inputText, + @Param( + name = InferenceFunction.INFERENCE_ID_PARAMETER_NAME, + type = { "keyword", "text" }, + description = "Identifier of the inference endpoint" + ) Expression inferenceId + ) { + super(source, List.of(inputText, inferenceId)); + this.inferenceId = inferenceId; + this.inputText = inputText; + } + + private TextEmbedding(StreamInput in) throws IOException { + this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class), in.readNamedWriteable(Expression.class)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + source().writeTo(out); + out.writeNamedWriteable(inputText); + out.writeNamedWriteable(inferenceId); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + public Expression inputText() { + return inputText; + } + + @Override + public Expression inferenceId() { + return inferenceId; + } + + @Override + public boolean foldable() { + return inferenceId.foldable() && inputText.foldable(); + } + + @Override + public DataType dataType() { + return DataType.DENSE_VECTOR; + } + + @Override + protected TypeResolution resolveType() { + if (childrenResolved() == false) { + return new TypeResolution("Unresolved children"); + } + + TypeResolution textResolution = isNotNull(inputText, sourceText(), FIRST).and(isFoldable(inputText, sourceText(), FIRST)) + .and(isString(inputText, sourceText(), FIRST)); + + if (textResolution.unresolved()) { + return textResolution; + } + + TypeResolution inferenceIdResolution = isNotNull(inferenceId, sourceText(), SECOND).and(isString(inferenceId, sourceText(), SECOND)) + .and(isFoldable(inferenceId, sourceText(), SECOND)); + + if (inferenceIdResolution.unresolved()) { + return inferenceIdResolution; + } + + return TypeResolution.TYPE_RESOLVED; + } + + @Override + public TaskType taskType() { + return TaskType.TEXT_EMBEDDING; + } + + @Override + public TextEmbedding withInferenceResolutionError(String inferenceId, String error) { + return new TextEmbedding(source(), inputText, new UnresolvedAttribute(inferenceId().source(), inferenceId, error)); + } + + @Override + public Expression replaceChildren(List newChildren) { + return new TextEmbedding(source(), newChildren.get(0), newChildren.get(1)); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, TextEmbedding::new, inputText, inferenceId); + } + + @Override + public String toString() { + return "TEXT_EMBEDDING(" + inputText + ", " + inferenceId + ")"; + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + if (super.equals(o) == false) return false; + TextEmbedding textEmbedding = (TextEmbedding) o; + return Objects.equals(inferenceId, textEmbedding.inferenceId) && Objects.equals(inputText, textEmbedding.inputText); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), inferenceId, inputText); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingErrorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingErrorTests.java new file mode 100644 index 0000000000000..9af017bd5207f --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingErrorTests.java @@ -0,0 +1,74 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.inference; + +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; +import org.elasticsearch.xpack.esql.expression.function.ErrorsForCasesWithoutExamplesTestCase; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.hamcrest.Matcher; +import org.junit.Before; + +import java.util.List; +import java.util.Locale; +import java.util.Set; + +import static org.hamcrest.Matchers.equalTo; + +/** Tests error conditions and type validation for TEXT_EMBEDDING function. */ +public class TextEmbeddingErrorTests extends ErrorsForCasesWithoutExamplesTestCase { + + @Before + public void checkCapability() { + assumeTrue("TEXT_EMBEDDING is not enabled", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()); + } + + @Override + protected List cases() { + return paramsToSuppliers(TextEmbeddingTests.parameters()); + } + + @Override + protected Expression build(Source source, List args) { + return new TextEmbedding(source, args.get(0), args.get(1)); + } + + @Override + protected Matcher expectedTypeErrorMatcher(List> validPerPosition, List signature) { + return equalTo(typeErrorMessage(true, validPerPosition, signature, (v, p) -> "string")); + } + + protected static String typeErrorMessage( + boolean includeOrdinal, + List> validPerPosition, + List signature, + AbstractFunctionTestCase.PositionalErrorMessageSupplier positionalErrorMessageSupplier + ) { + for (int i = 0; i < signature.size(); i++) { + if (signature.get(i) == DataType.NULL) { + String ordinal = includeOrdinal ? TypeResolutions.ParamOrdinal.fromIndex(i).name().toLowerCase(Locale.ROOT) + " " : ""; + return ordinal + "argument of [" + sourceForSignature(signature) + "] cannot be null, received []"; + } + + if (validPerPosition.get(i).contains(signature.get(i)) == false) { + break; + } + } + + return ErrorsForCasesWithoutExamplesTestCase.typeErrorMessage( + includeOrdinal, + validPerPosition, + signature, + positionalErrorMessageSupplier + ); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingSerializationTests.java new file mode 100644 index 0000000000000..5d7e1dfa4301a --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingSerializationTests.java @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.inference; + +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; +import org.junit.Before; + +import java.io.IOException; + +/** Tests serialization/deserialization of TEXT_EMBEDDING function instances. */ +public class TextEmbeddingSerializationTests extends AbstractExpressionSerializationTests { + + @Before + public void checkCapability() { + assumeTrue("TEXT_EMBEDDING is not enabled", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()); + } + + @Override + protected TextEmbedding createTestInstance() { + Source source = randomSource(); + Expression inputText = randomChild(); + Expression inferenceId = randomChild(); + return new TextEmbedding(source, inputText, inferenceId); + } + + @Override + protected TextEmbedding mutateInstance(TextEmbedding instance) throws IOException { + Source source = instance.source(); + Expression inputText = instance.inputText(); + Expression inferenceId = instance.inferenceId(); + if (randomBoolean()) { + inputText = randomValueOtherThan(inputText, AbstractExpressionSerializationTests::randomChild); + } else { + inferenceId = randomValueOtherThan(inferenceId, AbstractExpressionSerializationTests::randomChild); + } + return new TextEmbedding(source, inputText, inferenceId); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingTests.java new file mode 100644 index 0000000000000..ef5a570278c3b --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingTests.java @@ -0,0 +1,72 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.inference; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; +import org.elasticsearch.xpack.esql.expression.function.FunctionName; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.hamcrest.Matchers; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Supplier; + +import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; +import static org.hamcrest.Matchers.equalTo; + +@FunctionName("text_embedding") +public class TextEmbeddingTests extends AbstractFunctionTestCase { + @Before + public void checkCapability() { + assumeTrue("TEXT_EMBEDDING is not enabled", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()); + } + + public TextEmbeddingTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + List suppliers = new ArrayList<>(); + + // Test all string type combinations for text input and inference endpoint ID + for (DataType inputTextDataType : DataType.stringTypes()) { + for (DataType inferenceIdDataType : DataType.stringTypes()) { + suppliers.add( + new TestCaseSupplier( + List.of(inputTextDataType, inferenceIdDataType), + () -> new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(randomBytesReference(10).toBytesRef(), inputTextDataType, "inputText"), + new TestCaseSupplier.TypedData(randomBytesReference(10).toBytesRef(), inferenceIdDataType, "inference_id") + ), + Matchers.blankOrNullString(), + DENSE_VECTOR, + equalTo(true) + ) + ) + ); + } + } + + return parameterSuppliersFromTypedData(suppliers); + } + + @Override + protected Expression build(Source source, List args) { + return new TextEmbedding(source, args.get(0), args.get(1)); + } +} From 8f4c409c75e1b0e5823c1c9f4f715a50c9773efb Mon Sep 17 00:00:00 2001 From: afoucret Date: Tue, 9 Sep 2025 10:53:43 +0200 Subject: [PATCH 02/26] Fix tests. --- .../esql/kibana/definition/functions/text_embedding.json | 2 +- .../esql/kibana/docs/functions/text_embedding.md | 2 +- .../src/test/java/org/elasticsearch/xpack/esql/CsvTests.java | 4 ++++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/text_embedding.json b/docs/reference/query-languages/esql/kibana/definition/functions/text_embedding.json index 473572136ade7..9e4967b92c367 100644 --- a/docs/reference/query-languages/esql/kibana/definition/functions/text_embedding.json +++ b/docs/reference/query-languages/esql/kibana/definition/functions/text_embedding.json @@ -2,7 +2,7 @@ "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.", "type" : "scalar", "name" : "text_embedding", - "description" : "Generates dense vector embeddings for text using a specified inference emdpoint.", + "description" : "Generates dense vector embeddings for text using a specified inference endpoint.", "signatures" : [ ], "preview" : true, "snapshot_only" : true diff --git a/docs/reference/query-languages/esql/kibana/docs/functions/text_embedding.md b/docs/reference/query-languages/esql/kibana/docs/functions/text_embedding.md index 01cf89a5397f5..bb3e74fc116cd 100644 --- a/docs/reference/query-languages/esql/kibana/docs/functions/text_embedding.md +++ b/docs/reference/query-languages/esql/kibana/docs/functions/text_embedding.md @@ -1,4 +1,4 @@ % This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. ### TEXT EMBEDDING -Generates dense vector embeddings for text using a specified inference emdpoint. +Generates dense vector embeddings for text using a specified inference endpoint. diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java index 2ff6ce71be516..c635a9c9d5c00 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java @@ -303,6 +303,10 @@ public final void test() throws Throwable { "can't use KQL function in csv tests", testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KQL_FUNCTION.capabilityName()) ); + assumeFalse( + "can't use TEXT_EMBEDDING function in csv tests", + testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.capabilityName()) + ); assumeFalse( "can't use KNN function in csv tests", testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KNN_FUNCTION_V5.capabilityName()) From 847c99831638746123e04bb70853d7719e12dc5c Mon Sep 17 00:00:00 2001 From: afoucret Date: Thu, 11 Sep 2025 19:41:26 +0200 Subject: [PATCH 03/26] InferenceResolution for text embedding function. --- .../function/EsqlFunctionRegistry.java | 2 +- .../esql/inference/InferenceResolver.java | 53 +++++++++++++++++-- .../esql/inference/InferenceService.java | 7 ++- .../xpack/esql/session/EsqlSession.java | 2 +- .../inference/InferenceResolverTests.java | 25 +++++++-- 5 files changed, 79 insertions(+), 10 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index 2ce93a2ee54c0..337ff8d6e5891 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -56,6 +56,7 @@ import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket; import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize; import org.elasticsearch.xpack.esql.expression.function.grouping.TBucket; +import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Least; @@ -193,7 +194,6 @@ import org.elasticsearch.xpack.esql.expression.function.vector.L1Norm; import org.elasticsearch.xpack.esql.expression.function.vector.L2Norm; import org.elasticsearch.xpack.esql.expression.function.vector.Magnitude; -import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding; import org.elasticsearch.xpack.esql.parser.ParsingException; import org.elasticsearch.xpack.esql.session.Configuration; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java index f7d349281e004..c48b797ed67e0 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceResolver.java @@ -15,6 +15,11 @@ import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; +import org.elasticsearch.xpack.esql.expression.function.FunctionDefinition; +import org.elasticsearch.xpack.esql.expression.function.UnresolvedFunction; +import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan; @@ -29,14 +34,16 @@ public class InferenceResolver { private final Client client; + private final EsqlFunctionRegistry functionRegistry; /** * Constructs a new {@code InferenceResolver}. * * @param client The Elasticsearch client for executing inference deployment lookups */ - public InferenceResolver(Client client) { + public InferenceResolver(Client client, EsqlFunctionRegistry functionRegistry) { this.client = client; + this.functionRegistry = functionRegistry; } /** @@ -71,6 +78,7 @@ public void resolveInferenceIds(LogicalPlan plan, ActionListener c) { collectInferenceIdsFromInferencePlans(plan, c); + collectInferenceIdsFromInferenceFunctions(plan, c); } /** @@ -130,6 +138,28 @@ private void collectInferenceIdsFromInferencePlans(LogicalPlan plan, Consumer c.accept(inferenceId(inferencePlan))); } + /** + * Collects inference IDs from function expressions within the logical plan. + * + * @param plan The logical plan to scan for function expressions + * @param c Consumer function to receive each discovered inference ID + */ + private void collectInferenceIdsFromInferenceFunctions(LogicalPlan plan, Consumer c) { + EsqlFunctionRegistry snapshotRegistry = functionRegistry.snapshotRegistry(); + plan.forEachExpressionUp(UnresolvedFunction.class, f -> { + String functionName = snapshotRegistry.resolveAlias(f.name()); + if (snapshotRegistry.functionExists(functionName)) { + FunctionDefinition def = snapshotRegistry.resolveFunction(functionName); + if (InferenceFunction.class.isAssignableFrom(def.clazz())) { + String inferenceId = inferenceId(f, def); + if (inferenceId != null) { + c.accept(inferenceId); + } + } + } + }); + } + /** * Extracts the inference ID from an InferencePlan object. * @@ -144,6 +174,23 @@ private static String inferenceId(Expression e) { return BytesRefs.toString(e.fold(FoldContext.small())); } + public String inferenceId(UnresolvedFunction f, FunctionDefinition def) { + EsqlFunctionRegistry.FunctionDescription functionDescription = EsqlFunctionRegistry.description(def); + + for (int i = 0; i < functionDescription.args().size(); i++) { + EsqlFunctionRegistry.ArgSignature arg = functionDescription.args().get(i); + + if (arg.name().equals(InferenceFunction.INFERENCE_ID_PARAMETER_NAME)) { + Expression argValue = f.arguments().get(i); + if (argValue != null && argValue.foldable() && DataType.isString(argValue.dataType())) { + return inferenceId(argValue); + } + } + } + + return null; + } + public static Factory factory(Client client) { return new Factory(client); } @@ -155,8 +202,8 @@ private Factory(Client client) { this.client = client; } - public InferenceResolver create() { - return new InferenceResolver(client); + public InferenceResolver create(EsqlFunctionRegistry functionRegistry) { + return new InferenceResolver(client, functionRegistry); } } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceService.java index 37c163beaecda..630477a20f447 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceService.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.inference; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner; import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunnerConfig; @@ -33,10 +34,12 @@ private InferenceService(InferenceResolver.Factory inferenceResolverFactory, Bul /** * Creates an inference resolver for resolving inference IDs in logical plans. * + * @param functionRegistry the function registry to resolve functions + * * @return a new inference resolver instance */ - public InferenceResolver inferenceResolver() { - return inferenceResolverFactory.create(); + public InferenceResolver inferenceResolver(EsqlFunctionRegistry functionRegistry) { + return inferenceResolverFactory.create(functionRegistry); } public BulkInferenceRunner bulkInferenceRunner() { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java index a8078feec4f8b..42d5958ddfe61 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java @@ -731,7 +731,7 @@ private void analyzeWithRetry( } private void resolveInferences(LogicalPlan plan, PreAnalysisResult preAnalysisResult, ActionListener l) { - inferenceService.inferenceResolver().resolveInferenceIds(plan, l.map(preAnalysisResult::withInferenceResolution)); + inferenceService.inferenceResolver(functionRegistry).resolveInferenceIds(plan, l.map(preAnalysisResult::withInferenceResolution)); } private PhysicalPlan logicalPlanToPhysicalPlan(LogicalPlan optimizedPlan, EsqlQueryRequest request) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceResolverTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceResolverTests.java index 8666eedbaeaaa..67147ac1896dd 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceResolverTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceResolverTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.threadpool.FixedExecutorBuilder; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; +import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; import org.elasticsearch.xpack.esql.parser.EsqlParser; import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; import org.junit.After; @@ -44,6 +45,7 @@ public class InferenceResolverTests extends ESTestCase { private TestThreadPool threadPool; + private EsqlFunctionRegistry functionRegistry; @Before public void setThreadPool() { @@ -60,6 +62,11 @@ public void setThreadPool() { ); } + @Before + public void setUpFunctionRegistry() { + functionRegistry = new EsqlFunctionRegistry(); + } + @After public void shutdownThreadPool() { terminate(threadPool); @@ -78,6 +85,18 @@ public void testCollectInferenceIds() { List.of("completion-inference-id") ); + // Test inference ID collection from an inference function + assertCollectInferenceIds( + "FROM books METADATA _score | EVAL embedding = TEXT_EMBEDDING(\"description\", \"text-embedding-inference-id\")", + List.of("text-embedding-inference-id") + ); + + // Test inference ID collection with nested functions + assertCollectInferenceIds( + "FROM books METADATA _score | EVAL embedding = TEXT_EMBEDDING(TEXT_EMBEDDING(\"nested\", \"nested-id\"), \"outer-id\")", + List.of("nested-id", "outer-id") + ); + // Multiple inference plans assertCollectInferenceIds(""" FROM books METADATA _score @@ -139,7 +158,7 @@ public void testResolveMultipleInferenceIds() throws Exception { public void testResolveMissingInferenceIds() throws Exception { InferenceResolver inferenceResolver = inferenceResolver(); - List inferenceIds = List.of("missing-plan"); + List inferenceIds = List.of("missing-inference-id"); SetOnce inferenceResolutionSetOnce = new SetOnce<>(); @@ -153,7 +172,7 @@ public void testResolveMissingInferenceIds() throws Exception { assertThat(inferenceResolution.resolvedInferences(), empty()); assertThat(inferenceResolution.hasError(), equalTo(true)); - assertThat(inferenceResolution.getError("missing-plan"), equalTo("inference endpoint not found")); + assertThat(inferenceResolution.getError("missing-inference-id"), equalTo("inference endpoint not found")); }); } @@ -205,7 +224,7 @@ private static ActionResponse getInferenceModelResponse(GetInferenceModelAction. } private InferenceResolver inferenceResolver() { - return new InferenceResolver(mockClient()); + return new InferenceResolver(mockClient(), functionRegistry); } private static ModelConfigurations mockModelConfig(String inferenceId, TaskType taskType) { From 4fe169ad19ea1db7b16044900ec415a7499a8ba7 Mon Sep 17 00:00:00 2001 From: afoucret Date: Thu, 11 Sep 2025 19:41:49 +0200 Subject: [PATCH 04/26] Text embedding analysis and verification. --- .../xpack/esql/analysis/Analyzer.java | 34 +++++- .../esql/analysis/AnalyzerTestUtils.java | 33 +++++- .../xpack/esql/analysis/AnalyzerTests.java | 112 ++++++++++++++++++ .../xpack/esql/analysis/VerifierTests.java | 48 ++++++++ 4 files changed, 223 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java index df83feeac9f13..2f8cfa3511329 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java @@ -68,6 +68,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.SummationMode; import org.elasticsearch.xpack.esql.expression.function.aggregate.Values; import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction; +import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction; import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest; @@ -1329,7 +1330,8 @@ private static class ResolveInference extends ParameterizedRule resolveInferencePlan(p, context)); + return plan.transformDown(InferencePlan.class, p -> resolveInferencePlan(p, context)) + .transformExpressionsOnly(InferenceFunction.class, f -> resolveInferenceFunction(f, context)); } private LogicalPlan resolveInferencePlan(InferencePlan plan, AnalyzerContext context) { @@ -1358,6 +1360,36 @@ private LogicalPlan resolveInferencePlan(InferencePlan plan, AnalyzerContext return plan; } + + private InferenceFunction resolveInferenceFunction(InferenceFunction inferenceFunction, AnalyzerContext context) { + if (inferenceFunction.inferenceId().resolved() + && inferenceFunction.inferenceId().foldable() + && DataType.isString(inferenceFunction.inferenceId().dataType())) { + + String inferenceId = BytesRefs.toString(inferenceFunction.inferenceId().fold(FoldContext.small())); + ResolvedInference resolvedInference = context.inferenceResolution().getResolvedInference(inferenceId); + + if (resolvedInference == null) { + String error = context.inferenceResolution().getError(inferenceId); + return inferenceFunction.withInferenceResolutionError(inferenceId, error); + } + + if (resolvedInference.taskType() != inferenceFunction.taskType()) { + String error = "cannot use inference endpoint [" + + inferenceId + + "] with task type [" + + resolvedInference.taskType() + + "] within a " + + context.functionRegistry().snapshotRegistry().functionName(inferenceFunction.getClass()) + + " function. Only inference endpoints with the task type [" + + inferenceFunction.taskType() + + "] are supported."; + return inferenceFunction.withInferenceResolutionError(inferenceId, error); + } + } + + return inferenceFunction; + } } private static class AddImplicitLimit extends ParameterizedRule { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java index fbfa18dccc477..78a8ca5483246 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java @@ -9,6 +9,7 @@ import org.elasticsearch.index.IndexMode; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.enrich.EnrichPolicy; import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.core.type.EsField; @@ -26,6 +27,7 @@ import org.elasticsearch.xpack.esql.session.Configuration; import java.util.ArrayList; +import java.util.Arrays; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -195,14 +197,39 @@ public static EnrichResolution defaultEnrichResolution() { return enrichResolution; } + public static final String RERANKING_INFERENCE_ID = "reranking-inference-id"; + public static final String COMPLETION_INFERENCE_ID = "completion-inference-id"; + public static final String TEXT_EMBEDDING_INFERENCE_ID = "text-embedding-inference-id"; + public static final String CHAT_COMPLETION_INFERENCE_ID = "chat-completion-inference-id"; + public static final String SPARSE_EMBEDDING_INFERENCE_ID = "sparse-embedding-inference-id"; + public static final List VALID_INFERENCE_IDS = List.of( + RERANKING_INFERENCE_ID, + COMPLETION_INFERENCE_ID, + TEXT_EMBEDDING_INFERENCE_ID, + CHAT_COMPLETION_INFERENCE_ID, + SPARSE_EMBEDDING_INFERENCE_ID + ); + public static final String ERROR_INFERENCE_ID = "error-inference-id"; + public static InferenceResolution defaultInferenceResolution() { return InferenceResolution.builder() - .withResolvedInference(new ResolvedInference("reranking-inference-id", TaskType.RERANK)) - .withResolvedInference(new ResolvedInference("completion-inference-id", TaskType.COMPLETION)) - .withError("error-inference-id", "error with inference resolution") + .withResolvedInference(new ResolvedInference(RERANKING_INFERENCE_ID, TaskType.RERANK)) + .withResolvedInference(new ResolvedInference(COMPLETION_INFERENCE_ID, TaskType.COMPLETION)) + .withResolvedInference(new ResolvedInference(TEXT_EMBEDDING_INFERENCE_ID, TaskType.TEXT_EMBEDDING)) + .withResolvedInference(new ResolvedInference(CHAT_COMPLETION_INFERENCE_ID, TaskType.CHAT_COMPLETION)) + .withResolvedInference(new ResolvedInference(SPARSE_EMBEDDING_INFERENCE_ID, TaskType.SPARSE_EMBEDDING)) + .withError(ERROR_INFERENCE_ID, "error with inference resolution") .build(); } + public static String randomInferenceId() { + return ESTestCase.randomFrom(VALID_INFERENCE_IDS); + } + + public static String randomInferenceId(String... excludes) { + return ESTestCase.randomValueOtherThanMany(Arrays.asList(excludes)::contains, AnalyzerTestUtils::randomInferenceId); + } + public static void loadEnrichPolicyResolution( EnrichResolution enrich, String policyType, diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index 439c1bc189e3b..8d898411b0ab1 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -55,6 +55,7 @@ import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString; import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket; import org.elasticsearch.xpack.esql.expression.function.grouping.TBucket; +import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDateNanos; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDatetime; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger; @@ -121,6 +122,7 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.referenceAttribute; import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning; import static org.elasticsearch.xpack.esql.analysis.Analyzer.NO_FIELDS; +import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.TEXT_EMBEDDING_INFERENCE_ID; import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.analyze; import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.analyzer; import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.analyzerDefaultMapping; @@ -128,6 +130,7 @@ import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.defaultInferenceResolution; import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.indexWithDateDateNanosUnionType; import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.loadMapping; +import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.randomInferenceId; import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.tsdbIndexResolution; import static org.elasticsearch.xpack.esql.core.plugin.EsqlCorePlugin.DENSE_VECTOR_FEATURE_FLAG; import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; @@ -3629,6 +3632,115 @@ private void assertEmptyEsRelation(LogicalPlan plan) { assertThat(esRelation.output(), equalTo(NO_FIELDS)); } + public void testTextEmbeddingResolveInferenceId() { + assumeTrue("TEXT_EMBEDDING function required", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()); + + LogicalPlan plan = analyze( + """ + FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING("italian food recipe", "%s")""".formatted( + TEXT_EMBEDDING_INFERENCE_ID + ), + "mapping-books.json" + ); + + Eval eval = as(as(plan, Limit.class).child(), Eval.class); + assertThat(eval.fields(), hasSize(1)); + Alias alias = as(eval.fields().get(0), Alias.class); + assertThat(alias.name(), equalTo("embedding")); + TextEmbedding function = as(alias.child(), TextEmbedding.class); + + assertThat(function.inputText(), equalTo(string("italian food recipe"))); + assertThat(function.inferenceId(), equalTo(string(TEXT_EMBEDDING_INFERENCE_ID))); + } + + public void testTextEmbeddingFunctionResolveType() { + assumeTrue("TEXT_EMBEDDING function required", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()); + + LogicalPlan plan = analyze( + """ + FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING("italian food recipe", "%s")""".formatted( + TEXT_EMBEDDING_INFERENCE_ID + ), + "mapping-books.json" + ); + + Eval eval = as(as(plan, Limit.class).child(), Eval.class); + assertThat(eval.fields(), hasSize(1)); + Alias alias = as(eval.fields().get(0), Alias.class); + assertThat(alias.name(), equalTo("embedding")); + + TextEmbedding function = as(alias.child(), TextEmbedding.class); + + assertThat(function.foldable(), equalTo(true)); + assertThat(function.dataType(), equalTo(DENSE_VECTOR)); + } + + public void testTextEmbeddingFunctionMissingInferenceIdError() { + assumeTrue("TEXT_EMBEDDING function required", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()); + + VerificationException ve = expectThrows( + VerificationException.class, + () -> analyze( + """ + FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING("italian food recipe", "%s")""".formatted( + "unknow-inference-id" + ), + "mapping-books.json" + ) + ); + + assertThat(ve.getMessage(), containsString("unresolved inference [unknow-inference-id]")); + } + + public void testTextEmbeddingFunctionInvalidInferenceIdError() { + assumeTrue("TEXT_EMBEDDING function required", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()); + + String inferenceId = randomInferenceId(TEXT_EMBEDDING_INFERENCE_ID); + VerificationException ve = expectThrows( + VerificationException.class, + () -> analyze( + """ + FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING("italian food recipe", "%s")""".formatted(inferenceId), + "mapping-books.json" + ) + ); + + assertThat(ve.getMessage(), containsString("cannot use inference endpoint [%s] with task type".formatted(inferenceId))); + } + + public void testTextEmbeddingFunctionWithoutModel() { + assumeTrue("TEXT_EMBEDDING function required", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()); + + ParsingException ve = expectThrows(ParsingException.class, () -> analyze(""" + FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING("italian food recipe")""", "mapping-books.json")); + + assertThat( + ve.getMessage(), + containsString(" error building [text_embedding]: function [text_embedding] expects exactly two arguments") + ); + } + + public void testKnnFunctionWithTextEmbedding() { + assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V5.isEnabled()); + assumeTrue("TEXT_EMBEDDING function required", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()); + + String fieldName = randomFrom("float_vector", "byte_vector"); + + LogicalPlan plan = analyze(""" + from test | where KNN(%s, TEXT_EMBEDDING("italian food recipe", "%s")) + """.formatted(fieldName, TEXT_EMBEDDING_INFERENCE_ID), "mapping-dense_vector.json"); + + Limit limit = as(plan, Limit.class); + Filter filter = as(limit.child(), Filter.class); + Knn knn = as(filter.condition(), Knn.class); + assertThat(knn.field(), instanceOf(FieldAttribute.class)); + assertThat(((FieldAttribute) knn.field()).name(), equalTo(fieldName)); + + TextEmbedding textEmbedding = as(knn.query(), TextEmbedding.class); + assertThat(textEmbedding.inputText(), equalTo(string("italian food recipe"))); + assertThat(textEmbedding.inferenceId(), equalTo(string(TEXT_EMBEDDING_INFERENCE_ID))); + } + public void testResolveRerankInferenceId() { { LogicalPlan plan = analyze(""" diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index d5d3507928e84..609da9db9bc49 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -41,6 +41,7 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_CFG; import static org.elasticsearch.xpack.esql.EsqlTestUtils.paramAsConstant; import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning; +import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.TEXT_EMBEDDING_INFERENCE_ID; import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.loadMapping; import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN; import static org.elasticsearch.xpack.esql.core.type.DataType.CARTESIAN_POINT; @@ -2434,6 +2435,53 @@ public void testInvalidTBucketCalls() { } } + public void testTextEmbeddingFunctionInvalidQuery() { + assertThat( + error("from test | EVAL embedding = TEXT_EMBEDDING(null, ?)", defaultAnalyzer, TEXT_EMBEDDING_INFERENCE_ID), + equalTo("1:30: first argument of [TEXT_EMBEDDING(null, ?)] cannot be null, received [null]") + ); + + assertThat( + error("from test | EVAL embedding = TEXT_EMBEDDING(42, ?)", defaultAnalyzer, TEXT_EMBEDDING_INFERENCE_ID), + equalTo("1:30: first argument of [TEXT_EMBEDDING(42, ?)] must be [string], found value [42] type [integer]") + ); + + assertThat( + error("from test | EVAL embedding = TEXT_EMBEDDING(last_name, ?)", defaultAnalyzer, TEXT_EMBEDDING_INFERENCE_ID), + equalTo("1:30: first argument of [TEXT_EMBEDDING(last_name, ?)] must be a constant, received [last_name]") + ); + } + + public void testTextEmbeddingFunctionInvalidInferenceId() { + assertThat( + error("from test | EVAL embedding = TEXT_EMBEDDING(?, null)", defaultAnalyzer, "query text"), + equalTo("1:30: second argument of [TEXT_EMBEDDING(?, null)] cannot be null, received [null]") + ); + + assertThat( + error("from test | EVAL embedding = TEXT_EMBEDDING(?, 42)", defaultAnalyzer, "query text"), + equalTo("1:30: second argument of [TEXT_EMBEDDING(?, 42)] must be [string], found value [42] type [integer]") + ); + + assertThat( + error("from test | EVAL embedding = TEXT_EMBEDDING(?, last_name)", defaultAnalyzer, "query text"), + equalTo("1:30: second argument of [TEXT_EMBEDDING(?, last_name)] must be a constant, received [last_name]") + ); + } + + // public void testTextEmbeddingFunctionInvalidInferenceId() { + // assumeTrue("TEXT_EMBEDDING function required", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()); + // + // ParsingException ve = expectThrows(ParsingException.class, () -> analyze(""" + // FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING("italian food recipe", CONCAT("machin", title))""", + // "mapping-books.json")); + // + // assertThat( + // ve.getMessage(), + // containsString(" error building [text_embedding]: function [text_embedding] expects exactly two arguments") + // ); + // } + private void checkVectorFunctionsNullArgs(String functionInvocation) throws Exception { query("from test | eval similarity = " + functionInvocation, fullTextAnalyzer); } From 8056fde5376e67ca9de8d1bd0082855e61563d0b Mon Sep 17 00:00:00 2001 From: afoucret Date: Thu, 11 Sep 2025 20:00:55 +0200 Subject: [PATCH 05/26] Fix test can fail if byte or bit vectors are not supported --- .../xpack/esql/analysis/AnalyzerTests.java | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index 154f3639f4dcb..0362b34d05219 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -3807,20 +3807,18 @@ public void testTextEmbeddingFunctionWithoutModel() { } public void testKnnFunctionWithTextEmbedding() { - assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V5.isEnabled()); + assumeTrue("KNN function capability required", EsqlCapabilities.Cap.KNN_FUNCTION_V5.isEnabled()); assumeTrue("TEXT_EMBEDDING function required", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()); - String fieldName = randomFrom("float_vector", "byte_vector"); - LogicalPlan plan = analyze(""" - from test | where KNN(%s, TEXT_EMBEDDING("italian food recipe", "%s")) - """.formatted(fieldName, TEXT_EMBEDDING_INFERENCE_ID), "mapping-dense_vector.json"); + from test | where KNN(float_vector, TEXT_EMBEDDING("italian food recipe", "%s")) + """.formatted(TEXT_EMBEDDING_INFERENCE_ID), "mapping-dense_vector.json"); Limit limit = as(plan, Limit.class); Filter filter = as(limit.child(), Filter.class); Knn knn = as(filter.condition(), Knn.class); assertThat(knn.field(), instanceOf(FieldAttribute.class)); - assertThat(((FieldAttribute) knn.field()).name(), equalTo(fieldName)); + assertThat(((FieldAttribute) knn.field()).name(), equalTo("float_vector")); TextEmbedding textEmbedding = as(knn.query(), TextEmbedding.class); assertThat(textEmbedding.inputText(), equalTo(string("italian food recipe"))); From c4a08715edb7d248d3a9bad2864c9932d160cdeb Mon Sep 17 00:00:00 2001 From: afoucret Date: Sat, 13 Sep 2025 06:46:16 +0200 Subject: [PATCH 06/26] Add the text_embedding function to xpack usage tests --- .../resources/rest-api-spec/test/esql/60_usage.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml index bd61947411df2..d71805e6f9d82 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml @@ -130,7 +130,7 @@ setup: - match: {esql.functions.coalesce: $functions_coalesce} - gt: {esql.functions.categorize: $functions_categorize} # Testing for the entire function set isn't feasible, so we just check that we return the correct count as an approximation. - - length: {esql.functions: 176} # check the "sister" test below for a likely update to the same esql.functions length check + - length: {esql.functions: 177} # check the "sister" test below for a likely update to the same esql.functions length check --- "Basic ESQL usage output (telemetry) non-snapshot version": - requires: @@ -184,6 +184,7 @@ setup: - set: {esql.functions.to_long: functions_to_long} - set: {esql.functions.coalesce: functions_coalesce} - set: {esql.functions.categorize: functions_categorize} + - set: {esql.functions.text_embedding: functions_text_embedding} - do: esql.query: @@ -226,6 +227,7 @@ setup: - match: {esql.functions.coalesce: $functions_coalesce} - gt: {esql.functions.categorize: $functions_categorize} - length: {esql.functions: 150} # check the "sister" test above for a likely update to the same esql.functions length check + - match: {esql.functions.text_embedding: functions_text_embedding} --- took: From cee29352b4a951e6876640d1be840153837f3e53 Mon Sep 17 00:00:00 2001 From: afoucret Date: Sat, 13 Sep 2025 06:56:04 +0200 Subject: [PATCH 07/26] Fix error in xpack usage --- .../yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml index a75e02f3f7ac3..b8d2cbca65799 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml @@ -131,7 +131,7 @@ setup: - gt: {esql.functions.to_long: $functions_to_long} - match: {esql.functions.coalesce: $functions_coalesce} - gt: {esql.functions.categorize: $functions_categorize} - - match: {esql.functions.text_embedding: functions_text_embedding} + - match: {esql.functions.text_embedding: $functions_text_embedding} # Testing for the entire function set isn't feasible, so we just check that we return the correct count as an approximation. - length: {esql.functions: 180} # check the "sister" test below for a likely update to the same esql.functions length check --- From 9d7a23fcfd3d8612bfaed9753305d71eac0cf323 Mon Sep 17 00:00:00 2001 From: afoucret Date: Sat, 13 Sep 2025 07:58:01 +0200 Subject: [PATCH 08/26] Create the text embedding request iterator --- .../function/EsqlFunctionRegistry.java | 4 +- .../xpack/esql/inference/InputTextReader.java | 76 ++++++ .../CompletionOperatorRequestIterator.java | 62 +---- .../TextEmbeddingOperatorRequestIterator.java | 81 ++++++ .../esql/inference/InputTextReaderTests.java | 231 ++++++++++++++++++ ...EmbeddingOperatorRequestIteratorTests.java | 167 +++++++++++++ 6 files changed, 563 insertions(+), 58 deletions(-) create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InputTextReader.java create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorRequestIterator.java create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InputTextReaderTests.java create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorRequestIteratorTests.java diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index 408d2cd4554a1..fd11bf3e4c15e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -272,7 +272,7 @@ public class EsqlFunctionRegistry { } // Translation table for error messaging in the following function - private static final String[] NUM_NAMES = {"zero", "one", "two", "three", "four", "five", "six"}; + private static final String[] NUM_NAMES = { "zero", "one", "two", "three", "four", "five", "six" }; // list of functions grouped by type of functions (aggregate, statistics, math etc) and ordered alphabetically inside each group // a single function will have one entry for itself with its name associated to its instance and, also, one entry for each alias @@ -353,7 +353,7 @@ private static FunctionDefinition[][] functions() { def(Values.class, uni(Values::new), "values"), def(WeightedAvg.class, bi(WeightedAvg::new), "weighted_avg"), def(Present.class, uni(Present::new), "present"), - def(Absent.class, uni(Absent::new), "absent")}, + def(Absent.class, uni(Absent::new), "absent") }, // math new FunctionDefinition[] { def(Abs.class, Abs::new, "abs"), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InputTextReader.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InputTextReader.java new file mode 100644 index 0000000000000..d7bf43cdfbdd5 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InputTextReader.java @@ -0,0 +1,76 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; + +/** + * Helper class that reads text strings from a {@link BytesRefBlock}. + * This class is used by inference operators to extract text content from block data. + */ +public class InputTextReader implements Releasable { + private final BytesRefBlock textBlock; + private final StringBuilder strBuilder = new StringBuilder(); + private BytesRef readBuffer = new BytesRef(); + + public InputTextReader(BytesRefBlock textBlock) { + this.textBlock = textBlock; + } + + /** + * Reads the text string at the given position. + * Multiple values at the position are concatenated with newlines. + * + * @param pos the position index in the block + * @return the text string at the position, or null if the position contains a null value + */ + public String readText(int pos) { + return readText(pos, Integer.MAX_VALUE); + } + + /** + * Reads the text string at the given position. + * + * @param pos the position index in the block + * @param limit the maximum number of value to read from the position + * @return the text string at the position, or null if the position contains a null value + */ + public String readText(int pos, int limit) { + if (textBlock.isNull(pos)) { + return null; + } + + strBuilder.setLength(0); + int maxPos = Math.min(limit, textBlock.getValueCount(pos)); + for (int valueIndex = 0; valueIndex < maxPos; valueIndex++) { + readBuffer = textBlock.getBytesRef(textBlock.getFirstValueIndex(pos) + valueIndex, readBuffer); + strBuilder.append(readBuffer.utf8ToString()); + if (valueIndex != maxPos - 1) { + strBuilder.append("\n"); + } + } + + return strBuilder.toString(); + } + + /** + * Returns the total number of positions (text entries) in the block. + */ + public int estimatedSize() { + return textBlock.getPositionCount(); + } + + @Override + public void close() { + textBlock.allowPassingToDifferentDriver(); + Releasables.close(textBlock); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java index f526cd9edb077..509110e0ffe2b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java @@ -7,12 +7,11 @@ package org.elasticsearch.xpack.esql.inference.completion; -import org.apache.lucene.util.BytesRef; import org.elasticsearch.compute.data.BytesRefBlock; -import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.esql.inference.InputTextReader; import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator; import java.util.List; @@ -24,7 +23,7 @@ */ public class CompletionOperatorRequestIterator implements BulkInferenceRequestIterator { - private final PromptReader promptReader; + private final InputTextReader textReader; private final String inferenceId; private final int size; private int currentPos = 0; @@ -36,7 +35,7 @@ public class CompletionOperatorRequestIterator implements BulkInferenceRequestIt * @param inferenceId The ID of the inference model to invoke. */ public CompletionOperatorRequestIterator(BytesRefBlock promptBlock, String inferenceId) { - this.promptReader = new PromptReader(promptBlock); + this.textReader = new InputTextReader(promptBlock); this.size = promptBlock.getPositionCount(); this.inferenceId = inferenceId; } @@ -52,7 +51,7 @@ public InferenceAction.Request next() { throw new NoSuchElementException(); } - return inferenceRequest(promptReader.readPrompt(currentPos++)); + return inferenceRequest(textReader.readText(currentPos++)); } /** @@ -68,60 +67,11 @@ private InferenceAction.Request inferenceRequest(String prompt) { @Override public int estimatedSize() { - return promptReader.estimatedSize(); + return textReader.estimatedSize(); } @Override public void close() { - Releasables.close(promptReader); - } - - /** - * Helper class that reads prompts from a {@link BytesRefBlock}. - */ - private static class PromptReader implements Releasable { - private final BytesRefBlock promptBlock; - private final StringBuilder strBuilder = new StringBuilder(); - private BytesRef readBuffer = new BytesRef(); - - private PromptReader(BytesRefBlock promptBlock) { - this.promptBlock = promptBlock; - } - - /** - * Reads the prompt string at the given position.. - * - * @param pos the position index in the block - */ - public String readPrompt(int pos) { - if (promptBlock.isNull(pos)) { - return null; - } - - strBuilder.setLength(0); - - for (int valueIndex = 0; valueIndex < promptBlock.getValueCount(pos); valueIndex++) { - readBuffer = promptBlock.getBytesRef(promptBlock.getFirstValueIndex(pos) + valueIndex, readBuffer); - strBuilder.append(readBuffer.utf8ToString()); - if (valueIndex != promptBlock.getValueCount(pos) - 1) { - strBuilder.append("\n"); - } - } - - return strBuilder.toString(); - } - - /** - * Returns the total number of positions (prompts) in the block. - */ - public int estimatedSize() { - return promptBlock.getPositionCount(); - } - - @Override - public void close() { - promptBlock.allowPassingToDifferentDriver(); - Releasables.close(promptBlock); - } + Releasables.close(textReader); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorRequestIterator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorRequestIterator.java new file mode 100644 index 0000000000000..4860f7fdcb6ec --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorRequestIterator.java @@ -0,0 +1,81 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.textembedding; + +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.esql.inference.InputTextReader; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator; + +import java.util.List; +import java.util.NoSuchElementException; + +/** + * This iterator reads text inputs from a {@link BytesRefBlock} and converts them into individual {@link InferenceAction.Request} instances + * of type {@link TaskType#TEXT_EMBEDDING}. + */ +public class TextEmbeddingOperatorRequestIterator implements BulkInferenceRequestIterator { + + private final InputTextReader textReader; + private final String inferenceId; + private final int size; + private int currentPos = 0; + + /** + * Constructs a new iterator from the given block of text inputs. + * + * @param textBlock The input block containing text to embed. + * @param inferenceId The ID of the inference model to invoke. + */ + public TextEmbeddingOperatorRequestIterator(BytesRefBlock textBlock, String inferenceId) { + this.textReader = new InputTextReader(textBlock); + this.size = textBlock.getPositionCount(); + this.inferenceId = inferenceId; + } + + @Override + public boolean hasNext() { + return currentPos < size; + } + + @Override + public InferenceAction.Request next() { + if (hasNext() == false) { + throw new NoSuchElementException(); + } + + /* + * Keep only the first value in case of multi-valued fields. + * TODO: check if it is consistent with how the query vector builder is working. + */ + return inferenceRequest(textReader.readText(currentPos++, 1)); + } + + /** + * Wraps a single text string into an {@link InferenceAction.Request} for text embedding. + */ + private InferenceAction.Request inferenceRequest(String text) { + if (text == null) { + return null; + } + + return InferenceAction.Request.builder(inferenceId, TaskType.TEXT_EMBEDDING).setInput(List.of(text)).build(); + } + + @Override + public int estimatedSize() { + return textReader.estimatedSize(); + } + + @Override + public void close() { + Releasables.close(textReader); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InputTextReaderTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InputTextReaderTests.java new file mode 100644 index 0000000000000..64dd6a02928e8 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InputTextReaderTests.java @@ -0,0 +1,231 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.test.ComputeTestCase; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.nullValue; + +public class InputTextReaderTests extends ComputeTestCase { + + public void testReadSingleValuePositions() throws Exception { + String[] texts = { "hello", "world", "test" }; + BytesRefBlock block = createSingleValueBlock(texts); + + try (InputTextReader reader = new InputTextReader(block)) { + assertThat(reader.estimatedSize(), equalTo(texts.length)); + + for (int i = 0; i < texts.length; i++) { + assertThat(reader.readText(i), equalTo(texts[i])); + } + } + + allBreakersEmpty(); + } + + public void testReadMultiValuePositions() throws Exception { + BytesRefBlock block = createMultiValueBlock(); + + try (InputTextReader reader = new InputTextReader(block)) { + assertThat(reader.estimatedSize(), equalTo(2)); + + // First position has multiple values that should be concatenated with newlines + assertThat(reader.readText(0), equalTo("first\nsecond\nthird")); + + // Second position has a single value + assertThat(reader.readText(1), equalTo("single")); + } + + allBreakersEmpty(); + } + + public void testReadMultiValuePositionsWithLimit() throws Exception { + BytesRefBlock block = createMultiValueBlock(); + + try (InputTextReader reader = new InputTextReader(block)) { + // Test limiting to first 2 values out of 3 + assertThat(reader.readText(0, 2), equalTo("first\nsecond")); + + // Test limiting to first 1 value out of 3 + assertThat(reader.readText(0, 1), equalTo("first")); + + // Test limit larger than available values + assertThat(reader.readText(0, 10), equalTo("first\nsecond\nthird")); + + // Test limit of 0 + assertThat(reader.readText(0, 0), equalTo("")); + + // Test single value position with limit + assertThat(reader.readText(1, 1), equalTo("single")); + } + + allBreakersEmpty(); + } + + public void testReadNullValues() throws Exception { + BytesRefBlock block = createBlockWithNulls(); + + try (InputTextReader reader = new InputTextReader(block)) { + assertThat(reader.estimatedSize(), equalTo(3)); + + assertThat(reader.readText(0), equalTo("before")); + assertThat(reader.readText(1), nullValue()); + assertThat(reader.readText(2), equalTo("after")); + } + + allBreakersEmpty(); + } + + public void testReadNullValuesWithLimit() throws Exception { + BytesRefBlock block = createBlockWithNulls(); + + try (InputTextReader reader = new InputTextReader(block)) { + // Null values should return null regardless of limit + assertThat(reader.readText(0, 1), equalTo("before")); + assertThat(reader.readText(1, 1), nullValue()); + assertThat(reader.readText(1, 10), nullValue()); + assertThat(reader.readText(2, 1), equalTo("after")); + } + + allBreakersEmpty(); + } + + public void testReadEmptyStrings() throws Exception { + String[] texts = { "", "non-empty", "" }; + BytesRefBlock block = createSingleValueBlock(texts); + + try (InputTextReader reader = new InputTextReader(block)) { + for (int i = 0; i < texts.length; i++) { + assertThat(reader.readText(i), equalTo(texts[i])); + assertThat(reader.readText(i, 1), equalTo(texts[i])); + } + } + + allBreakersEmpty(); + } + + public void testReadLargeInput() throws Exception { + int size = between(1000, 5000); + String[] texts = new String[size]; + for (int i = 0; i < size; i++) { + texts[i] = "text_" + i + "_" + randomAlphaOfLength(10); + } + + BytesRefBlock block = createSingleValueBlock(texts); + + try (InputTextReader reader = new InputTextReader(block)) { + assertThat(reader.estimatedSize(), equalTo(size)); + + for (int i = 0; i < size; i++) { + assertThat(reader.readText(i), equalTo(texts[i])); + assertThat(reader.readText(i, 1), equalTo(texts[i])); + } + } + + allBreakersEmpty(); + } + + public void testReadUnicodeText() throws Exception { + String[] texts = { "café", "naïve", "résumé", "🚀 rocket", "多语言支持" }; + BytesRefBlock block = createSingleValueBlock(texts); + + try (InputTextReader reader = new InputTextReader(block)) { + for (int i = 0; i < texts.length; i++) { + assertThat(reader.readText(i), equalTo(texts[i])); + assertThat(reader.readText(i, 1), equalTo(texts[i])); + } + } + + allBreakersEmpty(); + } + + public void testReadMultipleTimesFromSamePosition() throws Exception { + String[] texts = { "consistent" }; + BytesRefBlock block = createSingleValueBlock(texts); + + try (InputTextReader reader = new InputTextReader(block)) { + // Reading the same position multiple times should return the same result + assertThat(reader.readText(0), equalTo("consistent")); + assertThat(reader.readText(0), equalTo("consistent")); + assertThat(reader.readText(0, 1), equalTo("consistent")); + assertThat(reader.readText(0, 10), equalTo("consistent")); + } + + allBreakersEmpty(); + } + + public void testLimitBoundaryConditions() throws Exception { + BytesRefBlock block = createLargeMultiValueBlock(); + + try (InputTextReader reader = new InputTextReader(block)) { + // Test various limit values on a position with 5 values + assertThat(reader.readText(0, 0), equalTo("")); + assertThat(reader.readText(0, 1), equalTo("value0")); + assertThat(reader.readText(0, 2), equalTo("value0\nvalue1")); + assertThat(reader.readText(0, 3), equalTo("value0\nvalue1\nvalue2")); + assertThat(reader.readText(0, 4), equalTo("value0\nvalue1\nvalue2\nvalue3")); + assertThat(reader.readText(0, 5), equalTo("value0\nvalue1\nvalue2\nvalue3\nvalue4")); + + // Test limit beyond available values + assertThat(reader.readText(0, 10), equalTo("value0\nvalue1\nvalue2\nvalue3\nvalue4")); + assertThat(reader.readText(0, Integer.MAX_VALUE), equalTo("value0\nvalue1\nvalue2\nvalue3\nvalue4")); + } + + allBreakersEmpty(); + } + + private BytesRefBlock createSingleValueBlock(String[] texts) { + try (BytesRefBlock.Builder builder = blockFactory().newBytesRefBlockBuilder(texts.length)) { + for (String text : texts) { + builder.appendBytesRef(new BytesRef(text)); + } + return builder.build(); + } + } + + private BytesRefBlock createMultiValueBlock() { + try (BytesRefBlock.Builder builder = blockFactory().newBytesRefBlockBuilder(2)) { + // First position: multiple values + builder.beginPositionEntry(); + builder.appendBytesRef(new BytesRef("first")); + builder.appendBytesRef(new BytesRef("second")); + builder.appendBytesRef(new BytesRef("third")); + builder.endPositionEntry(); + + // Second position: single value + builder.appendBytesRef(new BytesRef("single")); + + return builder.build(); + } + } + + private BytesRefBlock createBlockWithNulls() { + try (BytesRefBlock.Builder builder = blockFactory().newBytesRefBlockBuilder(3)) { + builder.appendBytesRef(new BytesRef("before")); + builder.appendNull(); + builder.appendBytesRef(new BytesRef("after")); + return builder.build(); + } + } + + private BytesRefBlock createLargeMultiValueBlock() { + try (BytesRefBlock.Builder builder = blockFactory().newBytesRefBlockBuilder(1)) { + // Single position with 5 values for testing limits + builder.beginPositionEntry(); + for (int i = 0; i < 5; i++) { + builder.appendBytesRef(new BytesRef("value" + i)); + } + builder.endPositionEntry(); + + return builder.build(); + } + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorRequestIteratorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorRequestIteratorTests.java new file mode 100644 index 0000000000000..a27ee95b941d7 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorRequestIteratorTests.java @@ -0,0 +1,167 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.textembedding; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.test.ComputeTestCase; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.nullValue; + +public class TextEmbeddingOperatorRequestIteratorTests extends ComputeTestCase { + + public void testIterateSmallInput() throws Exception { + assertIterate(between(1, 100)); + } + + public void testIterateLargeInput() throws Exception { + assertIterate(between(10_000, 100_000)); + } + + public void testIterateWithNullValues() throws Exception { + final String inferenceId = randomIdentifier(); + final BytesRefBlock inputBlock = createBlockWithNulls(); + + try (TextEmbeddingOperatorRequestIterator requestIterator = new TextEmbeddingOperatorRequestIterator(inputBlock, inferenceId)) { + BytesRef scratch = new BytesRef(); + + // First position: "before" + InferenceAction.Request request1 = requestIterator.next(); + assertThat(request1.getInferenceEntityId(), equalTo(inferenceId)); + assertThat(request1.getTaskType(), equalTo(TaskType.TEXT_EMBEDDING)); + scratch = inputBlock.getBytesRef(inputBlock.getFirstValueIndex(0), scratch); + assertThat(request1.getInput().get(0), equalTo(scratch.utf8ToString())); + + // Second position: null + InferenceAction.Request request2 = requestIterator.next(); + assertThat(request2, nullValue()); + + // Third position: "after" + InferenceAction.Request request3 = requestIterator.next(); + assertThat(request3.getInferenceEntityId(), equalTo(inferenceId)); + assertThat(request3.getTaskType(), equalTo(TaskType.TEXT_EMBEDDING)); + scratch = inputBlock.getBytesRef(inputBlock.getFirstValueIndex(2), scratch); + assertThat(request3.getInput().get(0), equalTo(scratch.utf8ToString())); + } + + allBreakersEmpty(); + } + + public void testIterateWithMultiValuePositions() throws Exception { + final String inferenceId = randomIdentifier(); + final BytesRefBlock inputBlock = createMultiValueBlock(); + + try (TextEmbeddingOperatorRequestIterator requestIterator = new TextEmbeddingOperatorRequestIterator(inputBlock, inferenceId)) { + // First position: multi-value concatenated with newlines + InferenceAction.Request request1 = requestIterator.next(); + assertThat(request1.getInferenceEntityId(), equalTo(inferenceId)); + assertThat(request1.getTaskType(), equalTo(TaskType.TEXT_EMBEDDING)); + assertThat(request1.getInput().get(0), equalTo("first")); + + // Second position: single value + InferenceAction.Request request2 = requestIterator.next(); + assertThat(request2.getInferenceEntityId(), equalTo(inferenceId)); + assertThat(request2.getTaskType(), equalTo(TaskType.TEXT_EMBEDDING)); + assertThat(request2.getInput().get(0), equalTo("single")); + } + + allBreakersEmpty(); + } + + public void testEstimatedSize() throws Exception { + final String inferenceId = randomIdentifier(); + final int size = randomIntBetween(10, 1000); + final BytesRefBlock inputBlock = randomInputBlock(size); + + try (TextEmbeddingOperatorRequestIterator requestIterator = new TextEmbeddingOperatorRequestIterator(inputBlock, inferenceId)) { + assertThat(requestIterator.estimatedSize(), equalTo(size)); + } + + allBreakersEmpty(); + } + + public void testHasNextAndIteration() throws Exception { + final String inferenceId = randomIdentifier(); + final int size = randomIntBetween(5, 50); + final BytesRefBlock inputBlock = randomInputBlock(size); + + try (TextEmbeddingOperatorRequestIterator requestIterator = new TextEmbeddingOperatorRequestIterator(inputBlock, inferenceId)) { + int count = 0; + while (requestIterator.hasNext()) { + requestIterator.next(); + count++; + } + assertThat(count, equalTo(size)); + + // Verify hasNext returns false after iteration is complete + assertThat(requestIterator.hasNext(), equalTo(false)); + } + + allBreakersEmpty(); + } + + private void assertIterate(int size) throws Exception { + final String inferenceId = randomIdentifier(); + final BytesRefBlock inputBlock = randomInputBlock(size); + + try (TextEmbeddingOperatorRequestIterator requestIterator = new TextEmbeddingOperatorRequestIterator(inputBlock, inferenceId)) { + BytesRef scratch = new BytesRef(); + + for (int currentPos = 0; requestIterator.hasNext(); currentPos++) { + InferenceAction.Request request = requestIterator.next(); + + assertThat(request.getInferenceEntityId(), equalTo(inferenceId)); + assertThat(request.getTaskType(), equalTo(TaskType.TEXT_EMBEDDING)); + + // Verify the input text matches what's in the block + scratch = inputBlock.getBytesRef(inputBlock.getFirstValueIndex(currentPos), scratch); + assertThat(request.getInput().get(0), equalTo(scratch.utf8ToString())); + } + } + + allBreakersEmpty(); + } + + private BytesRefBlock randomInputBlock(int size) { + try (BytesRefBlock.Builder blockBuilder = blockFactory().newBytesRefBlockBuilder(size)) { + for (int i = 0; i < size; i++) { + blockBuilder.appendBytesRef(new BytesRef(randomAlphaOfLength(10))); + } + + return blockBuilder.build(); + } + } + + private BytesRefBlock createBlockWithNulls() { + try (BytesRefBlock.Builder builder = blockFactory().newBytesRefBlockBuilder(3)) { + builder.appendBytesRef(new BytesRef("before")); + builder.appendNull(); + builder.appendBytesRef(new BytesRef("after")); + return builder.build(); + } + } + + private BytesRefBlock createMultiValueBlock() { + try (BytesRefBlock.Builder builder = blockFactory().newBytesRefBlockBuilder(2)) { + // First position: multiple values + builder.beginPositionEntry(); + builder.appendBytesRef(new BytesRef("first")); + builder.appendBytesRef(new BytesRef("second")); + builder.appendBytesRef(new BytesRef("third")); + builder.endPositionEntry(); + + // Second position: single value + builder.appendBytesRef(new BytesRef("single")); + + return builder.build(); + } + } +} From b8c5f111663080a2a892e7c7b8b07a1c9aaa4dba Mon Sep 17 00:00:00 2001 From: afoucret Date: Sat, 13 Sep 2025 10:17:05 +0200 Subject: [PATCH 09/26] Clean analyzer tests to avoid forbidden api usage. --- .../xpack/esql/analysis/AnalyzerTests.java | 35 +++++++++---------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index df3ced85efdf3..fb0f12267b2eb 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -3767,10 +3767,8 @@ public void testTextEmbeddingResolveInferenceId() { assumeTrue("TEXT_EMBEDDING function required", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()); LogicalPlan plan = analyze( - """ - FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING("italian food recipe", "%s")""".formatted( - TEXT_EMBEDDING_INFERENCE_ID - ), + String.format(Locale.ROOT, """ + FROM books METADATA _score | EVAL embedding = TEXT_EMBEDDING("italian food recipe", "%s")""", TEXT_EMBEDDING_INFERENCE_ID), "mapping-books.json" ); @@ -3788,10 +3786,8 @@ public void testTextEmbeddingFunctionResolveType() { assumeTrue("TEXT_EMBEDDING function required", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()); LogicalPlan plan = analyze( - """ - FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING("italian food recipe", "%s")""".formatted( - TEXT_EMBEDDING_INFERENCE_ID - ), + String.format(Locale.ROOT, """ + FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING("italian food recipe", "%s")""", TEXT_EMBEDDING_INFERENCE_ID), "mapping-books.json" ); @@ -3812,10 +3808,8 @@ public void testTextEmbeddingFunctionMissingInferenceIdError() { VerificationException ve = expectThrows( VerificationException.class, () -> analyze( - """ - FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING("italian food recipe", "%s")""".formatted( - "unknow-inference-id" - ), + String.format(Locale.ROOT, """ + FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING("italian food recipe", "%s")""", "unknow-inference-id"), "mapping-books.json" ) ); @@ -3830,13 +3824,16 @@ public void testTextEmbeddingFunctionInvalidInferenceIdError() { VerificationException ve = expectThrows( VerificationException.class, () -> analyze( - """ - FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING("italian food recipe", "%s")""".formatted(inferenceId), + String.format(Locale.ROOT, """ + FROM books METADATA _score| EVAL embedding = TEXT_EMBEDDING("italian food recipe", "%s")""", inferenceId), "mapping-books.json" ) ); - assertThat(ve.getMessage(), containsString("cannot use inference endpoint [%s] with task type".formatted(inferenceId))); + assertThat( + ve.getMessage(), + containsString(String.format(Locale.ROOT, "cannot use inference endpoint [%s] with task type", inferenceId)) + ); } public void testTextEmbeddingFunctionWithoutModel() { @@ -3855,9 +3852,11 @@ public void testKnnFunctionWithTextEmbedding() { assumeTrue("KNN function capability required", EsqlCapabilities.Cap.KNN_FUNCTION_V5.isEnabled()); assumeTrue("TEXT_EMBEDDING function required", EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION.isEnabled()); - LogicalPlan plan = analyze(""" - from test | where KNN(float_vector, TEXT_EMBEDDING("italian food recipe", "%s")) - """.formatted(TEXT_EMBEDDING_INFERENCE_ID), "mapping-dense_vector.json"); + LogicalPlan plan = analyze( + String.format(Locale.ROOT, """ + from test | where KNN(float_vector, TEXT_EMBEDDING("italian food recipe", "%s"))""", TEXT_EMBEDDING_INFERENCE_ID), + "mapping-dense_vector.json" + ); Limit limit = as(plan, Limit.class); Filter filter = as(limit.child(), Filter.class); From 678438bb90d3624004b999001ff256b437a12227 Mon Sep 17 00:00:00 2001 From: afoucret Date: Sat, 13 Sep 2025 21:59:07 +0200 Subject: [PATCH 10/26] Add text embedding output builder. --- .../results/TextEmbeddingByteResults.java | 2 +- .../TextEmbeddingOperatorOutputBuilder.java | 103 ++++++++ ...xtEmbeddingOperatorOutputBuilderTests.java | 247 ++++++++++++++++++ 3 files changed, 351 insertions(+), 1 deletion(-) create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilderTests.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java index 54f858cb20ae0..d8b3fb2b9c544 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java @@ -155,7 +155,7 @@ public String toString() { return Strings.toString(this); } - float[] toFloatArray() { + public float[] toFloatArray() { float[] floatArray = new float[values.length]; for (int i = 0; i < values.length; i++) { floatArray[i] = ((Byte) values[i]).floatValue(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java new file mode 100644 index 0000000000000..521bb508c30af --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java @@ -0,0 +1,103 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.textembedding; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.esql.inference.InferenceOperator; + +/** + * {@link TextEmbeddingOperatorOutputBuilder} builds the output page for text embedding by converting + * {@link TextEmbeddingResults} into a {@link FloatBlock} containing dense vector embeddings. + */ +public class TextEmbeddingOperatorOutputBuilder implements InferenceOperator.OutputBuilder { + private final Page inputPage; + private final FloatBlock.Builder outputBlockBuilder; + + public TextEmbeddingOperatorOutputBuilder(FloatBlock.Builder outputBlockBuilder, Page inputPage) { + this.inputPage = inputPage; + this.outputBlockBuilder = outputBlockBuilder; + } + + @Override + public void close() { + Releasables.close(outputBlockBuilder); + } + + /** + * Adds an inference response to the output builder. + * + *

+ * If the response is null or not of type {@link TextEmbeddingResults} an {@link IllegalStateException} is thrown. + * Else, the embedding vector is added to the output block as a multi-value position. + *

+ * + *

+ * The responses must be added in the same order as the corresponding inference requests were generated. + * Failing to preserve order may lead to incorrect or misaligned output rows. + *

+ */ + @Override + public void addInferenceResponse(InferenceAction.Response inferenceResponse) { + if (inferenceResponse == null) { + outputBlockBuilder.appendNull(); + return; + } + + TextEmbeddingResults embeddingResults = inferenceResults(inferenceResponse); + + var embeddings = embeddingResults.embeddings(); + if (embeddings.isEmpty()) { + outputBlockBuilder.appendNull(); + return; + } + + float[] embeddingArray = getEmbeddingAsFloatArray(embeddingResults); + + outputBlockBuilder.beginPositionEntry(); + for (float component : embeddingArray) { + outputBlockBuilder.appendFloat(component); + } + outputBlockBuilder.endPositionEntry(); + } + + /** + * Builds the final output page by appending the embedding output block to the input page. + */ + @Override + public Page buildOutput() { + Block outputBlock = outputBlockBuilder.build(); + assert outputBlock.getPositionCount() == inputPage.getPositionCount(); + return inputPage.appendBlock(outputBlock); + } + + private TextEmbeddingResults inferenceResults(InferenceAction.Response inferenceResponse) { + return InferenceOperator.OutputBuilder.inferenceResults(inferenceResponse, TextEmbeddingResults.class); + } + + /** + * Extracts the embedding as a float array from the embedding result. + */ + private float[] getEmbeddingAsFloatArray(TextEmbeddingResults embedding) { + return switch (embedding.embeddings().get(0)) { + case TextEmbeddingFloatResults.Embedding floatEmbedding -> floatEmbedding.values(); + case TextEmbeddingByteResults.Embedding byteEmbedding -> byteEmbedding.toFloatArray(); + default -> throw new IllegalArgumentException( + "Unsupported embedding type: " + + embedding.embeddings().get(0).getClass().getName() + + ". Expected TextEmbeddingFloatResults.Embedding or TextEmbeddingByteResults.Embedding." + ); + }; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilderTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilderTests.java new file mode 100644 index 0000000000000..ea77c6bed3c38 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilderTests.java @@ -0,0 +1,247 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.textembedding; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.test.ComputeTestCase; +import org.elasticsearch.compute.test.RandomBlock; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; + +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; + +public class TextEmbeddingOperatorOutputBuilderTests extends ComputeTestCase { + + public void testBuildSmallOutputWithFloatEmbeddings() throws Exception { + assertBuildOutputWithFloatEmbeddings(between(1, 100)); + } + + public void testBuildLargeOutputWithFloatEmbeddings() throws Exception { + assertBuildOutputWithFloatEmbeddings(between(1_000, 10_000)); + } + + public void testBuildSmallOutputWithByteEmbeddings() throws Exception { + assertBuildOutputWithByteEmbeddings(between(1, 100)); + } + + public void testBuildLargeOutputWithByteEmbeddings() throws Exception { + assertBuildOutputWithByteEmbeddings(between(1_000, 10_000)); + } + + public void testHandleNullResponses() throws Exception { + final int size = between(10, 100); + final Page inputPage = randomInputPage(size, between(1, 5)); + + try ( + TextEmbeddingOperatorOutputBuilder outputBuilder = new TextEmbeddingOperatorOutputBuilder( + blockFactory().newFloatBlockBuilder(size), + inputPage + ) + ) { + // Add some null responses + for (int currentPos = 0; currentPos < inputPage.getPositionCount(); currentPos++) { + if (randomBoolean()) { + outputBuilder.addInferenceResponse(null); + } else { + float[] embedding = randomFloatEmbedding(randomIntBetween(50, 200)); + outputBuilder.addInferenceResponse(createFloatEmbeddingResponse(embedding)); + } + } + + final Page outputPage = outputBuilder.buildOutput(); + assertThat(outputPage.getPositionCount(), equalTo(inputPage.getPositionCount())); + assertThat(outputPage.getBlockCount(), equalTo(inputPage.getBlockCount() + 1)); + + FloatBlock outputBlock = (FloatBlock) outputPage.getBlock(outputPage.getBlockCount() - 1); + assertThat(outputBlock.getPositionCount(), equalTo(size)); + + outputPage.releaseBlocks(); + } + + allBreakersEmpty(); + } + + public void testHandleEmptyEmbeddings() throws Exception { + final int size = between(5, 50); + final Page inputPage = randomInputPage(size, between(1, 3)); + + try ( + TextEmbeddingOperatorOutputBuilder outputBuilder = new TextEmbeddingOperatorOutputBuilder( + blockFactory().newFloatBlockBuilder(size), + inputPage + ) + ) { + // Add responses with empty embeddings + for (int currentPos = 0; currentPos < inputPage.getPositionCount(); currentPos++) { + outputBuilder.addInferenceResponse(createEmptyFloatEmbeddingResponse()); + } + + final Page outputPage = outputBuilder.buildOutput(); + FloatBlock outputBlock = (FloatBlock) outputPage.getBlock(outputPage.getBlockCount() - 1); + + // All positions should be null due to empty embeddings + for (int pos = 0; pos < outputBlock.getPositionCount(); pos++) { + assertThat(outputBlock.isNull(pos), equalTo(true)); + } + + outputPage.releaseBlocks(); + } + + allBreakersEmpty(); + } + + private void assertBuildOutputWithFloatEmbeddings(int size) throws Exception { + final Page inputPage = randomInputPage(size, between(1, 10)); + final int embeddingDim = randomIntBetween(50, 1536); // Common embedding dimensions + final float[][] expectedEmbeddings = new float[size][]; + + try ( + TextEmbeddingOperatorOutputBuilder outputBuilder = new TextEmbeddingOperatorOutputBuilder( + blockFactory().newFloatBlockBuilder(size), + inputPage + ) + ) { + for (int currentPos = 0; currentPos < inputPage.getPositionCount(); currentPos++) { + float[] embedding = randomFloatEmbedding(embeddingDim); + expectedEmbeddings[currentPos] = embedding; + outputBuilder.addInferenceResponse(createFloatEmbeddingResponse(embedding)); + } + + final Page outputPage = outputBuilder.buildOutput(); + assertThat(outputPage.getPositionCount(), equalTo(inputPage.getPositionCount())); + assertThat(outputPage.getBlockCount(), equalTo(inputPage.getBlockCount() + 1)); + + assertFloatEmbeddingContent((FloatBlock) outputPage.getBlock(outputPage.getBlockCount() - 1), expectedEmbeddings); + + outputPage.releaseBlocks(); + } + + allBreakersEmpty(); + } + + private void assertBuildOutputWithByteEmbeddings(int size) throws Exception { + final Page inputPage = randomInputPage(size, between(1, 10)); + final int embeddingDim = randomIntBetween(50, 1536); + final byte[][] expectedByteEmbeddings = new byte[size][]; + + try ( + TextEmbeddingOperatorOutputBuilder outputBuilder = new TextEmbeddingOperatorOutputBuilder( + blockFactory().newFloatBlockBuilder(size), + inputPage + ) + ) { + for (int currentPos = 0; currentPos < inputPage.getPositionCount(); currentPos++) { + byte[] embedding = randomByteEmbedding(embeddingDim); + expectedByteEmbeddings[currentPos] = embedding; + outputBuilder.addInferenceResponse(createByteEmbeddingResponse(embedding)); + } + + final Page outputPage = outputBuilder.buildOutput(); + assertThat(outputPage.getPositionCount(), equalTo(inputPage.getPositionCount())); + assertThat(outputPage.getBlockCount(), equalTo(inputPage.getBlockCount() + 1)); + + assertByteEmbeddingContent((FloatBlock) outputPage.getBlock(outputPage.getBlockCount() - 1), expectedByteEmbeddings); + + outputPage.releaseBlocks(); + } + + allBreakersEmpty(); + } + + private void assertFloatEmbeddingContent(FloatBlock block, float[][] expectedEmbeddings) { + for (int currentPos = 0; currentPos < block.getPositionCount(); currentPos++) { + assertThat(block.isNull(currentPos), equalTo(false)); + assertThat(block.getValueCount(currentPos), equalTo(expectedEmbeddings[currentPos].length)); + + int firstValueIndex = block.getFirstValueIndex(currentPos); + for (int i = 0; i < expectedEmbeddings[currentPos].length; i++) { + float actualValue = block.getFloat(firstValueIndex + i); + float expectedValue = expectedEmbeddings[currentPos][i]; + assertThat(actualValue, equalTo(expectedValue)); + } + } + } + + private void assertByteEmbeddingContent(FloatBlock block, byte[][] expectedByteEmbeddings) { + for (int currentPos = 0; currentPos < block.getPositionCount(); currentPos++) { + assertThat(block.isNull(currentPos), equalTo(false)); + assertThat(block.getValueCount(currentPos), equalTo(expectedByteEmbeddings[currentPos].length)); + + int firstValueIndex = block.getFirstValueIndex(currentPos); + for (int i = 0; i < expectedByteEmbeddings[currentPos].length; i++) { + float actualValue = block.getFloat(firstValueIndex + i); + // Convert byte to float the same way as TextEmbeddingByteResults.Embedding.toFloatArray() + float expectedValue = expectedByteEmbeddings[currentPos][i]; + assertThat(actualValue, equalTo(expectedValue)); + } + } + } + + private float[] randomFloatEmbedding(int dimension) { + float[] embedding = new float[dimension]; + for (int i = 0; i < dimension; i++) { + embedding[i] = randomFloat(); + } + return embedding; + } + + private byte[] randomByteEmbedding(int dimension) { + byte[] embedding = new byte[dimension]; + for (int i = 0; i < dimension; i++) { + embedding[i] = randomByte(); + } + return embedding; + } + + private static InferenceAction.Response createFloatEmbeddingResponse(float[] embedding) { + var embeddingResult = new TextEmbeddingFloatResults.Embedding(embedding); + var textEmbeddingResults = new TextEmbeddingFloatResults(List.of(embeddingResult)); + return new InferenceAction.Response(textEmbeddingResults); + } + + private static InferenceAction.Response createByteEmbeddingResponse(byte[] embedding) { + var embeddingResult = new TextEmbeddingByteResults.Embedding(embedding); + var textEmbeddingResults = new TextEmbeddingByteResults(List.of(embeddingResult)); + return new InferenceAction.Response(textEmbeddingResults); + } + + private static InferenceAction.Response createEmptyFloatEmbeddingResponse() { + var textEmbeddingResults = new TextEmbeddingFloatResults(List.of()); + return new InferenceAction.Response(textEmbeddingResults); + } + + private Page randomInputPage(int positionCount, int columnCount) { + final Block[] blocks = new Block[columnCount]; + try { + for (int i = 0; i < columnCount; i++) { + blocks[i] = RandomBlock.randomBlock( + blockFactory(), + RandomBlock.randomElementExcluding(List.of(ElementType.AGGREGATE_METRIC_DOUBLE)), + positionCount, + randomBoolean(), + 0, + 0, + randomInt(10), + randomInt(10) + ).block(); + } + + return new Page(blocks); + } catch (Exception e) { + Releasables.close(blocks); + throw (e); + } + } +} From 2068e91a87f78f88816f6e7de727e24ee6402cae Mon Sep 17 00:00:00 2001 From: afoucret Date: Sat, 13 Sep 2025 22:09:48 +0200 Subject: [PATCH 11/26] Text embedding inference operator. --- .../textembedding/TextEmbeddingOperator.java | 96 ++++++++++++ .../TextEmbeddingOperatorTests.java | 137 ++++++++++++++++++ 2 files changed, 233 insertions(+) create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperator.java create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorTests.java diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperator.java new file mode 100644 index 0000000000000..7817612007614 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperator.java @@ -0,0 +1,96 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.textembedding; + +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; +import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.inference.InferenceOperator; +import org.elasticsearch.xpack.esql.inference.InferenceService; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner; +import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunnerConfig; + +/** + * {@link TextEmbeddingOperator} is an {@link InferenceOperator} that performs text embedding inference. + * It evaluates a text expression for each input row, constructs text embedding inference requests, + * and emits the dense vector embeddings as output. + */ +public class TextEmbeddingOperator extends InferenceOperator { + + private final ExpressionEvaluator textEvaluator; + + public TextEmbeddingOperator( + DriverContext driverContext, + BulkInferenceRunner bulkInferenceRunner, + String inferenceId, + ExpressionEvaluator textEvaluator, + int maxOutstandingPages + ) { + super(driverContext, bulkInferenceRunner, inferenceId, maxOutstandingPages); + this.textEvaluator = textEvaluator; + } + + @Override + protected void doClose() { + Releasables.close(textEvaluator); + } + + @Override + public String toString() { + return "TextEmbeddingOperator[inference_id=[" + inferenceId() + "]]"; + } + + /** + * Constructs the text embedding inference requests iterator for the given input page by evaluating the text expression. + * + * @param inputPage The input data page. + */ + @Override + protected BulkInferenceRequestIterator requests(Page inputPage) { + return new TextEmbeddingOperatorRequestIterator((BytesRefBlock) textEvaluator.eval(inputPage), inferenceId()); + } + + /** + * Creates a new {@link TextEmbeddingOperatorOutputBuilder} to collect and emit the text embedding results. + * + * @param input The input page for which results will be constructed. + */ + @Override + protected TextEmbeddingOperatorOutputBuilder outputBuilder(Page input) { + FloatBlock.Builder outputBlockBuilder = blockFactory().newFloatBlockBuilder(input.getPositionCount()); + return new TextEmbeddingOperatorOutputBuilder(outputBlockBuilder, input); + } + + /** + * Factory for creating {@link TextEmbeddingOperator} instances. + */ + public record Factory(InferenceService inferenceService, String inferenceId, ExpressionEvaluator.Factory textEvaluatorFactory) + implements + OperatorFactory { + @Override + public String describe() { + return "TextEmbeddingOperator[inference_id=[" + inferenceId + "]]"; + } + + @Override + public Operator get(DriverContext driverContext) { + return new TextEmbeddingOperator( + driverContext, + inferenceService.bulkInferenceRunner(), + inferenceId, + textEvaluatorFactory.get(driverContext), + BulkInferenceRunnerConfig.DEFAULT.maxOutstandingBulkRequests() + ); + } + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorTests.java new file mode 100644 index 0000000000000..6ff9a90b70b16 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorTests.java @@ -0,0 +1,137 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference.textembedding; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.esql.inference.InferenceOperatorTestCase; +import org.hamcrest.Matcher; +import org.junit.Before; + +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public class TextEmbeddingOperatorTests extends InferenceOperatorTestCase { + private static final String SIMPLE_INFERENCE_ID = "test_text_embedding"; + private static final int EMBEDDING_DIMENSION = 384; // Common embedding dimension + + private int inputChannel; + + @Before + public void initTextEmbeddingChannels() { + inputChannel = between(0, inputsCount - 1); + } + + @Override + protected Operator.OperatorFactory simple(SimpleOptions options) { + return new TextEmbeddingOperator.Factory(mockedInferenceService(), SIMPLE_INFERENCE_ID, evaluatorFactory(inputChannel)); + } + + @Override + protected void assertSimpleOutput(List input, List results) { + assertThat(results, hasSize(input.size())); + + for (int curPage = 0; curPage < input.size(); curPage++) { + Page inputPage = input.get(curPage); + Page resultPage = results.get(curPage); + + assertEquals(inputPage.getPositionCount(), resultPage.getPositionCount()); + assertEquals(inputPage.getBlockCount() + 1, resultPage.getBlockCount()); + + for (int channel = 0; channel < inputPage.getBlockCount(); channel++) { + Block inputBlock = inputPage.getBlock(channel); + Block resultBlock = resultPage.getBlock(channel); + assertBlockContentEquals(inputBlock, resultBlock); + } + + assertTextEmbeddingResults(inputPage, resultPage); + } + } + + private void assertTextEmbeddingResults(Page inputPage, Page resultPage) { + BytesRefBlock inputBlock = resultPage.getBlock(inputChannel); + FloatBlock resultBlock = (FloatBlock) resultPage.getBlock(inputPage.getBlockCount()); + + BlockStringReader blockReader = new InferenceOperatorTestCase.BlockStringReader(); + + for (int curPos = 0; curPos < inputPage.getPositionCount(); curPos++) { + if (inputBlock.isNull(curPos)) { + assertThat(resultBlock.isNull(curPos), equalTo(true)); + } else { + // Verify that we have an embedding vector at this position + assertThat(resultBlock.isNull(curPos), equalTo(false)); + assertThat(resultBlock.getValueCount(curPos), equalTo(EMBEDDING_DIMENSION)); + + // Get the input text to verify our mock embedding generation + String inputText = blockReader.readString(inputBlock, curPos); + + // Verify the embedding values match our mock generation pattern + int firstValueIndex = resultBlock.getFirstValueIndex(curPos); + for (int i = 0; i < EMBEDDING_DIMENSION; i++) { + float expectedValue = generateMockEmbeddingValue(inputText, i); + float actualValue = resultBlock.getFloat(firstValueIndex + i); + assertThat(actualValue, equalTo(expectedValue)); + } + } + } + } + + @Override + protected TextEmbeddingFloatResults mockInferenceResult(InferenceAction.Request request) { + // For text embedding, we expect one input text per request + String inputText = request.getInput().get(0); + + // Generate a deterministic mock embedding based on the input text + float[] mockEmbedding = generateMockEmbedding(inputText, EMBEDDING_DIMENSION); + + var embeddingResult = new TextEmbeddingFloatResults.Embedding(mockEmbedding); + return new TextEmbeddingFloatResults(List.of(embeddingResult)); + } + + @Override + protected Matcher expectedDescriptionOfSimple() { + return expectedToStringOfSimple(); + } + + @Override + protected Matcher expectedToStringOfSimple() { + return equalTo("TextEmbeddingOperator[inference_id=[" + SIMPLE_INFERENCE_ID + "]]"); + } + + /** + * Generates a deterministic mock embedding vector based on the input text. + * This ensures our tests are repeatable and verifiable. + */ + private float[] generateMockEmbedding(String inputText, int dimension) { + float[] embedding = new float[dimension]; + int textHash = inputText.hashCode(); + + for (int i = 0; i < dimension; i++) { + embedding[i] = generateMockEmbeddingValue(inputText, i); + } + + return embedding; + } + + /** + * Generates a single embedding value for a specific dimension based on input text. + * Uses a deterministic function so tests are repeatable. + */ + private float generateMockEmbeddingValue(String inputText, int dimension) { + // Create a deterministic value based on input text and dimension + int hash = (inputText.hashCode() + dimension * 31) % 10000; + return hash / 10000.0f; // Normalize to [0, 1) range + } +} From 92984506d1c3ac91268c77fd1b66052e55356b62 Mon Sep 17 00:00:00 2001 From: afoucret Date: Tue, 16 Sep 2025 08:21:31 +0200 Subject: [PATCH 12/26] More flexible output builder. --- .../xpack/esql/inference/InferenceOperator.java | 15 ++++++++------- .../CompletionOperatorOutputBuilder.java | 2 +- .../rerank/RerankOperatorOutputBuilder.java | 2 +- .../TextEmbeddingOperatorOutputBuilder.java | 2 +- .../function/inference/TextEmbeddingTests.java | 2 +- 5 files changed, 12 insertions(+), 11 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java index 93085969415a6..b5edb68488089 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.inference; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.AsyncOperator; @@ -102,7 +103,7 @@ public Page getOutput() { return null; } - try (OutputBuilder outputBuilder = outputBuilder(ongoingInferenceResult.inputPage)) { + try (OutputBuilder outputBuilder = outputBuilder(ongoingInferenceResult.inputPage)) { for (InferenceAction.Response response : ongoingInferenceResult.responses) { outputBuilder.addInferenceResponse(response); } @@ -125,12 +126,12 @@ public Page getOutput() { * * @param input The corresponding input page used to generate the inference requests. */ - protected abstract OutputBuilder outputBuilder(Page input); + protected abstract OutputBuilder outputBuilder(Page input); /** - * An interface for accumulating inference responses and constructing a result {@link Page}. + * An interface for accumulating inference responses and constructing a result (can be a {@link Page} or a {@link Block}). */ - public interface OutputBuilder extends Releasable { + public interface OutputBuilder extends Releasable { /** * Adds an inference response to the output. @@ -144,11 +145,11 @@ public interface OutputBuilder extends Releasable { void addInferenceResponse(InferenceAction.Response inferenceResponse); /** - * Builds the final output page from accumulated inference responses. + * Builds the final output from accumulated inference responses. * - * @return The constructed output page. + * @return The constructed output block. */ - Page buildOutput(); + T buildOutput(); static IR inferenceResults(InferenceAction.Response inferenceResponse, Class clazz) { InferenceServiceResults results = inferenceResponse.getResults(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilder.java index 3e9106f9a1cf6..d135ae75fa56d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilder.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilder.java @@ -20,7 +20,7 @@ * {@link CompletionOperatorOutputBuilder} builds the output page for {@link CompletionOperator} by converting {@link ChatCompletionResults} * into a {@link BytesRefBlock}. */ -public class CompletionOperatorOutputBuilder implements InferenceOperator.OutputBuilder { +public class CompletionOperatorOutputBuilder implements InferenceOperator.OutputBuilder { private final Page inputPage; private final BytesRefBlock.Builder outputBlockBuilder; private final BytesRefBuilder bytesRefBuilder = new BytesRefBuilder(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilder.java index bff95cf54bae9..605997563dec7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilder.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilder.java @@ -24,7 +24,7 @@ * * reranked relevance scores into the specified score channel of the input page. */ -public class RerankOperatorOutputBuilder implements InferenceOperator.OutputBuilder { +public class RerankOperatorOutputBuilder implements InferenceOperator.OutputBuilder { private final Page inputPage; private final DoubleBlock.Builder scoreBlockBuilder; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java index 521bb508c30af..54ef6f8c1c227 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java @@ -21,7 +21,7 @@ * {@link TextEmbeddingOperatorOutputBuilder} builds the output page for text embedding by converting * {@link TextEmbeddingResults} into a {@link FloatBlock} containing dense vector embeddings. */ -public class TextEmbeddingOperatorOutputBuilder implements InferenceOperator.OutputBuilder { +public class TextEmbeddingOperatorOutputBuilder implements InferenceOperator.OutputBuilder { private final Page inputPage; private final FloatBlock.Builder outputBlockBuilder; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingTests.java index ef5a570278c3b..b6fdc7addf984 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbeddingTests.java @@ -50,7 +50,7 @@ public static Iterable parameters() { List.of(inputTextDataType, inferenceIdDataType), () -> new TestCaseSupplier.TestCase( List.of( - new TestCaseSupplier.TypedData(randomBytesReference(10).toBytesRef(), inputTextDataType, "inputText"), + new TestCaseSupplier.TypedData(randomBytesReference(10).toBytesRef(), inputTextDataType, "text"), new TestCaseSupplier.TypedData(randomBytesReference(10).toBytesRef(), inferenceIdDataType, "inference_id") ), Matchers.blankOrNullString(), From 0ae1e0e48bbdea5361a1fdeb9913515f63c481af Mon Sep 17 00:00:00 2001 From: afoucret Date: Tue, 16 Sep 2025 15:27:35 +0200 Subject: [PATCH 13/26] Init inference function evaluator. --- .../xpack/esql/execution/PlanExecutor.java | 2 +- .../inference/InferenceFunctionEvaluator.java | 57 +++++++++++++++++++ .../optimizer/LogicalPlanPreOptimizer.java | 30 +++++++++- .../optimizer/LogicalPreOptimizerContext.java | 25 +++----- .../elasticsearch/xpack/esql/CsvTests.java | 2 +- .../LogicalPlanPreOptimizerTests.java | 2 +- 6 files changed, 97 insertions(+), 21 deletions(-) create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java index 2043176f24a29..974b73718ff0b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java @@ -88,7 +88,7 @@ public void esql( indexResolver, enrichPolicyResolver, preAnalyzer, - new LogicalPlanPreOptimizer(new LogicalPreOptimizerContext(foldContext)), + new LogicalPlanPreOptimizer(new LogicalPreOptimizerContext(foldContext, services.inferenceService())), functionRegistry, new LogicalPlanOptimizer(new LogicalOptimizerContext(cfg, foldContext)), mapper, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java new file mode 100644 index 0000000000000..9e090699b151a --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java @@ -0,0 +1,57 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.lucene.BytesRefs; +import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; +import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.evaluator.EvalMapper; +import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction; +import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding; +import org.elasticsearch.xpack.esql.inference.textembedding.TextEmbeddingOperator; + +public class InferenceFunctionEvaluator { + + private final FoldContext foldContext; + private final InferenceService inferenceService; + + public InferenceFunctionEvaluator(FoldContext foldContext, InferenceService inferenceService) { + this.foldContext = foldContext; + this.inferenceService = inferenceService; + } + + public void fold(InferenceFunction f, ActionListener listener) { + assert f.foldable() : "Inference function must be foldable"; + + + } + + private Operator.OperatorFactory createInferenceOperatorFactory(InferenceFunction f) { + return switch (f) { + case TextEmbedding textEmbedding -> new TextEmbeddingOperator.Factory( + inferenceService, + inferenceId(f), + expressionEvaluatorFactory(textEmbedding.inputText()) + ); + default -> throw new IllegalArgumentException("Unknown inference function: " + f.getClass().getName()); + }; + } + + private String inferenceId(InferenceFunction f) { + return BytesRefs.toString(f.inferenceId().fold(foldContext)); + } + + private ExpressionEvaluator.Factory expressionEvaluatorFactory(Expression e) { + assert e.foldable() : "Input expression must be foldable"; + return EvalMapper.toEvaluator(foldContext, Literal.of(foldContext, e), null); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java index fdd8e1318f636..3ec6b41f73e8a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java @@ -8,8 +8,12 @@ package org.elasticsearch.xpack.esql.optimizer; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.xpack.esql.inference.InferenceFunctionEvaluator; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import java.util.List; + /** * The class is responsible for invoking any steps that need to be applied to the logical plan, * before this is being optimized. @@ -25,6 +29,8 @@ public LogicalPlanPreOptimizer(LogicalPreOptimizerContext preOptimizerContext) { this.preOptimizerContext = preOptimizerContext; } + private static final List RULES = List.of(); + /** * Pre-optimize a logical plan. * @@ -44,7 +50,27 @@ public void preOptimize(LogicalPlan plan, ActionListener listener) } private void doPreOptimize(LogicalPlan plan, ActionListener listener) { - // this is where we will be executing async tasks - listener.onResponse(plan); + SubscribableListener ruleChainListener = SubscribableListener.newSucceeded(plan); + for (Rule rule : RULES) { + ruleChainListener = ruleChainListener.andThen((l, p) -> rule.apply(p, l)); + } + ruleChainListener.addListener(listener); + } + + public interface Rule { + void apply(LogicalPlan plan, ActionListener listener); + } + + private static class FoldInferenceFunction implements Rule { + private final InferenceFunctionEvaluator inferenceEvaluator; + + private FoldInferenceFunction(LogicalPreOptimizerContext preOptimizerContext) { + this.inferenceEvaluator = new InferenceFunctionEvaluator(preOptimizerContext.foldCtx(), preOptimizerContext.inferenceService()); + } + + @Override + public void apply(LogicalPlan plan, ActionListener listener) { + + } } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPreOptimizerContext.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPreOptimizerContext.java index d082bd56fc46d..233bf9c2945be 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPreOptimizerContext.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPreOptimizerContext.java @@ -8,36 +8,29 @@ package org.elasticsearch.xpack.esql.optimizer; import org.elasticsearch.xpack.esql.core.expression.FoldContext; - -import java.util.Objects; +import org.elasticsearch.xpack.esql.inference.InferenceService; public class LogicalPreOptimizerContext { private final FoldContext foldCtx; - public LogicalPreOptimizerContext(FoldContext foldCtx) { + private final InferenceService inferenceService; + + public LogicalPreOptimizerContext(FoldContext foldCtx, InferenceService inferenceService) { this.foldCtx = foldCtx; + this.inferenceService = inferenceService; } public FoldContext foldCtx() { return foldCtx; } - @Override - public boolean equals(Object obj) { - if (obj == this) return true; - if (obj == null || obj.getClass() != this.getClass()) return false; - var that = (LogicalPreOptimizerContext) obj; - return this.foldCtx.equals(that.foldCtx); - } - - @Override - public int hashCode() { - return Objects.hash(foldCtx); - } - @Override public String toString() { return "LogicalPreOptimizerContext[foldCtx=" + foldCtx + ']'; } + + public InferenceService inferenceService() { + return inferenceService; + } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java index c635a9c9d5c00..7351349c3b750 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java @@ -588,7 +588,7 @@ private ActualResults executePlan(BigArrays bigArrays) throws Exception { null, null, null, - new LogicalPlanPreOptimizer(new LogicalPreOptimizerContext(foldCtx)), + new LogicalPlanPreOptimizer(new LogicalPreOptimizerContext(foldCtx, null)), functionRegistry, new LogicalPlanOptimizer(new LogicalOptimizerContext(configuration, foldCtx)), mapper, diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java index 8e573dd1cf3c9..fd859cc26fc1e 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java @@ -72,7 +72,7 @@ public LogicalPlan preOptimizedPlan(LogicalPlan plan) throws Exception { } private LogicalPlanPreOptimizer preOptimizer() { - LogicalPreOptimizerContext preOptimizerContext = new LogicalPreOptimizerContext(FoldContext.small()); + LogicalPreOptimizerContext preOptimizerContext = new LogicalPreOptimizerContext(FoldContext.small(), null); return new LogicalPlanPreOptimizer(preOptimizerContext); } From 6d064f07fb2bdbd8ea61b3be56fa0cb0518ee819 Mon Sep 17 00:00:00 2001 From: afoucret Date: Thu, 18 Sep 2025 10:50:50 +0200 Subject: [PATCH 14/26] Implementing inference evaluation in the pre-optimizer. --- .../xpack/esql/qa/rest/EsqlSpecTestCase.java | 4 +- .../main/resources/text-embedding.csv-spec | 11 ++ .../inference/InferenceFunctionEvaluator.java | 176 +++++++++++++++++- .../optimizer/LogicalPlanPreOptimizer.java | 41 ++-- .../preoptimizer/FoldInferenceFunctions.java | 127 +++++++++++++ .../preoptimizer/PreOptimizerRule.java | 25 +++ 6 files changed, 352 insertions(+), 32 deletions(-) create mode 100644 x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/FoldInferenceFunctions.java create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/PreOptimizerRule.java diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java index 33ff75519348c..15b79183cac4a 100644 --- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java @@ -77,6 +77,7 @@ import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.RERANK; import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.SEMANTIC_TEXT_FIELD_CAPS; import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.SOURCE_FIELD_MAPPING; +import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.TEXT_EMBEDDING_FUNCTION; import static org.elasticsearch.xpack.esql.qa.rest.RestEsqlTestCase.assertNotPartial; import static org.elasticsearch.xpack.esql.qa.rest.RestEsqlTestCase.hasCapabilities; @@ -205,7 +206,8 @@ protected boolean requiresInferenceEndpoint() { SEMANTIC_TEXT_FIELD_CAPS.capabilityName(), RERANK.capabilityName(), COMPLETION.capabilityName(), - KNN_FUNCTION_V5.capabilityName() + KNN_FUNCTION_V5.capabilityName(), + TEXT_EMBEDDING_FUNCTION.capabilityName() ).anyMatch(testCase.requiredCapabilities::contains); } diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec new file mode 100644 index 0000000000000..8067eb5219b5a --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec @@ -0,0 +1,11 @@ +text_embedding using a ROW source operator +required_capability: text_embedding_function +required_capability: dense_vector_field_type + +ROW input="Who is Victor Hugo?" +| EVAL embedding = TEXT_EMBEDDING("Who is Victor Hugo?", "test_dense_inference") +; + +input:keyword | embedding:dense_vector +Who is Victor Hugo? | [56.0, 50.0, 48.0] +; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java index 9e090699b151a..c978ff7ccb7f0 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java @@ -8,9 +8,18 @@ package org.elasticsearch.xpack.esql.inference; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.lucene.BytesRefs; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BlockUtils; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.indices.breaker.AllCircuitBreakerStats; +import org.elasticsearch.indices.breaker.CircuitBreakerService; +import org.elasticsearch.indices.breaker.CircuitBreakerStats; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; @@ -19,24 +28,146 @@ import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding; import org.elasticsearch.xpack.esql.inference.textembedding.TextEmbeddingOperator; +/** + * Evaluator for inference functions that performs constant folding by executing inference operations + * at optimization time and replacing them with their computed results. + *

+ * This class is responsible for: + *

    + *
  • Setting up the necessary execution context (DriverContext, CircuitBreaker, etc.)
  • + *
  • Creating and configuring appropriate inference operators for different function types
  • + *
  • Executing inference operations asynchronously
  • + *
  • Converting operator results back to ESQL expressions
  • + *
+ */ public class InferenceFunctionEvaluator { private final FoldContext foldContext; private final InferenceService inferenceService; + private final InferenceOperatorProvider inferenceOperatorProvider; + /** + * Creates a new inference function evaluator with the default operator provider. + * + * @param foldContext the fold context containing circuit breakers and evaluation settings + * @param inferenceService the inference service for executing inference operations + */ public InferenceFunctionEvaluator(FoldContext foldContext, InferenceService inferenceService) { this.foldContext = foldContext; this.inferenceService = inferenceService; + this.inferenceOperatorProvider = this::createInferenceOperator; + } + + /** + * Creates a new inference function evaluator with a custom operator provider. + * This constructor is primarily used for testing to inject mock operator providers. + * + * @param foldContext the fold context containing circuit breakers and evaluation settings + * @param inferenceService the inference service for executing inference operations + * @param inferenceOperatorProvider custom provider for creating inference operators + */ + InferenceFunctionEvaluator( + FoldContext foldContext, + InferenceService inferenceService, + InferenceOperatorProvider inferenceOperatorProvider + ) { + this.foldContext = foldContext; + this.inferenceService = inferenceService; + this.inferenceOperatorProvider = inferenceOperatorProvider; } - public void fold(InferenceFunction f, ActionListener listener) { - assert f.foldable() : "Inference function must be foldable"; + /** + * Folds an inference function by executing it and replacing it with its computed result. + *

+ * This method performs the following steps: + *

    + *
  1. Validates that the function is foldable (has constant parameters)
  2. + *
  3. Sets up a minimal execution context with appropriate circuit breakers
  4. + *
  5. Creates and configures the appropriate inference operator
  6. + *
  7. Executes the inference operation asynchronously
  8. + *
  9. Converts the result to a {@link Literal} expression
  10. + *
+ * + * @param f the inference function to fold - must be foldable (have constant parameters) + * @param listener the listener to notify when folding completes successfully or fails + * @throws IllegalArgumentException if the function is not foldable + */ + public void fold(InferenceFunction f, ActionListener listener) { + if (f.foldable() == false) { + listener.onFailure(new IllegalArgumentException("Inference function must be foldable")); + return; + } + + // Set up a DriverContext for executing the inference operator. + // This follows the same pattern as EvaluatorMapper but in a simplified context + // suitable for constant folding during optimization. + CircuitBreaker breaker = foldContext.circuitBreakerView(f.source()); + BigArrays bigArrays = new BigArrays(null, new CircuitBreakerService() { + @Override + public CircuitBreaker getBreaker(String name) { + if (name.equals(CircuitBreaker.REQUEST) == false) { + throw new UnsupportedOperationException("Only REQUEST circuit breaker is supported"); + } + return breaker; + } + + @Override + public AllCircuitBreakerStats stats() { + throw new UnsupportedOperationException("Circuit breaker stats not supported in fold context"); + } + + @Override + public CircuitBreakerStats stats(String name) { + throw new UnsupportedOperationException("Circuit breaker stats not supported in fold context"); + } + }, CircuitBreaker.REQUEST).withCircuitBreaking(); + + DriverContext driverContext = new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays)); + // Create the inference operator for the specific function type using the provider + Operator inferenceOperator = inferenceOperatorProvider.getOperator(f, driverContext); + // Execute the inference operation asynchronously and handle the result + // The operator will perform the actual inference call and return a page with the result + driverContext.waitForAsyncActions(listener.delegateFailureIgnoreResponseAndWrap(l -> { + Page output = inferenceOperator.getOutput(); + + if (output == null) { + l.onFailure(new IllegalStateException("Expected output page from inference operator")); + return; + } + + if (output.getPositionCount() != 1 || output.getBlockCount() != 1) { + l.onFailure(new IllegalStateException("Expected a single block with a single value from inference operator")); + return; + } + + // Convert the operator result back to an ESQL expression (Literal) + l.onResponse(Literal.of(f, BlockUtils.toJavaObject(output.getBlock(0), 0))); + })); + + // Feed the operator with a single page to trigger execution + // The actual input data is already bound in the operator through expression evaluators + inferenceOperator.addInput(new Page(1)); + + driverContext.finish(); } - private Operator.OperatorFactory createInferenceOperatorFactory(InferenceFunction f) { - return switch (f) { + /** + * Creates an inference operator for the given function type and driver context. + *

+ * This method uses pattern matching to determine the correct operator factory based on + * the inference function type, creates the factory, and then instantiates the operator + * with the provided driver context. Each supported inference function type has its own + * specialized operator implementation. + * + * @param f the inference function to create an operator for + * @param driverContext the driver context to use for operator creation + * @return an operator instance configured for the given function type + * @throws IllegalArgumentException if the function type is not supported + */ + private Operator createInferenceOperator(InferenceFunction f, DriverContext driverContext) { + Operator.OperatorFactory factory = switch (f) { case TextEmbedding textEmbedding -> new TextEmbeddingOperator.Factory( inferenceService, inferenceId(f), @@ -44,14 +175,51 @@ private Operator.OperatorFactory createInferenceOperatorFactory(InferenceFunctio ); default -> throw new IllegalArgumentException("Unknown inference function: " + f.getClass().getName()); }; + + return factory.get(driverContext); } + /** + * Extracts the inference endpoint ID from an inference function. + * + * @param f the inference function containing the inference ID + * @return the inference endpoint ID as a string + */ private String inferenceId(InferenceFunction f) { return BytesRefs.toString(f.inferenceId().fold(foldContext)); } + /** + * Creates an expression evaluator factory for a foldable expression. + *

+ * This method converts a foldable expression into an evaluator factory that can be used by inference + * operators. The expression is first folded to its constant value and then wrapped in a literal. + * + * @param e the foldable expression to create an evaluator factory for + * @return an expression evaluator factory for the given expression + * @throws AssertionError if the expression is not foldable (in debug builds) + */ private ExpressionEvaluator.Factory expressionEvaluatorFactory(Expression e) { assert e.foldable() : "Input expression must be foldable"; return EvalMapper.toEvaluator(foldContext, Literal.of(foldContext, e), null); } + + /** + * Functional interface for providing inference operators. + *

+ * This interface abstracts the creation of inference operators for different function types, + * allowing for easier testing and potential future extensibility. The provider is responsible + * for creating an appropriate operator instance given an inference function and driver context. + */ + @FunctionalInterface + interface InferenceOperatorProvider { + /** + * Creates an inference operator for the given function and driver context. + * + * @param f the inference function to create an operator for + * @param driverContext the driver context to use for operator creation + * @return an operator instance configured for the given function + */ + Operator getOperator(InferenceFunction f, DriverContext driverContext); + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java index 3ec6b41f73e8a..dbb1fb45334da 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java @@ -9,7 +9,8 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.SubscribableListener; -import org.elasticsearch.xpack.esql.inference.InferenceFunctionEvaluator; +import org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer.FoldInferenceFunctions; +import org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer.PreOptimizerRule; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import java.util.List; @@ -22,15 +23,12 @@ *

*/ public class LogicalPlanPreOptimizer { - - private final LogicalPreOptimizerContext preOptimizerContext; + private final List preOptimizerRules; public LogicalPlanPreOptimizer(LogicalPreOptimizerContext preOptimizerContext) { - this.preOptimizerContext = preOptimizerContext; + preOptimizerRules = List.of(new FoldInferenceFunctions(preOptimizerContext)); } - private static final List RULES = List.of(); - /** * Pre-optimize a logical plan. * @@ -49,28 +47,17 @@ public void preOptimize(LogicalPlan plan, ActionListener listener) })); } + /** + * Loop over the rules and apply them sequentially to the logical plan. + * + * @param plan the analyzed logical plan to pre-optimize + * @param listener the listener returning the pre-optimized plan when pre-optimization is complete + */ private void doPreOptimize(LogicalPlan plan, ActionListener listener) { - SubscribableListener ruleChainListener = SubscribableListener.newSucceeded(plan); - for (Rule rule : RULES) { - ruleChainListener = ruleChainListener.andThen((l, p) -> rule.apply(p, l)); - } - ruleChainListener.addListener(listener); - } - - public interface Rule { - void apply(LogicalPlan plan, ActionListener listener); - } - - private static class FoldInferenceFunction implements Rule { - private final InferenceFunctionEvaluator inferenceEvaluator; - - private FoldInferenceFunction(LogicalPreOptimizerContext preOptimizerContext) { - this.inferenceEvaluator = new InferenceFunctionEvaluator(preOptimizerContext.foldCtx(), preOptimizerContext.inferenceService()); - } - - @Override - public void apply(LogicalPlan plan, ActionListener listener) { - + SubscribableListener rulesListener = SubscribableListener.newSucceeded(plan); + for (PreOptimizerRule preOptimizerRule : preOptimizerRules) { + rulesListener = rulesListener.andThen((l, p) -> preOptimizerRule.apply(p, l)); } + rulesListener.addListener(listener); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/FoldInferenceFunctions.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/FoldInferenceFunctions.java new file mode 100644 index 0000000000000..7b80d3d69039d --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/FoldInferenceFunctions.java @@ -0,0 +1,127 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.CountDownActionListener; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction; +import org.elasticsearch.xpack.esql.inference.InferenceFunctionEvaluator; +import org.elasticsearch.xpack.esql.optimizer.LogicalPreOptimizerContext; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Pre-optimizer rule that performs constant folding for inference functions in logical plans. + *

+ * This rule identifies inference functions with constant parameters and evaluates them at optimization time, + * replacing them with their computed results. + *

+ * The folding process is recursive and handles nested inference functions by processing them in multiple + * passes until no more foldable functions remain. + *

+ * Example transformation: + * {@code TEXT_EMBEDDING("hello world", "model1")} → {@code [0.1, 0.2, 0.3, ...]} + */ +public class FoldInferenceFunctions implements PreOptimizerRule { + + private final InferenceFunctionEvaluator inferenceFunctionEvaluator; + + public FoldInferenceFunctions(LogicalPreOptimizerContext preOptimizerContext) { + inferenceFunctionEvaluator = new InferenceFunctionEvaluator(preOptimizerContext.foldCtx(), preOptimizerContext.inferenceService()); + } + + @Override + public void apply(LogicalPlan plan, ActionListener listener) { + foldInferenceFunctions(plan, listener); + } + + /** + * Recursively folds inference functions in the logical plan. + *

+ * This method collects all foldable inference functions, evaluates them in parallel, + * and then replaces them with their computed results. If new foldable inference functions are remaining + * after the first round of folding (due to nested function resolution), it recursively processes + * them until no more foldable functions remain. + *

+ * + * @param plan the logical plan to fold inference functions in + * @param listener the listener to notify when the folding is complete + */ + private void foldInferenceFunctions(LogicalPlan plan, ActionListener listener) { + // Collect all foldable inference functions from the current plan + List> inferenceFunctions = collectFoldableInferenceFunctions(plan); + + if (inferenceFunctions.isEmpty()) { + // No foldable inference functions were found - return the original plan unchanged + listener.onResponse(plan); + return; + } + + // Map to store the computed results for each inference function + Map, Expression> inferenceFunctionsToResults = new HashMap<>(); + + // Create a countdown listener that will be triggered when all inference functions complete + // Once all are done, replace the functions in the plan with their results and recursively + // process any remaining foldable inference functions + CountDownActionListener completionListener = new CountDownActionListener( + inferenceFunctions.size(), + listener.delegateFailureIgnoreResponseAndWrap(l -> { + // Transform the plan by replacing inference functions with their computed results + LogicalPlan transformedPlan = plan.transformExpressionsUp( + InferenceFunction.class, + f -> inferenceFunctionsToResults.getOrDefault(f, f) + ); + + // Recursively process the transformed plan to handle any remaining inference functions + foldInferenceFunctions(transformedPlan, l); + }) + ); + + // Evaluate each inference function asynchronously + for (InferenceFunction inferenceFunction : inferenceFunctions) { + inferenceFunctionEvaluator.fold(inferenceFunction, completionListener.delegateFailureAndWrap((l, result) -> { + inferenceFunctionsToResults.put(inferenceFunction, result); + l.onResponse(null); + })); + } + } + + /** + * Collects all foldable inference functions from the logical plan. + *

+ * A function is considered foldable if it meets all of the following criteria: + *

    + *
  1. It's an instance of {@link InferenceFunction}
  2. + *
  3. It's marked as foldable (all parameters are constants)
  4. + *
  5. It doesn't contain nested inference functions (to avoid dependency issues)
  6. + *
+ *

+ * Functions with nested inference functions are excluded to ensure proper evaluation order. + * They will be considered for folding in subsequent recursive passes after their nested + * functions have been resolved. + * + * @param plan the logical plan to collect inference functions from + * @return a list of foldable inference functions, may be empty if none are found + */ + private List> collectFoldableInferenceFunctions(LogicalPlan plan) { + List> inferenceFunctions = new ArrayList<>(); + + plan.forEachExpressionUp(InferenceFunction.class, f -> { + if (f.foldable() && f.hasNestedInferenceFunction() == false) { + inferenceFunctions.add(f); + } + }); + + return inferenceFunctions; + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/PreOptimizerRule.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/PreOptimizerRule.java new file mode 100644 index 0000000000000..e29216a020afb --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/PreOptimizerRule.java @@ -0,0 +1,25 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; + +/** + * A rule that can be applied to an analyzed logical plan before it is optimized. + */ +public interface PreOptimizerRule { + + /** + * Apply the rule to the logical plan. + * + * @param plan the analyzed logical plan to pre-optimize + * @param listener the listener returning the pre-optimized plan when pre-optimization rule is applied + */ + void apply(LogicalPlan plan, ActionListener listener); +} From 6fc5f999884769878807cfe7c2834b8b068fe320 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Thu, 18 Sep 2025 08:57:57 +0000 Subject: [PATCH 15/26] [CI] Auto commit changes from spotless --- .../esql/expression/function/EsqlFunctionRegistry.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index f934dc71de792..f7da7380015bb 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -502,7 +502,7 @@ private static FunctionDefinition[][] functions() { def(Match.class, tri(Match::new), "match"), def(MultiMatch.class, MultiMatch::new, "multi_match"), def(QueryString.class, bi(QueryString::new), "qstr"), - def(MatchPhrase.class, tri(MatchPhrase::new), "match_phrase")}, + def(MatchPhrase.class, tri(MatchPhrase::new), "match_phrase") }, // time-series functions new FunctionDefinition[] { def(Rate.class, uni(Rate::new), "rate"), @@ -517,13 +517,13 @@ private static FunctionDefinition[][] functions() { def(AbsentOverTime.class, uni(AbsentOverTime::new), "absent_over_time"), def(AvgOverTime.class, uni(AvgOverTime::new), "avg_over_time"), def(LastOverTime.class, uni(LastOverTime::new), "last_over_time"), - def(FirstOverTime.class, uni(FirstOverTime::new), "first_over_time")}}; + def(FirstOverTime.class, uni(FirstOverTime::new), "first_over_time") } }; } private static FunctionDefinition[][] snapshotFunctions() { - return new FunctionDefinition[][]{ - new FunctionDefinition[]{ + return new FunctionDefinition[][] { + new FunctionDefinition[] { // The delay() function is for debug/snapshot environments only and should never be enabled in a non-snapshot build. // This is an experimental function and can be removed without notice. def(Delay.class, Delay::new, "delay"), From 5ee238284fdb93e1f3d45160a740948a91917c11 Mon Sep 17 00:00:00 2001 From: afoucret Date: Thu, 18 Sep 2025 11:33:21 +0200 Subject: [PATCH 16/26] Remove overengineered type param on the InferenceOperator.OutputBuilder --- .../xpack/esql/inference/InferenceOperator.java | 11 +++++------ .../completion/CompletionOperatorOutputBuilder.java | 2 +- .../inference/rerank/RerankOperatorOutputBuilder.java | 2 +- .../TextEmbeddingOperatorOutputBuilder.java | 2 +- 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java index b5edb68488089..a95b72d64f9da 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.esql.inference; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.AsyncOperator; @@ -103,7 +102,7 @@ public Page getOutput() { return null; } - try (OutputBuilder outputBuilder = outputBuilder(ongoingInferenceResult.inputPage)) { + try (OutputBuilder outputBuilder = outputBuilder(ongoingInferenceResult.inputPage)) { for (InferenceAction.Response response : ongoingInferenceResult.responses) { outputBuilder.addInferenceResponse(response); } @@ -126,12 +125,12 @@ public Page getOutput() { * * @param input The corresponding input page used to generate the inference requests. */ - protected abstract OutputBuilder outputBuilder(Page input); + protected abstract OutputBuilder outputBuilder(Page input); /** - * An interface for accumulating inference responses and constructing a result (can be a {@link Page} or a {@link Block}). + * An interface for accumulating inference responses and constructing the result page.. */ - public interface OutputBuilder extends Releasable { + public interface OutputBuilder extends Releasable { /** * Adds an inference response to the output. @@ -149,7 +148,7 @@ public interface OutputBuilder extends Releasable { * * @return The constructed output block. */ - T buildOutput(); + Page buildOutput(); static IR inferenceResults(InferenceAction.Response inferenceResponse, Class clazz) { InferenceServiceResults results = inferenceResponse.getResults(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilder.java index d135ae75fa56d..3e9106f9a1cf6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilder.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilder.java @@ -20,7 +20,7 @@ * {@link CompletionOperatorOutputBuilder} builds the output page for {@link CompletionOperator} by converting {@link ChatCompletionResults} * into a {@link BytesRefBlock}. */ -public class CompletionOperatorOutputBuilder implements InferenceOperator.OutputBuilder { +public class CompletionOperatorOutputBuilder implements InferenceOperator.OutputBuilder { private final Page inputPage; private final BytesRefBlock.Builder outputBlockBuilder; private final BytesRefBuilder bytesRefBuilder = new BytesRefBuilder(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilder.java index 605997563dec7..bff95cf54bae9 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilder.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilder.java @@ -24,7 +24,7 @@ * * reranked relevance scores into the specified score channel of the input page. */ -public class RerankOperatorOutputBuilder implements InferenceOperator.OutputBuilder { +public class RerankOperatorOutputBuilder implements InferenceOperator.OutputBuilder { private final Page inputPage; private final DoubleBlock.Builder scoreBlockBuilder; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java index 54ef6f8c1c227..521bb508c30af 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/textembedding/TextEmbeddingOperatorOutputBuilder.java @@ -21,7 +21,7 @@ * {@link TextEmbeddingOperatorOutputBuilder} builds the output page for text embedding by converting * {@link TextEmbeddingResults} into a {@link FloatBlock} containing dense vector embeddings. */ -public class TextEmbeddingOperatorOutputBuilder implements InferenceOperator.OutputBuilder { +public class TextEmbeddingOperatorOutputBuilder implements InferenceOperator.OutputBuilder { private final Page inputPage; private final FloatBlock.Builder outputBlockBuilder; From 3e9d3d596de0c08c3e4e051a74f4a8bbd5afa798 Mon Sep 17 00:00:00 2001 From: afoucret Date: Thu, 18 Sep 2025 12:55:26 +0200 Subject: [PATCH 17/26] Unit tests for InferenceFunctionEvaluator --- .../inference/InferenceFunctionEvaluator.java | 28 +- .../InferenceFunctionEvaluatorTests.java | 264 ++++++++++++++++++ 2 files changed, 281 insertions(+), 11 deletions(-) create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluatorTests.java diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java index c978ff7ccb7f0..9edbc1b159d2e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java @@ -132,18 +132,24 @@ public CircuitBreakerStats stats(String name) { driverContext.waitForAsyncActions(listener.delegateFailureIgnoreResponseAndWrap(l -> { Page output = inferenceOperator.getOutput(); - if (output == null) { - l.onFailure(new IllegalStateException("Expected output page from inference operator")); - return; - } + try { + if (output == null) { + l.onFailure(new IllegalStateException("Expected output page from inference operator")); + return; + } - if (output.getPositionCount() != 1 || output.getBlockCount() != 1) { - l.onFailure(new IllegalStateException("Expected a single block with a single value from inference operator")); - return; - } + if (output.getPositionCount() != 1 || output.getBlockCount() != 1) { + l.onFailure(new IllegalStateException("Expected a single block with a single value from inference operator")); + return; + } - // Convert the operator result back to an ESQL expression (Literal) - l.onResponse(Literal.of(f, BlockUtils.toJavaObject(output.getBlock(0), 0))); + // Convert the operator result back to an ESQL expression (Literal) + l.onResponse(Literal.of(f, BlockUtils.toJavaObject(output.getBlock(0), 0))); + } finally { + if (output != null) { + output.releaseBlocks(); + } + } })); // Feed the operator with a single page to trigger execution @@ -193,7 +199,7 @@ private String inferenceId(InferenceFunction f) { * Creates an expression evaluator factory for a foldable expression. *

* This method converts a foldable expression into an evaluator factory that can be used by inference - * operators. The expression is first folded to its constant value and then wrapped in a literal. + * operators. The expressionis first folded to its constant value and then wrapped in a literal. * * @param e the foldable expression to create an evaluator factory for * @return an expression evaluator factory for the given expression diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluatorTests.java new file mode 100644 index 0000000000000..a86cb722aeabc --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluatorTests.java @@ -0,0 +1,264 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.inference; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.compute.test.ComputeTestCase; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.function.Function; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction; +import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding; +import org.junit.After; +import org.junit.Before; + +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; + +import static org.elasticsearch.xpack.esql.EsqlTestUtils.as; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class InferenceFunctionEvaluatorTests extends ComputeTestCase { + + private ThreadPool threadPool; + + @Before + public void setupThreadPool() { + this.threadPool = createThreadPool(); + } + + @After + public void tearDownThreadPool() { + terminate(threadPool); + } + + public void testFoldTextEmbeddingFunction() throws Exception { + // Create a mock TextEmbedding function + TextEmbedding textEmbeddingFunction = new TextEmbedding( + Source.EMPTY, + Literal.keyword(Source.EMPTY, "test-model"), + Literal.keyword(Source.EMPTY, "test input") + ); + + // Create a mock operator that returns a result + Operator operator = mock(Operator.class); + + Float[] embedding = randomArray(1, 100, Float[]::new, ESTestCase::randomFloat); + + when(operator.getOutput()).thenAnswer(i -> { + FloatBlock.Builder outputBlockBuilder = blockFactory().newFloatBlockBuilder(1).beginPositionEntry(); + + for (int j = 0; j < embedding.length; j++) { + outputBlockBuilder.appendFloat(embedding[j]); + } + + outputBlockBuilder.endPositionEntry(); + + return new Page(outputBlockBuilder.build()); + }); + + InferenceFunctionEvaluator.InferenceOperatorProvider inferenceOperatorProvider = (f, driverContext) -> operator; + + // Execute the fold operation + InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator( + FoldContext.small(), + mock(InferenceService.class), + inferenceOperatorProvider + ); + + AtomicReference resultExpression = new AtomicReference<>(); + evaluator.fold(textEmbeddingFunction, ActionListener.wrap(resultExpression::set, ESTestCase::fail)); + + assertBusy(() -> { + assertNotNull(resultExpression.get()); + Literal result = as(resultExpression.get(), Literal.class); + assertThat(result.dataType(), equalTo(DataType.DENSE_VECTOR)); + assertThat(as(result.value(), List.class).toArray(), equalTo(embedding)); + }); + + // Check all breakers are empty after the operation is executed + allBreakersEmpty(); + } + + public void testFoldWithNonFoldableFunction() { + // A function with a non-literal argument is not foldable. + TextEmbedding textEmbeddingFunction = new TextEmbedding( + Source.EMPTY, + mock(Attribute.class), + Literal.keyword(Source.EMPTY, "test input") + ); + + InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator( + FoldContext.small(), + mock(InferenceService.class), + (f, driverContext) -> mock(Operator.class) + ); + + AtomicReference error = new AtomicReference<>(); + evaluator.fold(textEmbeddingFunction, ActionListener.wrap(r -> fail("should have failed"), error::set)); + + assertNotNull(error.get()); + assertThat(error.get(), instanceOf(IllegalArgumentException.class)); + assertThat(error.get().getMessage(), equalTo("Inference function must be foldable")); + } + + public void testFoldWithAsyncFailure() throws Exception { + TextEmbedding textEmbeddingFunction = new TextEmbedding( + Source.EMPTY, + Literal.keyword(Source.EMPTY, "test-model"), + Literal.keyword(Source.EMPTY, "test input") + ); + + // Mock an operator that will trigger an async failure + Operator operator = mock(Operator.class); + doAnswer(invocation -> { + // Simulate the operator finishing and then immediately calling the failure listener + // This happens inside the `DriverContext` logic that the evaluator uses. + // We can't directly access the listener, so we'll have the operator throw an exception + // which should be caught and propagated to the listener. + throw new RuntimeException("async failure"); + }).when(operator).addInput(new Page(1)); + + InferenceFunctionEvaluator.InferenceOperatorProvider inferenceOperatorProvider = (f, driverContext) -> operator; + InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator( + FoldContext.small(), + mock(InferenceService.class), + inferenceOperatorProvider + ); + + AtomicReference error = new AtomicReference<>(); + evaluator.fold(textEmbeddingFunction, ActionListener.wrap(r -> fail("should have failed"), error::set)); + + assertBusy(() -> assertNotNull(error.get())); + assertThat(error.get(), instanceOf(RuntimeException.class)); + assertThat(error.get().getMessage(), equalTo("async failure")); + + allBreakersEmpty(); + } + + public void testFoldWithNullOutputPage() throws Exception { + TextEmbedding textEmbeddingFunction = new TextEmbedding( + Source.EMPTY, + Literal.keyword(Source.EMPTY, "test-model"), + Literal.keyword(Source.EMPTY, "test input") + ); + + Operator operator = mock(Operator.class); + when(operator.getOutput()).thenReturn(null); + + InferenceFunctionEvaluator.InferenceOperatorProvider inferenceOperatorProvider = (f, driverContext) -> operator; + InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator( + FoldContext.small(), + mock(InferenceService.class), + inferenceOperatorProvider + ); + + AtomicReference error = new AtomicReference<>(); + evaluator.fold(textEmbeddingFunction, ActionListener.wrap(r -> fail("should have failed"), error::set)); + + assertBusy(() -> assertNotNull(error.get())); + assertThat(error.get(), instanceOf(IllegalStateException.class)); + assertThat(error.get().getMessage(), equalTo("Expected output page from inference operator")); + + allBreakersEmpty(); + } + + public void testFoldWithMultiPositionOutputPage() throws Exception { + TextEmbedding textEmbeddingFunction = new TextEmbedding( + Source.EMPTY, + Literal.keyword(Source.EMPTY, "test-model"), + Literal.keyword(Source.EMPTY, "test input") + ); + + Operator operator = mock(Operator.class); + // Output page should have exactly one position for constant folding + when(operator.getOutput()).thenReturn(new Page(blockFactory().newFloatBlockBuilder(2).build())); + + InferenceFunctionEvaluator.InferenceOperatorProvider inferenceOperatorProvider = (f, driverContext) -> operator; + InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator( + FoldContext.small(), + mock(InferenceService.class), + inferenceOperatorProvider + ); + + AtomicReference error = new AtomicReference<>(); + evaluator.fold(textEmbeddingFunction, ActionListener.wrap(r -> fail("should have failed"), error::set)); + + assertBusy(() -> assertNotNull(error.get())); + assertThat(error.get(), instanceOf(IllegalStateException.class)); + assertThat(error.get().getMessage(), equalTo("Expected a single block with a single value from inference operator")); + + allBreakersEmpty(); + } + + public void testFoldWithMultiBlockOutputPage() throws Exception { + TextEmbedding textEmbeddingFunction = new TextEmbedding( + Source.EMPTY, + Literal.keyword(Source.EMPTY, "test-model"), + Literal.keyword(Source.EMPTY, "test input") + ); + + Operator operator = mock(Operator.class); + // Output page should have exactly one block for constant folding + when(operator.getOutput()).thenReturn( + new Page(blockFactory().newFloatBlockBuilder(1).build(), blockFactory().newFloatBlockBuilder(1).build()) + ); + + InferenceFunctionEvaluator.InferenceOperatorProvider inferenceOperatorProvider = (f, driverContext) -> operator; + InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator( + FoldContext.small(), + mock(InferenceService.class), + inferenceOperatorProvider + ); + + AtomicReference error = new AtomicReference<>(); + evaluator.fold(textEmbeddingFunction, ActionListener.wrap(r -> fail("should have failed"), error::set)); + + assertBusy(() -> assertNotNull(error.get())); + assertThat(error.get(), instanceOf(IllegalStateException.class)); + assertThat(error.get().getMessage(), equalTo("Expected a single block with a single value from inference operator")); + + allBreakersEmpty(); + } + + public void testFoldWithUnsupportedFunction() throws Exception { + Function unsupported = mock(Function.class); + when(unsupported.foldable()).thenReturn(true); + + InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator( + FoldContext.small(), + mock(InferenceService.class), + (f, driverContext) -> { + throw new IllegalArgumentException("Unknown inference function: " + f.getClass().getName()); + } + ); + + AtomicReference error = new AtomicReference<>(); + evaluator.fold((InferenceFunction) unsupported, ActionListener.wrap(r -> fail("should have failed"), error::set)); + + assertNotNull(error.get()); + assertThat(error.get(), instanceOf(IllegalArgumentException.class)); + assertThat(error.get().getMessage(), containsString("Unknown inference function")); + + allBreakersEmpty(); + } +} From 56040552fd19fe8d9dbd051b46a1ef5ef6d7ba69 Mon Sep 17 00:00:00 2001 From: afoucret Date: Thu, 18 Sep 2025 14:49:33 +0200 Subject: [PATCH 18/26] Unit tests for InferenceFunctionEvaluator --- .../inference/InferenceFunctionEvaluator.java | 205 +++++++++--------- .../preoptimizer/FoldInferenceFunctions.java | 3 +- .../InferenceFunctionEvaluatorTests.java | 100 +-------- 3 files changed, 114 insertions(+), 194 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java index 9edbc1b159d2e..4caf91a8bf917 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java @@ -15,7 +15,7 @@ import org.elasticsearch.compute.data.BlockUtils; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; +import org.elasticsearch.compute.operator.EvalOperator; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.indices.breaker.AllCircuitBreakerStats; import org.elasticsearch.indices.breaker.CircuitBreakerService; @@ -42,37 +42,24 @@ */ public class InferenceFunctionEvaluator { - private final FoldContext foldContext; - private final InferenceService inferenceService; - private final InferenceOperatorProvider inferenceOperatorProvider; + private static final Factory FACTORY = new Factory(); - /** - * Creates a new inference function evaluator with the default operator provider. - * - * @param foldContext the fold context containing circuit breakers and evaluation settings - * @param inferenceService the inference service for executing inference operations - */ - public InferenceFunctionEvaluator(FoldContext foldContext, InferenceService inferenceService) { - this.foldContext = foldContext; - this.inferenceService = inferenceService; - this.inferenceOperatorProvider = this::createInferenceOperator; + public static InferenceFunctionEvaluator.Factory factory() { + return FACTORY; } + private final FoldContext foldContext; + private final InferenceOperatorProvider inferenceOperatorProvider; + /** * Creates a new inference function evaluator with a custom operator provider. * This constructor is primarily used for testing to inject mock operator providers. * * @param foldContext the fold context containing circuit breakers and evaluation settings - * @param inferenceService the inference service for executing inference operations * @param inferenceOperatorProvider custom provider for creating inference operators */ - InferenceFunctionEvaluator( - FoldContext foldContext, - InferenceService inferenceService, - InferenceOperatorProvider inferenceOperatorProvider - ) { + InferenceFunctionEvaluator(FoldContext foldContext, InferenceOperatorProvider inferenceOperatorProvider) { this.foldContext = foldContext; - this.inferenceService = inferenceService; this.inferenceOperatorProvider = inferenceOperatorProvider; } @@ -90,7 +77,6 @@ public InferenceFunctionEvaluator(FoldContext foldContext, InferenceService infe * * @param f the inference function to fold - must be foldable (have constant parameters) * @param listener the listener to notify when folding completes successfully or fails - * @throws IllegalArgumentException if the function is not foldable */ public void fold(InferenceFunction f, ActionListener listener) { if (f.foldable() == false) { @@ -125,89 +111,41 @@ public CircuitBreakerStats stats(String name) { DriverContext driverContext = new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays)); // Create the inference operator for the specific function type using the provider - Operator inferenceOperator = inferenceOperatorProvider.getOperator(f, driverContext); - - // Execute the inference operation asynchronously and handle the result - // The operator will perform the actual inference call and return a page with the result - driverContext.waitForAsyncActions(listener.delegateFailureIgnoreResponseAndWrap(l -> { - Page output = inferenceOperator.getOutput(); - - try { - if (output == null) { - l.onFailure(new IllegalStateException("Expected output page from inference operator")); - return; - } - if (output.getPositionCount() != 1 || output.getBlockCount() != 1) { - l.onFailure(new IllegalStateException("Expected a single block with a single value from inference operator")); - return; + try (Operator inferenceOperator = inferenceOperatorProvider.getOperator(f, driverContext)) { + // Execute the inference operation asynchronously and handle the result + // The operator will perform the actual inference call and return a page with the result + driverContext.waitForAsyncActions(listener.delegateFailureIgnoreResponseAndWrap(l -> { + Page output = inferenceOperator.getOutput(); + + try { + if (output == null) { + l.onFailure(new IllegalStateException("Expected output page from inference operator")); + return; + } + + if (output.getPositionCount() != 1 || output.getBlockCount() != 1) { + l.onFailure(new IllegalStateException("Expected a single block with a single value from inference operator")); + return; + } + + // Convert the operator result back to an ESQL expression (Literal) + l.onResponse(Literal.of(f, BlockUtils.toJavaObject(output.getBlock(0), 0))); + } finally { + if (output != null) { + output.releaseBlocks(); + } } - - // Convert the operator result back to an ESQL expression (Literal) - l.onResponse(Literal.of(f, BlockUtils.toJavaObject(output.getBlock(0), 0))); - } finally { - if (output != null) { - output.releaseBlocks(); - } - } - })); - - // Feed the operator with a single page to trigger execution - // The actual input data is already bound in the operator through expression evaluators - inferenceOperator.addInput(new Page(1)); - - driverContext.finish(); - } - - /** - * Creates an inference operator for the given function type and driver context. - *

- * This method uses pattern matching to determine the correct operator factory based on - * the inference function type, creates the factory, and then instantiates the operator - * with the provided driver context. Each supported inference function type has its own - * specialized operator implementation. - * - * @param f the inference function to create an operator for - * @param driverContext the driver context to use for operator creation - * @return an operator instance configured for the given function type - * @throws IllegalArgumentException if the function type is not supported - */ - private Operator createInferenceOperator(InferenceFunction f, DriverContext driverContext) { - Operator.OperatorFactory factory = switch (f) { - case TextEmbedding textEmbedding -> new TextEmbeddingOperator.Factory( - inferenceService, - inferenceId(f), - expressionEvaluatorFactory(textEmbedding.inputText()) - ); - default -> throw new IllegalArgumentException("Unknown inference function: " + f.getClass().getName()); - }; - - return factory.get(driverContext); - } - - /** - * Extracts the inference endpoint ID from an inference function. - * - * @param f the inference function containing the inference ID - * @return the inference endpoint ID as a string - */ - private String inferenceId(InferenceFunction f) { - return BytesRefs.toString(f.inferenceId().fold(foldContext)); - } - - /** - * Creates an expression evaluator factory for a foldable expression. - *

- * This method converts a foldable expression into an evaluator factory that can be used by inference - * operators. The expressionis first folded to its constant value and then wrapped in a literal. - * - * @param e the foldable expression to create an evaluator factory for - * @return an expression evaluator factory for the given expression - * @throws AssertionError if the expression is not foldable (in debug builds) - */ - private ExpressionEvaluator.Factory expressionEvaluatorFactory(Expression e) { - assert e.foldable() : "Input expression must be foldable"; - return EvalMapper.toEvaluator(foldContext, Literal.of(foldContext, e), null); + })); + + // Feed the operator with a single page to trigger execution + // The actual input data is already bound in the operator through expression evaluators + inferenceOperator.addInput(new Page(1)); + } catch (Exception e) { + listener.onFailure(e); + } finally { + driverContext.finish(); + } } /** @@ -217,7 +155,6 @@ private ExpressionEvaluator.Factory expressionEvaluatorFactory(Expression e) { * allowing for easier testing and potential future extensibility. The provider is responsible * for creating an appropriate operator instance given an inference function and driver context. */ - @FunctionalInterface interface InferenceOperatorProvider { /** * Creates an inference operator for the given function and driver context. @@ -228,4 +165,64 @@ interface InferenceOperatorProvider { */ Operator getOperator(InferenceFunction f, DriverContext driverContext); } + + /** + * Factory for creating {@link InferenceFunctionEvaluator} instances. + */ + public static class Factory { + private Factory() {} + + /** + * Creates a new inference function evaluator. + * + * @param foldContext the fold context + * @param inferenceService the inference service + * @return a new instance of {@link InferenceFunctionEvaluator} + */ + public InferenceFunctionEvaluator create(FoldContext foldContext, InferenceService inferenceService) { + return new InferenceFunctionEvaluator(foldContext, createInferenceOperatorProvider(foldContext, inferenceService)); + } + + /** + * Creates an {@link InferenceOperatorProvider} that can produce operators for all supported inference functions. + */ + private InferenceOperatorProvider createInferenceOperatorProvider(FoldContext foldContext, InferenceService inferenceService) { + return (inferenceFunction, driverContext) -> { + Operator.OperatorFactory factory = switch (inferenceFunction) { + case TextEmbedding textEmbedding -> new TextEmbeddingOperator.Factory( + inferenceService, + inferenceId(inferenceFunction, foldContext), + expressionEvaluatorFactory(textEmbedding.inputText(), foldContext) + ); + default -> throw new IllegalArgumentException("Unknown inference function: " + inferenceFunction.getClass().getName()); + }; + + return factory.get(driverContext); + }; + } + + /** + * Extracts the inference endpoint ID from an inference function. + * + * @param f the inference function containing the inference ID + * @return the inference endpoint ID as a string + */ + private String inferenceId(InferenceFunction f, FoldContext foldContext) { + return BytesRefs.toString(f.inferenceId().fold(foldContext)); + } + + /** + * Creates an expression evaluator factory for a foldable expression. + *

+ * This method converts a foldable expression into an evaluator factory that can be used by inference + * operators. The expression is first folded to its constant value and then wrapped in a literal. + * + * @param e the foldable expression to create an evaluator factory for + * @return an expression evaluator factory for the given expression + */ + private EvalOperator.ExpressionEvaluator.Factory expressionEvaluatorFactory(Expression e, FoldContext foldContext) { + assert e.foldable() : "Input expression must be foldable"; + return EvalMapper.toEvaluator(foldContext, Literal.of(foldContext, e), null); + } + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/FoldInferenceFunctions.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/FoldInferenceFunctions.java index 7b80d3d69039d..63b6d4c34879d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/FoldInferenceFunctions.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/FoldInferenceFunctions.java @@ -37,7 +37,8 @@ public class FoldInferenceFunctions implements PreOptimizerRule { private final InferenceFunctionEvaluator inferenceFunctionEvaluator; public FoldInferenceFunctions(LogicalPreOptimizerContext preOptimizerContext) { - inferenceFunctionEvaluator = new InferenceFunctionEvaluator(preOptimizerContext.foldCtx(), preOptimizerContext.inferenceService()); + inferenceFunctionEvaluator = InferenceFunctionEvaluator.factory() + .create(preOptimizerContext.foldCtx(), preOptimizerContext.inferenceService()); } @Override diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluatorTests.java index a86cb722aeabc..b4c63e6553b95 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluatorTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluatorTests.java @@ -18,7 +18,6 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; -import org.elasticsearch.xpack.esql.core.expression.function.Function; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction; @@ -79,11 +78,7 @@ public void testFoldTextEmbeddingFunction() throws Exception { InferenceFunctionEvaluator.InferenceOperatorProvider inferenceOperatorProvider = (f, driverContext) -> operator; // Execute the fold operation - InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator( - FoldContext.small(), - mock(InferenceService.class), - inferenceOperatorProvider - ); + InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator(FoldContext.small(), inferenceOperatorProvider); AtomicReference resultExpression = new AtomicReference<>(); evaluator.fold(textEmbeddingFunction, ActionListener.wrap(resultExpression::set, ESTestCase::fail)); @@ -109,7 +104,6 @@ public void testFoldWithNonFoldableFunction() { InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator( FoldContext.small(), - mock(InferenceService.class), (f, driverContext) -> mock(Operator.class) ); @@ -131,19 +125,13 @@ public void testFoldWithAsyncFailure() throws Exception { // Mock an operator that will trigger an async failure Operator operator = mock(Operator.class); doAnswer(invocation -> { - // Simulate the operator finishing and then immediately calling the failure listener - // This happens inside the `DriverContext` logic that the evaluator uses. - // We can't directly access the listener, so we'll have the operator throw an exception - // which should be caught and propagated to the listener. + // Simulate the operator finishing and then immediately calling the failure listener. + // In that case getOutput() will replay the failure when called allowing us to catch the error. throw new RuntimeException("async failure"); - }).when(operator).addInput(new Page(1)); + }).when(operator).getOutput(); InferenceFunctionEvaluator.InferenceOperatorProvider inferenceOperatorProvider = (f, driverContext) -> operator; - InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator( - FoldContext.small(), - mock(InferenceService.class), - inferenceOperatorProvider - ); + InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator(FoldContext.small(), inferenceOperatorProvider); AtomicReference error = new AtomicReference<>(); evaluator.fold(textEmbeddingFunction, ActionListener.wrap(r -> fail("should have failed"), error::set)); @@ -166,11 +154,7 @@ public void testFoldWithNullOutputPage() throws Exception { when(operator.getOutput()).thenReturn(null); InferenceFunctionEvaluator.InferenceOperatorProvider inferenceOperatorProvider = (f, driverContext) -> operator; - InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator( - FoldContext.small(), - mock(InferenceService.class), - inferenceOperatorProvider - ); + InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator(FoldContext.small(), inferenceOperatorProvider); AtomicReference error = new AtomicReference<>(); evaluator.fold(textEmbeddingFunction, ActionListener.wrap(r -> fail("should have failed"), error::set)); @@ -182,78 +166,16 @@ public void testFoldWithNullOutputPage() throws Exception { allBreakersEmpty(); } - public void testFoldWithMultiPositionOutputPage() throws Exception { - TextEmbedding textEmbeddingFunction = new TextEmbedding( - Source.EMPTY, - Literal.keyword(Source.EMPTY, "test-model"), - Literal.keyword(Source.EMPTY, "test input") - ); - - Operator operator = mock(Operator.class); - // Output page should have exactly one position for constant folding - when(operator.getOutput()).thenReturn(new Page(blockFactory().newFloatBlockBuilder(2).build())); - - InferenceFunctionEvaluator.InferenceOperatorProvider inferenceOperatorProvider = (f, driverContext) -> operator; - InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator( - FoldContext.small(), - mock(InferenceService.class), - inferenceOperatorProvider - ); - - AtomicReference error = new AtomicReference<>(); - evaluator.fold(textEmbeddingFunction, ActionListener.wrap(r -> fail("should have failed"), error::set)); - - assertBusy(() -> assertNotNull(error.get())); - assertThat(error.get(), instanceOf(IllegalStateException.class)); - assertThat(error.get().getMessage(), equalTo("Expected a single block with a single value from inference operator")); - - allBreakersEmpty(); - } - - public void testFoldWithMultiBlockOutputPage() throws Exception { - TextEmbedding textEmbeddingFunction = new TextEmbedding( - Source.EMPTY, - Literal.keyword(Source.EMPTY, "test-model"), - Literal.keyword(Source.EMPTY, "test input") - ); - - Operator operator = mock(Operator.class); - // Output page should have exactly one block for constant folding - when(operator.getOutput()).thenReturn( - new Page(blockFactory().newFloatBlockBuilder(1).build(), blockFactory().newFloatBlockBuilder(1).build()) - ); - - InferenceFunctionEvaluator.InferenceOperatorProvider inferenceOperatorProvider = (f, driverContext) -> operator; - InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator( - FoldContext.small(), - mock(InferenceService.class), - inferenceOperatorProvider - ); - - AtomicReference error = new AtomicReference<>(); - evaluator.fold(textEmbeddingFunction, ActionListener.wrap(r -> fail("should have failed"), error::set)); - - assertBusy(() -> assertNotNull(error.get())); - assertThat(error.get(), instanceOf(IllegalStateException.class)); - assertThat(error.get().getMessage(), equalTo("Expected a single block with a single value from inference operator")); - - allBreakersEmpty(); - } - public void testFoldWithUnsupportedFunction() throws Exception { - Function unsupported = mock(Function.class); + InferenceFunction unsupported = mock(InferenceFunction.class); when(unsupported.foldable()).thenReturn(true); - InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator( - FoldContext.small(), - mock(InferenceService.class), - (f, driverContext) -> { - throw new IllegalArgumentException("Unknown inference function: " + f.getClass().getName()); - } - ); + InferenceFunctionEvaluator evaluator = new InferenceFunctionEvaluator(FoldContext.small(), (f, driverContext) -> { + throw new IllegalArgumentException("Unknown inference function: " + f.getClass().getName()); + }); AtomicReference error = new AtomicReference<>(); - evaluator.fold((InferenceFunction) unsupported, ActionListener.wrap(r -> fail("should have failed"), error::set)); + evaluator.fold(unsupported, ActionListener.wrap(r -> fail("should have failed"), error::set)); assertNotNull(error.get()); assertThat(error.get(), instanceOf(IllegalArgumentException.class)); From 238c0c2dd70bd7acc6edf5d4fc23e2743d59b85b Mon Sep 17 00:00:00 2001 From: afoucret Date: Thu, 18 Sep 2025 16:30:11 +0200 Subject: [PATCH 19/26] LogicalPlanPreOptimizerTests rule chain tests. --- .../optimizer/LogicalPlanPreOptimizer.java | 6 +- .../LogicalPlanPreOptimizerTests.java | 113 +++++++++++++++++- 2 files changed, 117 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java index dbb1fb45334da..f5727bcc107c2 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java @@ -26,7 +26,11 @@ public class LogicalPlanPreOptimizer { private final List preOptimizerRules; public LogicalPlanPreOptimizer(LogicalPreOptimizerContext preOptimizerContext) { - preOptimizerRules = List.of(new FoldInferenceFunctions(preOptimizerContext)); + this(List.of(new FoldInferenceFunctions(preOptimizerContext))); + } + + LogicalPlanPreOptimizer(List preOptimizerRules) { + this.preOptimizerRules = preOptimizerRules; } /** diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java index fd859cc26fc1e..0905624b445f9 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java @@ -10,6 +10,7 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionListener; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Expression; @@ -17,13 +18,18 @@ import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.expression.function.scalar.string.Concat; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; +import org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer.PreOptimizerRule; import org.elasticsearch.xpack.esql.plan.logical.Eval; import org.elasticsearch.xpack.esql.plan.logical.Filter; import org.elasticsearch.xpack.esql.plan.logical.Limit; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.Project; +import org.junit.After; +import org.junit.Before; import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import static org.elasticsearch.xpack.esql.EsqlTestUtils.as; import static org.elasticsearch.xpack.esql.EsqlTestUtils.fieldAttribute; @@ -33,9 +39,21 @@ public class LogicalPlanPreOptimizerTests extends ESTestCase { + private ThreadPool threadPool; + + @Before + public void setUpThreadPool() { + threadPool = createThreadPool(); + } + + @After + public void tearDownThreadPool() { + terminate(threadPool); + } + public void testPlanIsMarkedAsPreOptimized() throws Exception { for (int round = 0; round < 100; round++) { - // We want to make sure that the pre-optimizer woks for a wide range of plans + // We want to make sure that the pre-optimizer works for a wide range of plans preOptimizedPlan(randomPlan()); } } @@ -52,6 +70,75 @@ public void testPreOptimizeFailsIfPlanIsNotAnalyzed() throws Exception { }); } + public void testPreOptimizerRulesAreAppliedInOrder() throws Exception { + LogicalPlan plan = EsqlTestUtils.relation(); + plan.setPreOptimized(); + + StringBuilder executionOrder = new StringBuilder(); + + // Create mock rules that track execution order + PreOptimizerRule rule1 = createOrderTrackingRule("A", executionOrder); + PreOptimizerRule rule2 = createOrderTrackingRule("B", executionOrder); + PreOptimizerRule rule3 = createOrderTrackingRule("C", executionOrder); + + LogicalPlanPreOptimizer preOptimizer = new LogicalPlanPreOptimizer(List.of(rule1, rule2, rule3)); + + SetOnce resultHolder = new SetOnce<>(); + + preOptimizer.preOptimize(plan, ActionListener.wrap(resultHolder::set, ESTestCase::fail)); + + assertBusy(() -> { + assertThat(resultHolder.get(), notNullValue()); + // Rules should be applied in the order they were provided + assertThat(executionOrder.toString(), equalTo("ABC")); + assertThat(resultHolder.get().preOptimized(), equalTo(true)); + }); + } + + public void testPreOptimizerWithEmptyRulesList() throws Exception { + LogicalPlan plan = EsqlTestUtils.relation(); + plan.setPreOptimized(); + + LogicalPlanPreOptimizer preOptimizer = new LogicalPlanPreOptimizer(List.of()); + + SetOnce resultHolder = new SetOnce<>(); + + preOptimizer.preOptimize(plan, ActionListener.wrap(resultHolder::set, ESTestCase::fail)); + + assertBusy(() -> { + assertThat(resultHolder.get(), notNullValue()); + assertThat(resultHolder.get().preOptimized(), equalTo(true)); + // The plan should be the same as the original (no modifications) + assertThat(resultHolder.get(), equalTo(plan)); + }); + } + + public void testPreOptimizerRuleFailurePropagatesError() throws Exception { + LogicalPlan plan = EsqlTestUtils.relation(); + plan.setPreOptimized(); + + RuntimeException expectedError = new RuntimeException("Mock rule failure"); + + AtomicInteger ruleACounter = new AtomicInteger(); + PreOptimizerRule ruleA = createMockRule(ruleACounter); + PreOptimizerRule ruleB = createFailingRule(expectedError); + AtomicInteger ruleCCounter = new AtomicInteger(); + PreOptimizerRule ruleC = createMockRule(ruleCCounter); + + LogicalPlanPreOptimizer preOptimizer = new LogicalPlanPreOptimizer(List.of(ruleA, ruleB, ruleC)); + + SetOnce exceptionHolder = new SetOnce<>(); + + preOptimizer.preOptimize(plan, ActionListener.wrap(r -> fail("Should have failed"), exceptionHolder::set)); + + assertBusy(() -> { + assertThat(exceptionHolder.get(), notNullValue()); + assertThat(exceptionHolder.get(), equalTo(expectedError)); + assertThat(ruleACounter.get(), equalTo(1)); + assertThat(ruleCCounter.get(), equalTo(0)); + }); + } + public LogicalPlan preOptimizedPlan(LogicalPlan plan) throws Exception { // set plan as analyzed plan.setPreOptimized(); @@ -107,4 +194,28 @@ private Expression randomCondition() { return EsqlTestUtils.greaterThanOf(randomExpression(), randomExpression()); } + + // Helper methods for creating mock rules + + private PreOptimizerRule createMockRule(AtomicInteger executionCounter) { + return (plan, listener) -> { + threadPool.schedule(() -> { + executionCounter.incrementAndGet(); + listener.onResponse(plan); // Return the plan unchanged + }, randomTimeValue(1, 100, TimeUnit.MILLISECONDS), threadPool.executor(ThreadPool.Names.GENERIC)); + }; + } + + private PreOptimizerRule createOrderTrackingRule(String ruleId, StringBuilder executionOrder) { + return (plan, listener) -> { + threadPool.schedule(() -> { + executionOrder.append(ruleId); + listener.onResponse(plan); // Return the plan unchanged + }, randomTimeValue(1, 100, TimeUnit.MILLISECONDS), threadPool.executor(ThreadPool.Names.GENERIC)); + }; + } + + private PreOptimizerRule createFailingRule(Exception error) { + return (plan, listener) -> listener.onFailure(error); + } } From 816c410a0e8d46d026951bad47ea12046f4a7d80 Mon Sep 17 00:00:00 2001 From: afoucret Date: Thu, 18 Sep 2025 17:42:33 +0200 Subject: [PATCH 20/26] More unit tests. --- .../optimizer/LogicalPlanPreOptimizer.java | 8 +- .../preoptimizer/FoldInferenceFunctions.java | 9 +- ....java => LogicalPlanPreOptimizerRule.java} | 2 +- .../LogicalPlanPreOptimizerTests.java | 20 +-- .../FoldInferenceFunctionsTests.java | 156 ++++++++++++++++++ 5 files changed, 177 insertions(+), 18 deletions(-) rename x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/{PreOptimizerRule.java => LogicalPlanPreOptimizerRule.java} (94%) create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/FoldInferenceFunctionsTests.java diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java index f5727bcc107c2..11c57b38c2331 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java @@ -10,7 +10,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer.FoldInferenceFunctions; -import org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer.PreOptimizerRule; +import org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer.LogicalPlanPreOptimizerRule; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import java.util.List; @@ -23,13 +23,13 @@ *

*/ public class LogicalPlanPreOptimizer { - private final List preOptimizerRules; + private final List preOptimizerRules; public LogicalPlanPreOptimizer(LogicalPreOptimizerContext preOptimizerContext) { this(List.of(new FoldInferenceFunctions(preOptimizerContext))); } - LogicalPlanPreOptimizer(List preOptimizerRules) { + LogicalPlanPreOptimizer(List preOptimizerRules) { this.preOptimizerRules = preOptimizerRules; } @@ -59,7 +59,7 @@ public void preOptimize(LogicalPlan plan, ActionListener listener) */ private void doPreOptimize(LogicalPlan plan, ActionListener listener) { SubscribableListener rulesListener = SubscribableListener.newSucceeded(plan); - for (PreOptimizerRule preOptimizerRule : preOptimizerRules) { + for (LogicalPlanPreOptimizerRule preOptimizerRule : preOptimizerRules) { rulesListener = rulesListener.andThen((l, p) -> preOptimizerRule.apply(p, l)); } rulesListener.addListener(listener); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/FoldInferenceFunctions.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/FoldInferenceFunctions.java index 63b6d4c34879d..4a3fc87c43795 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/FoldInferenceFunctions.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/FoldInferenceFunctions.java @@ -32,13 +32,16 @@ * Example transformation: * {@code TEXT_EMBEDDING("hello world", "model1")} → {@code [0.1, 0.2, 0.3, ...]} */ -public class FoldInferenceFunctions implements PreOptimizerRule { +public class FoldInferenceFunctions implements LogicalPlanPreOptimizerRule { private final InferenceFunctionEvaluator inferenceFunctionEvaluator; public FoldInferenceFunctions(LogicalPreOptimizerContext preOptimizerContext) { - inferenceFunctionEvaluator = InferenceFunctionEvaluator.factory() - .create(preOptimizerContext.foldCtx(), preOptimizerContext.inferenceService()); + this(InferenceFunctionEvaluator.factory().create(preOptimizerContext.foldCtx(), preOptimizerContext.inferenceService())); + } + + protected FoldInferenceFunctions(InferenceFunctionEvaluator inferenceFunctionEvaluator) { + this.inferenceFunctionEvaluator = inferenceFunctionEvaluator; } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/PreOptimizerRule.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/LogicalPlanPreOptimizerRule.java similarity index 94% rename from x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/PreOptimizerRule.java rename to x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/LogicalPlanPreOptimizerRule.java index e29216a020afb..29d9be564f1bf 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/PreOptimizerRule.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/LogicalPlanPreOptimizerRule.java @@ -13,7 +13,7 @@ /** * A rule that can be applied to an analyzed logical plan before it is optimized. */ -public interface PreOptimizerRule { +public interface LogicalPlanPreOptimizerRule { /** * Apply the rule to the logical plan. diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java index 0905624b445f9..97792e4387b22 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizerTests.java @@ -18,7 +18,7 @@ import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.expression.function.scalar.string.Concat; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; -import org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer.PreOptimizerRule; +import org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer.LogicalPlanPreOptimizerRule; import org.elasticsearch.xpack.esql.plan.logical.Eval; import org.elasticsearch.xpack.esql.plan.logical.Filter; import org.elasticsearch.xpack.esql.plan.logical.Limit; @@ -77,9 +77,9 @@ public void testPreOptimizerRulesAreAppliedInOrder() throws Exception { StringBuilder executionOrder = new StringBuilder(); // Create mock rules that track execution order - PreOptimizerRule rule1 = createOrderTrackingRule("A", executionOrder); - PreOptimizerRule rule2 = createOrderTrackingRule("B", executionOrder); - PreOptimizerRule rule3 = createOrderTrackingRule("C", executionOrder); + LogicalPlanPreOptimizerRule rule1 = createOrderTrackingRule("A", executionOrder); + LogicalPlanPreOptimizerRule rule2 = createOrderTrackingRule("B", executionOrder); + LogicalPlanPreOptimizerRule rule3 = createOrderTrackingRule("C", executionOrder); LogicalPlanPreOptimizer preOptimizer = new LogicalPlanPreOptimizer(List.of(rule1, rule2, rule3)); @@ -120,10 +120,10 @@ public void testPreOptimizerRuleFailurePropagatesError() throws Exception { RuntimeException expectedError = new RuntimeException("Mock rule failure"); AtomicInteger ruleACounter = new AtomicInteger(); - PreOptimizerRule ruleA = createMockRule(ruleACounter); - PreOptimizerRule ruleB = createFailingRule(expectedError); + LogicalPlanPreOptimizerRule ruleA = createMockRule(ruleACounter); + LogicalPlanPreOptimizerRule ruleB = createFailingRule(expectedError); AtomicInteger ruleCCounter = new AtomicInteger(); - PreOptimizerRule ruleC = createMockRule(ruleCCounter); + LogicalPlanPreOptimizerRule ruleC = createMockRule(ruleCCounter); LogicalPlanPreOptimizer preOptimizer = new LogicalPlanPreOptimizer(List.of(ruleA, ruleB, ruleC)); @@ -197,7 +197,7 @@ private Expression randomCondition() { // Helper methods for creating mock rules - private PreOptimizerRule createMockRule(AtomicInteger executionCounter) { + private LogicalPlanPreOptimizerRule createMockRule(AtomicInteger executionCounter) { return (plan, listener) -> { threadPool.schedule(() -> { executionCounter.incrementAndGet(); @@ -206,7 +206,7 @@ private PreOptimizerRule createMockRule(AtomicInteger executionCounter) { }; } - private PreOptimizerRule createOrderTrackingRule(String ruleId, StringBuilder executionOrder) { + private LogicalPlanPreOptimizerRule createOrderTrackingRule(String ruleId, StringBuilder executionOrder) { return (plan, listener) -> { threadPool.schedule(() -> { executionOrder.append(ruleId); @@ -215,7 +215,7 @@ private PreOptimizerRule createOrderTrackingRule(String ruleId, StringBuilder ex }; } - private PreOptimizerRule createFailingRule(Exception error) { + private LogicalPlanPreOptimizerRule createFailingRule(Exception error) { return (plan, listener) -> listener.onFailure(error); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/FoldInferenceFunctionsTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/FoldInferenceFunctionsTests.java new file mode 100644 index 0000000000000..46f4a1ccf0cc7 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/FoldInferenceFunctionsTests.java @@ -0,0 +1,156 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer; + +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.core.expression.Alias; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction; +import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding; +import org.elasticsearch.xpack.esql.expression.function.vector.Knn; +import org.elasticsearch.xpack.esql.inference.InferenceFunctionEvaluator; +import org.elasticsearch.xpack.esql.plan.logical.EsRelation; +import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.Filter; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; + +import java.util.List; + +import static org.elasticsearch.xpack.esql.EsqlTestUtils.as; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.of; +import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; +import static org.elasticsearch.xpack.esql.optimizer.LocalLogicalPlanOptimizerTests.relation; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; + +public class FoldInferenceFunctionsTests extends ESTestCase { + /** + * Tests that the rule correctly evaluates TEXT_EMBEDDING functions in Eval nodes. + */ + public void testEvalFunctionEmbedding() throws Exception { + for (int round = 0; round < 100; round++) { + // Setup: Create a plan with an Eval node containing a TEXT_EMBEDDING function + String inferenceId = randomUUID(); + String query = String.join(" ", randomArray(1, String[]::new, () -> randomAlphaOfLength(randomIntBetween(1, 10)))); + int dimensions = randomIntBetween(1, 2048); + String fieldName = randomIdentifier(); + + logger.info("query sent: {}", query); + + EsRelation relation = relation(); + Eval eval = new Eval( + Source.EMPTY, + relation, + List.of(new Alias(Source.EMPTY, fieldName, new TextEmbedding(Source.EMPTY, of(query), of(inferenceId)))) + ); + eval.setAnalyzed(); + + SetOnce resultHolder = new SetOnce<>(); + createFoldInferenceFunctionRule(dimensions).apply(eval, ActionListener.wrap(resultHolder::set, ESTestCase::fail)); + + assertBusy(() -> { + Eval preOptimizedEval = as(resultHolder.get(), Eval.class); + assertThat(preOptimizedEval.fields(), hasSize(1)); + assertThat(preOptimizedEval.fields().get(0).name(), equalTo(fieldName)); + Literal preOptimizedQuery = as(preOptimizedEval.fields().get(0).child(), Literal.class); + assertThat(preOptimizedQuery.dataType(), equalTo(DENSE_VECTOR)); + assertThat(preOptimizedQuery.value(), equalTo(generateTestEmbedding(query, dimensions))); + }); + + + } + } + + /** + * Tests that the rule correctly evaluates TEXT_EMBEDDING functions as KNN query. + */ + public void testKnnFunctionEmbedding() throws Exception { + for (int round = 0; round < 100; round++) { + // Setup: Create a plan with a Filter node containing a KNN predicate with a TEXT_EMBEDDING function + String inferenceId = randomUUID(); + String query = String.join(" ", randomArray(1, String[]::new, () -> randomAlphaOfLength(randomIntBetween(1, 10)))); + int dimensions = randomIntBetween(1, 2048); + + EsRelation relation = relation(); + Filter filter = new Filter( + Source.EMPTY, + relation, + new Knn(Source.EMPTY, getFieldAttribute("a"), new TextEmbedding(Source.EMPTY, of(query), of(inferenceId)), null) + ); + Knn knn = as(filter.condition(), Knn.class); + + SetOnce resultHolder = new SetOnce<>(); + createFoldInferenceFunctionRule(dimensions).apply(filter, ActionListener.wrap(resultHolder::set, ESTestCase::fail)); + + assertBusy(() -> { + Filter preOptimizedFilter = as(resultHolder.get(), Filter.class); + + Knn preOptimizedKnn = as(preOptimizedFilter.condition(), Knn.class); + assertThat(preOptimizedKnn.field(), equalTo(knn.field())); + assertThat(preOptimizedKnn.k(), equalTo(knn.k())); + assertThat(preOptimizedKnn.options(), equalTo(knn.options())); + + Literal preOptimizedQuery = as(preOptimizedKnn.query(), Literal.class); + assertThat(preOptimizedQuery.dataType(), equalTo(DENSE_VECTOR)); + assertThat(preOptimizedQuery.value(), equalTo(generateTestEmbedding(query, dimensions))); + }); + } + } + + @SuppressWarnings("unchecked") + private LogicalPlanPreOptimizerRule createFoldInferenceFunctionRule(int dimensions) { + InferenceFunctionEvaluator functionEvaluator = mock(InferenceFunctionEvaluator.class); + + doAnswer((i) -> { + ActionListener l = i.getArgument(1, ActionListener.class); + InferenceFunction function = i.getArgument(0, InferenceFunction.class); + if (function instanceof TextEmbedding textEmbedding) { + BytesRef bytesRef = (BytesRef) as(textEmbedding.inputText(), Literal.class).value(); + float[] embedding = generateTestEmbedding(bytesRef.utf8ToString(), dimensions); + l.onResponse(Literal.of(function, embedding)); + } else { + fail("Unexpected function type: " + function.getClass().getSimpleName()); + } + return null; + }).when(functionEvaluator).fold(any(InferenceFunction.class), any(ActionListener.class)); + return new FoldInferenceFunctions(functionEvaluator); + } + + /** + * Generates a deterministic mock embedding vector based on the input text. + * This ensures our tests are repeatable and verifiable. + */ + private float[] generateTestEmbedding(String inputText, int dimensions) { + float[] embedding = new float[dimensions]; + + for (int i = 0; i < dimensions; i++) { + embedding[i] = generateMockFloatEmbeddingValue(inputText, i); + } + + return embedding; + } + + /** + * Generates a single embedding value for a specific dimension based on input text. + * Uses a deterministic function so tests are repeatable. + */ + private float generateMockFloatEmbeddingValue(String inputText, int dimension) { + // Create a deterministic value based on input text and dimension + int hash = (inputText.hashCode() + dimension * 31) % 10000; + return hash / 10000.0f; // Normalize to [0, 1) range + } +} From ea3de8ba15c87e5a61d964bec944a962b3df79c7 Mon Sep 17 00:00:00 2001 From: afoucret Date: Thu, 18 Sep 2025 18:17:39 +0200 Subject: [PATCH 21/26] More CSV tests. --- .../main/resources/text-embedding.csv-spec | 13 +++ .../inference/InferenceFunctionEvaluator.java | 66 +++++++++------- .../inference/bulk/BulkInferenceRunner.java | 79 ++++++++++--------- 3 files changed, 91 insertions(+), 67 deletions(-) diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec index 8067eb5219b5a..eeef71f453147 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec @@ -9,3 +9,16 @@ ROW input="Who is Victor Hugo?" input:keyword | embedding:dense_vector Who is Victor Hugo? | [56.0, 50.0, 48.0] ; + + +text_embedding using a ROW source operator with query build using CONCAT +required_capability: text_embedding_function +required_capability: dense_vector_field_type + +ROW input="Who is Victor Hugo?" +| EVAL embedding = TEXT_EMBEDDING(CONCAT("Who is ", "Victor Hugo?"), "test_dense_inference") +; + +input:keyword | embedding:dense_vector +Who is Victor Hugo? | [56.0, 50.0, 48.0] +; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java index 4caf91a8bf917..eba1e11cb296f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java @@ -17,6 +17,7 @@ import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.EvalOperator; import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.core.Releasables; import org.elasticsearch.indices.breaker.AllCircuitBreakerStats; import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.indices.breaker.CircuitBreakerStats; @@ -111,36 +112,43 @@ public CircuitBreakerStats stats(String name) { DriverContext driverContext = new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays)); // Create the inference operator for the specific function type using the provider - - try (Operator inferenceOperator = inferenceOperatorProvider.getOperator(f, driverContext)) { - // Execute the inference operation asynchronously and handle the result - // The operator will perform the actual inference call and return a page with the result - driverContext.waitForAsyncActions(listener.delegateFailureIgnoreResponseAndWrap(l -> { - Page output = inferenceOperator.getOutput(); - - try { - if (output == null) { - l.onFailure(new IllegalStateException("Expected output page from inference operator")); - return; + try { + Operator inferenceOperator = inferenceOperatorProvider.getOperator(f, driverContext); + + try { + // Feed the operator with a single page to trigger execution + // The actual input data is already bound in the operator through expression evaluators + inferenceOperator.addInput(new Page(1)); + + // Execute the inference operation asynchronously and handle the result + // The operator will perform the actual inference call and return a page with the result + driverContext.waitForAsyncActions(listener.delegateFailureIgnoreResponseAndWrap(l -> { + Page output = inferenceOperator.getOutput(); + + try { + if (output == null) { + l.onFailure(new IllegalStateException("Expected output page from inference operator")); + return; + } + + if (output.getPositionCount() != 1 || output.getBlockCount() != 1) { + l.onFailure(new IllegalStateException("Expected a single block with a single value from inference operator")); + return; + } + + // Convert the operator result back to an ESQL expression (Literal) + l.onResponse(Literal.of(f, BlockUtils.toJavaObject(output.getBlock(0), 0))); + } finally { + Releasables.close(inferenceOperator); + if (output != null) { + output.releaseBlocks(); + } } - - if (output.getPositionCount() != 1 || output.getBlockCount() != 1) { - l.onFailure(new IllegalStateException("Expected a single block with a single value from inference operator")); - return; - } - - // Convert the operator result back to an ESQL expression (Literal) - l.onResponse(Literal.of(f, BlockUtils.toJavaObject(output.getBlock(0), 0))); - } finally { - if (output != null) { - output.releaseBlocks(); - } - } - })); - - // Feed the operator with a single page to trigger execution - // The actual input data is already bound in the operator through expression evaluators - inferenceOperator.addInput(new Page(1)); + })); + } catch (Exception e) { + Releasables.close(inferenceOperator); + listener.onFailure(e); + } } catch (Exception e) { listener.onFailure(e); } finally { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java index 203a3031bcad4..f220f1842953c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRunner.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.inference.bulk; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ThreadedActionListener; import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.threadpool.ThreadPool; @@ -25,7 +26,6 @@ import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; -import static org.elasticsearch.xpack.esql.plugin.EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME; /** * Implementation of bulk inference execution with throttling and concurrency control. @@ -88,7 +88,7 @@ public BulkInferenceRequest poll() { public BulkInferenceRunner(Client client, int maxRunningTasks) { this.permits = new Semaphore(maxRunningTasks); this.client = client; - this.executor = client.threadPool().executor(ESQL_WORKER_THREAD_POOL_NAME); + this.executor = client.threadPool().executor(ThreadPool.Names.SEARCH); } /** @@ -253,48 +253,51 @@ private void executePendingRequests(int recursionDepth) { executionState.finish(); } - final ActionListener inferenceResponseListener = ActionListener.runAfter( - ActionListener.wrap( - r -> executionState.onInferenceResponse(bulkRequestItem.seqNo(), r), - e -> executionState.onInferenceException(bulkRequestItem.seqNo(), e) - ), - () -> { - // Release the permit we used - permits.release(); - - try { - synchronized (executionState) { - persistPendingResponses(); - } + final ActionListener inferenceResponseListener = new ThreadedActionListener<>( + executor, + ActionListener.runAfter( + ActionListener.wrap( + r -> executionState.onInferenceResponse(bulkRequestItem.seqNo(), r), + e -> executionState.onInferenceException(bulkRequestItem.seqNo(), e) + ), + () -> { + // Release the permit we used + permits.release(); + + try { + synchronized (executionState) { + persistPendingResponses(); + } - if (executionState.finished() && responseSent.compareAndSet(false, true)) { - onBulkCompletion(); - } + if (executionState.finished() && responseSent.compareAndSet(false, true)) { + onBulkCompletion(); + } - if (responseSent.get()) { - // Response has already been sent - // No need to continue processing this bulk. - // Check if another bulk request is pending for execution. - BulkInferenceRequest nexBulkRequest = pendingBulkRequests.poll(); - if (nexBulkRequest != null) { - executor.execute(nexBulkRequest::executePendingRequests); + if (responseSent.get()) { + // Response has already been sent + // No need to continue processing this bulk. + // Check if another bulk request is pending for execution. + BulkInferenceRequest nexBulkRequest = pendingBulkRequests.poll(); + if (nexBulkRequest != null) { + executor.execute(nexBulkRequest::executePendingRequests); + } + return; } - return; - } - if (executionState.finished() == false) { - // Execute any pending requests if any - if (recursionDepth > 100) { - executor.execute(this::executePendingRequests); - } else { - this.executePendingRequests(recursionDepth + 1); + if (executionState.finished() == false) { + // Execute any pending requests if any + if (recursionDepth > 100) { + executor.execute(this::executePendingRequests); + } else { + this.executePendingRequests(recursionDepth + 1); + } + } + } catch (Exception e) { + if (responseSent.compareAndSet(false, true)) { + completionListener.onFailure(e); } - } - } catch (Exception e) { - if (responseSent.compareAndSet(false, true)) { - completionListener.onFailure(e); } } - } + ) ); // Handle null requests (edge case in some iterators) From 42bc936c362e3aa49dcc8902c4fb8dd5f81527cf Mon Sep 17 00:00:00 2001 From: afoucret Date: Thu, 18 Sep 2025 19:15:23 +0200 Subject: [PATCH 22/26] Lint --- .../rules/logical/preoptimizer/FoldInferenceFunctionsTests.java | 1 - 1 file changed, 1 deletion(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/FoldInferenceFunctionsTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/FoldInferenceFunctionsTests.java index 46f4a1ccf0cc7..50fcdab626501 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/FoldInferenceFunctionsTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/preoptimizer/FoldInferenceFunctionsTests.java @@ -71,7 +71,6 @@ public void testEvalFunctionEmbedding() throws Exception { assertThat(preOptimizedQuery.value(), equalTo(generateTestEmbedding(query, dimensions))); }); - } } From 1d1cc3a3c262e46ace0fc61ad15a04463ba4fd7f Mon Sep 17 00:00:00 2001 From: afoucret Date: Thu, 18 Sep 2025 19:16:26 +0200 Subject: [PATCH 23/26] Remove useless changes. --- .../elasticsearch/xpack/esql/inference/InferenceOperator.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java index a95b72d64f9da..f93c024eeab73 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java @@ -128,7 +128,7 @@ public Page getOutput() { protected abstract OutputBuilder outputBuilder(Page input); /** - * An interface for accumulating inference responses and constructing the result page.. + * An interface for accumulating inference responses and constructing a result {@link Page}. */ public interface OutputBuilder extends Releasable { From 75ed98856f0670d2a05d068a5f082951859a631b Mon Sep 17 00:00:00 2001 From: afoucret Date: Thu, 18 Sep 2025 19:16:42 +0200 Subject: [PATCH 24/26] More CSV tests :tada: --- .../main/resources/text-embedding.csv-spec | 45 ++++++++++++++++++- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec index eeef71f453147..17f9e71ff185a 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec @@ -1,4 +1,4 @@ -text_embedding using a ROW source operator +text_embedding using a row source operator required_capability: text_embedding_function required_capability: dense_vector_field_type @@ -11,7 +11,7 @@ Who is Victor Hugo? | [56.0, 50.0, 48.0] ; -text_embedding using a ROW source operator with query build using CONCAT +text_embedding using a row source operator with query build using CONCAT required_capability: text_embedding_function required_capability: dense_vector_field_type @@ -22,3 +22,44 @@ ROW input="Who is Victor Hugo?" input:keyword | embedding:dense_vector Who is Victor Hugo? | [56.0, 50.0, 48.0] ; + + +text_embedding with knn on semantic_text_dense_field +required_capability: text_embedding_function +required_capability: dense_vector_field_type +required_capability: knn_function_v5 +required_capability: semantic_text_field_caps + +FROM semantic_text METADATA _score +| EVAL query_embedding = TEXT_EMBEDDING("be excellent to each other", "test_dense_inference") +| WHERE KNN(semantic_text_dense_field, query_embedding) +| KEEP semantic_text_field, query_embedding, _score +| EVAL _score = ROUND(_score, 4) +| SORT _score DESC +| LIMIT 10 +; + +semantic_text_field:text | query_embedding:dense_vector | _score:double +be excellent to each other | [45.0, 55.0, 54.0] | 1.0 +live long and prosper | [45.0, 55.0, 54.0] | 0.0295 +all we have to decide is what to do with the time that is given to us | [45.0, 55.0, 54.0] | 0.0214 + +text_embedding with knn (inline) on semantic_text_dense_field +required_capability: text_embedding_function +required_capability: dense_vector_field_type +required_capability: knn_function_v5 +required_capability: semantic_text_field_caps + +FROM semantic_text METADATA _score +| WHERE KNN(semantic_text_dense_field, TEXT_EMBEDDING("be excellent to each other", "test_dense_inference")) +| KEEP semantic_text_field, _score +| EVAL _score = ROUND(_score, 4) +| SORT _score DESC +| LIMIT 10 +; + +semantic_text_field:text | _score:double +be excellent to each other | 1.0 +live long and prosper | 0.0295 +all we have to decide is what to do with the time that is given to us | 0.0214 +; From 2d4749b6a1ab11d2ad790ecefe6a70a0725803bc Mon Sep 17 00:00:00 2001 From: afoucret Date: Thu, 18 Sep 2025 19:19:32 +0200 Subject: [PATCH 25/26] Fix a typo error --- .../qa/testFixtures/src/main/resources/text-embedding.csv-spec | 1 + 1 file changed, 1 insertion(+) diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec index 17f9e71ff185a..4be9bacab399d 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/text-embedding.csv-spec @@ -43,6 +43,7 @@ semantic_text_field:text | query_em be excellent to each other | [45.0, 55.0, 54.0] | 1.0 live long and prosper | [45.0, 55.0, 54.0] | 0.0295 all we have to decide is what to do with the time that is given to us | [45.0, 55.0, 54.0] | 0.0214 +; text_embedding with knn (inline) on semantic_text_dense_field required_capability: text_embedding_function From e789dc2a4b66a1dfe02600bff119652bfb9c783e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20FOUCRET?= Date: Thu, 18 Sep 2025 19:22:14 +0200 Subject: [PATCH 26/26] Update docs/changelog/134573.yaml --- docs/changelog/134573.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/134573.yaml diff --git a/docs/changelog/134573.yaml b/docs/changelog/134573.yaml new file mode 100644 index 0000000000000..64e79561678a2 --- /dev/null +++ b/docs/changelog/134573.yaml @@ -0,0 +1,5 @@ +pr: 134573 +summary: Esql text embedding function +area: ES|QL +type: feature +issues: []