diff --git a/docs/changelog/132765.yaml b/docs/changelog/132765.yaml new file mode 100644 index 0000000000000..1b019e224c0ae --- /dev/null +++ b/docs/changelog/132765.yaml @@ -0,0 +1,5 @@ +pr: 132765 +summary: Implement `v_magnitude` function +area: ES|QL +type: feature +issues: [132768] diff --git a/docs/reference/query-languages/esql/_snippets/functions/description/v_magnitude.md b/docs/reference/query-languages/esql/_snippets/functions/description/v_magnitude.md new file mode 100644 index 0000000000000..5b66acddf19c5 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/description/v_magnitude.md @@ -0,0 +1,6 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Description** + +Calculates the magnitude of a dense_vector. + diff --git a/docs/reference/query-languages/esql/_snippets/functions/examples/v_magnitude.md b/docs/reference/query-languages/esql/_snippets/functions/examples/v_magnitude.md new file mode 100644 index 0000000000000..c9bed2cbc864e --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/examples/v_magnitude.md @@ -0,0 +1,24 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Example** + +```esql + from colors + | eval magnitude = v_magnitude(rgb_vector) + | sort magnitude desc, color asc +``` + +| color:text | magnitude:double | +| --- | --- | +| white | 441.6729431152344 | +| snow | 435.9185791015625 | +| azure | 433.1858825683594 | +| ivory | 433.1858825683594 | +| mint cream | 433.0704345703125 | +| sea shell | 426.25579833984375 | +| honeydew | 424.5291442871094 | +| old lace | 420.6352233886719 | +| corn silk | 418.2451477050781 | +| linen | 415.93267822265625 | + + diff --git a/docs/reference/query-languages/esql/_snippets/functions/layout/v_magnitude.md b/docs/reference/query-languages/esql/_snippets/functions/layout/v_magnitude.md new file mode 100644 index 0000000000000..2d2e8ae1fc0e0 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/layout/v_magnitude.md @@ -0,0 +1,27 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +## `V_MAGNITUDE` [esql-v_magnitude] +```{applies_to} +stack: development +serverless: preview +``` + +**Syntax** + +:::{image} ../../../images/functions/v_magnitude.svg +:alt: Embedded +:class: text-center +::: + + +:::{include} ../parameters/v_magnitude.md +::: + +:::{include} ../description/v_magnitude.md +::: + +:::{include} ../types/v_magnitude.md +::: + +:::{include} ../examples/v_magnitude.md +::: diff --git a/docs/reference/query-languages/esql/_snippets/functions/parameters/v_magnitude.md b/docs/reference/query-languages/esql/_snippets/functions/parameters/v_magnitude.md new file mode 100644 index 0000000000000..5a7cf14ed7137 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/parameters/v_magnitude.md @@ -0,0 +1,7 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Parameters** + +`input` +: dense_vector for which to compute the magnitude + diff --git a/docs/reference/query-languages/esql/images/functions/v_magnitude.svg b/docs/reference/query-languages/esql/images/functions/v_magnitude.svg new file mode 100644 index 0000000000000..7b32eee3f3d65 --- /dev/null +++ b/docs/reference/query-languages/esql/images/functions/v_magnitude.svg @@ -0,0 +1 @@ +V_MAGNITUDE(input) \ No newline at end of file diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/v_magnitude.json b/docs/reference/query-languages/esql/kibana/definition/functions/v_magnitude.json new file mode 100644 index 0000000000000..2835d403e656e --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/definition/functions/v_magnitude.json @@ -0,0 +1,12 @@ +{ + "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.", + "type" : "scalar", + "name" : "v_magnitude", + "description" : "Calculates the magnitude of a dense_vector.", + "signatures" : [ ], + "examples" : [ + " from colors\n | eval magnitude = v_magnitude(rgb_vector)\n | sort magnitude desc, color asc" + ], + "preview" : true, + "snapshot_only" : true +} diff --git a/docs/reference/query-languages/esql/kibana/docs/functions/v_magnitude.md b/docs/reference/query-languages/esql/kibana/docs/functions/v_magnitude.md new file mode 100644 index 0000000000000..236f4880eda49 --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/docs/functions/v_magnitude.md @@ -0,0 +1,10 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +### V MAGNITUDE +Calculates the magnitude of a dense_vector. + +```esql + from colors + | eval magnitude = v_magnitude(rgb_vector) + | sort magnitude desc, color asc +``` diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-magnitude.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-magnitude.csv-spec new file mode 100644 index 0000000000000..c670cb9ec678e --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-magnitude.csv-spec @@ -0,0 +1,87 @@ + # Tests for v_magnitude scalar function + + magnitudeWithVectorField + required_capability: magnitude_scalar_vector_function + +// tag::vector-magnitude[] + from colors + | eval magnitude = v_magnitude(rgb_vector) + | sort magnitude desc, color asc +// end::vector-magnitude[] + | limit 10 + | keep color, magnitude + ; + +// tag::vector-magnitude-result[] +color:text | magnitude:double +white | 441.6729431152344 +snow | 435.9185791015625 +azure | 433.1858825683594 +ivory | 433.1858825683594 +mint cream | 433.0704345703125 +sea shell | 426.25579833984375 +honeydew | 424.5291442871094 +old lace | 420.6352233886719 +corn silk | 418.2451477050781 +linen | 415.93267822265625 +// end::vector-magnitude-result[] +; + + magnitudeAsPartOfExpression + required_capability: magnitude_scalar_vector_function + + from colors + | eval score = round((1 + v_magnitude(rgb_vector) / 2), 3) + | sort score desc, color asc + | limit 10 + | keep color, score + ; + +color:text | score:double +white | 221.836 +snow | 218.959 +azure | 217.593 +ivory | 217.593 +mint cream | 217.535 +sea shell | 214.128 +honeydew | 213.265 +old lace | 211.318 +corn silk | 210.123 +linen | 208.966 +; + +magnitudeWithLiteralVectors +required_capability: magnitude_scalar_vector_function + +row a = 1 +| eval magnitude = round(v_magnitude([1, 2, 3]), 3) +| keep magnitude +; + +magnitude:double +3.742 +; + + magnitudeWithStats + required_capability: magnitude_scalar_vector_function + + from colors + | eval magnitude = round(v_magnitude(rgb_vector), 3) + | stats avg = round(avg(magnitude), 3), min = min(magnitude), max = max(magnitude) + ; + +avg:double | min:double | max:double +313.692 | 0.0 | 441.673 +; + +magnitudeWithNull +required_capability: magnitude_scalar_vector_function + +row a = 1 +| eval magnitude = v_magnitude(null) +| keep magnitude +; + +magnitude:double +null +; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 14a79f54646ba..a397a84343e43 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -1356,6 +1356,11 @@ public enum Cap { */ CORRECT_SKIPPED_SHARDS_COUNT, + /* + * Support for calculating the scalar vector magnitude. + */ + MAGNITUDE_SCALAR_VECTOR_FUNCTION(Build.current().isSnapshot()), + /** * Byte elements dense vector field type support. */ diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index 0eca67f625121..f52c51cf8d3f9 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -188,6 +188,7 @@ import org.elasticsearch.xpack.esql.expression.function.vector.Knn; import org.elasticsearch.xpack.esql.expression.function.vector.L1Norm; import org.elasticsearch.xpack.esql.expression.function.vector.L2Norm; +import org.elasticsearch.xpack.esql.expression.function.vector.Magnitude; import org.elasticsearch.xpack.esql.parser.ParsingException; import org.elasticsearch.xpack.esql.session.Configuration; @@ -503,7 +504,8 @@ private static FunctionDefinition[][] snapshotFunctions() { def(CosineSimilarity.class, CosineSimilarity::new, "v_cosine"), def(DotProduct.class, DotProduct::new, "v_dot_product"), def(L1Norm.class, L1Norm::new, "v_l1_norm"), - def(L2Norm.class, L2Norm::new, "v_l2_norm") } }; + def(L2Norm.class, L2Norm::new, "v_l2_norm"), + def(Magnitude.class, Magnitude::new, "v_magnitude") } }; } public EsqlFunctionRegistry snapshotRegistry() { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Magnitude.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Magnitude.java new file mode 100644 index 0000000000000..56d1cc0d31b8d --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Magnitude.java @@ -0,0 +1,180 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.vector; + +import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; +import org.elasticsearch.xpack.esql.core.expression.function.scalar.UnaryScalarFunction; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper; +import org.elasticsearch.xpack.esql.expression.function.Example; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.Param; + +import java.io.IOException; + +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; +import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; + +public class Magnitude extends UnaryScalarFunction implements EvaluatorMapper, VectorFunction { + + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "Magnitude", + Magnitude::new + ); + static final ScalarEvaluatorFunction SCALAR_FUNCTION = Magnitude::calculateScalar; + + @FunctionInfo( + returnType = "double", + preview = true, + description = "Calculates the magnitude of a dense_vector.", + examples = { @Example(file = "vector-magnitude", tag = "vector-magnitude") }, + appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) } + ) + public Magnitude( + Source source, + @Param(name = "input", type = { "dense_vector" }, description = "dense_vector for which to compute the magnitude") Expression input + ) { + super(source, input); + } + + private Magnitude(StreamInput in) throws IOException { + super(in); + } + + @Override + protected UnaryScalarFunction replaceChild(Expression newChild) { + return new Magnitude(source(), newChild); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, Magnitude::new, field()); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + public static float calculateScalar(float[] scratch) { + return (float) Math.sqrt(VectorUtil.dotProduct(scratch, scratch)); + } + + @Override + public DataType dataType() { + return DataType.DOUBLE; + } + + @Override + protected TypeResolution resolveType() { + if (childrenResolved() == false) { + return new TypeResolution("Unresolved children"); + } + + return isType(field(), dt -> dt == DENSE_VECTOR, sourceText(), TypeResolutions.ParamOrdinal.FIRST, "dense_vector"); + } + + /** + * Functional interface for evaluating the scalar value of the underlying float array. + */ + @FunctionalInterface + public interface ScalarEvaluatorFunction { + float calculateScalar(float[] scratch); + } + + @Override + public Object fold(FoldContext ctx) { + return EvaluatorMapper.super.fold(source(), ctx); + } + + @Override + public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMapper.ToEvaluator toEvaluator) { + return new ScalarEvaluatorFactory(toEvaluator.apply(field()), SCALAR_FUNCTION, getClass().getSimpleName() + "Evaluator"); + } + + private record ScalarEvaluatorFactory( + EvalOperator.ExpressionEvaluator.Factory child, + ScalarEvaluatorFunction scalarFunction, + String evaluatorName + ) implements EvalOperator.ExpressionEvaluator.Factory { + + @Override + public EvalOperator.ExpressionEvaluator get(DriverContext context) { + // TODO check whether to use this custom evaluator or reuse / define an existing one + return new EvalOperator.ExpressionEvaluator() { + @Override + public Block eval(Page page) { + try (FloatBlock block = (FloatBlock) child.get(context).eval(page);) { + int positionCount = page.getPositionCount(); + int dimensions = 0; + // Get the first non-empty vector to calculate the dimension + for (int p = 0; p < positionCount; p++) { + if (block.getValueCount(p) != 0) { + dimensions = block.getValueCount(p); + break; + } + } + if (dimensions == 0) { + return context.blockFactory().newConstantFloatBlockWith(0F, 0); + } + + float[] scratch = new float[dimensions]; + try (var builder = context.blockFactory().newDoubleBlockBuilder(positionCount * dimensions)) { + for (int p = 0; p < positionCount; p++) { + int dims = block.getValueCount(p); + if (dims == 0) { + // A null value for the vector, by default append null as result. + builder.appendNull(); + continue; + } + readFloatArray(block, block.getFirstValueIndex(p), dimensions, scratch); + float result = scalarFunction.calculateScalar(scratch); + builder.appendDouble(result); + } + return builder.build(); + } + } + } + + @Override + public String toString() { + return evaluatorName() + "[child=" + child + "]"; + } + + @Override + public void close() {} + }; + } + + private static void readFloatArray(FloatBlock block, int position, int dimensions, float[] scratch) { + for (int i = 0; i < dimensions; i++) { + scratch[i] = block.getFloat(position + i); + } + } + + @Override + public String toString() { + return evaluatorName() + "[child=" + child + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java index 4a1a2ec9386ae..a0897792482d8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java @@ -42,6 +42,9 @@ public static List getNamedWritables() { if (EsqlCapabilities.Cap.L2_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { entries.add(L2Norm.ENTRY); } + if (EsqlCapabilities.Cap.MAGNITUDE_SCALAR_VECTOR_FUNCTION.isEnabled()) { + entries.add(Magnitude.ENTRY); + } return Collections.unmodifiableList(entries); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index e040067458408..7307285ec37a7 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -59,6 +59,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.string.Concat; import org.elasticsearch.xpack.esql.expression.function.scalar.string.Substring; import org.elasticsearch.xpack.esql.expression.function.vector.Knn; +import org.elasticsearch.xpack.esql.expression.function.vector.Magnitude; import org.elasticsearch.xpack.esql.expression.function.vector.VectorSimilarityFunction; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; @@ -2432,6 +2433,30 @@ private void checkNoDenseVectorFailsSimilarityFunction(String similarityFunction ); } + public void testMagnitudePlanWithDenseVectorImplicitCasting() { + var plan = analyze(String.format(Locale.ROOT, """ + from test | eval scalar = v_magnitude([1, 2, 3]) + """), "mapping-dense_vector.json"); + + var limit = as(plan, Limit.class); + var eval = as(limit.child(), Eval.class); + var alias = as(eval.fields().get(0), Alias.class); + assertEquals("scalar", alias.name()); + var scalar = as(alias.child(), Magnitude.class); + var child = as(scalar.field(), Literal.class); + assertThat(child.dataType(), is(DENSE_VECTOR)); + assertThat(child.value(), equalTo(List.of(1.0f, 2.0f, 3.0f))); + } + + public void testNoDenseVectorFailsForMagnitude() { + var query = String.format(Locale.ROOT, "row a = 1 | eval scalar = v_magnitude(0.342)"); + VerificationException error = expectThrows(VerificationException.class, () -> analyze(query)); + assertThat( + error.getMessage(), + containsString("first argument of [v_magnitude(0.342)] must be [dense_vector], found value [0.342] type [double]") + ); + } + public void testRateRequiresCounterTypes() { assumeTrue("rate requires snapshot builds", Build.current().isSnapshot()); Analyzer analyzer = analyzer(tsdbIndexResolution()); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index 37d6719ddccfc..3ba70a1caf486 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -2484,24 +2484,27 @@ private void checkFullTextFunctionsInStats(String functionInvocation) { public void testVectorSimilarityFunctionsNullArgs() throws Exception { if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - checkVectorSimilarityFunctionsNullArgs("v_cosine(null, vector)"); - checkVectorSimilarityFunctionsNullArgs("v_cosine(vector, null)"); + checkVectorFunctionsNullArgs("v_cosine(null, vector)"); + checkVectorFunctionsNullArgs("v_cosine(vector, null)"); } if (EsqlCapabilities.Cap.DOT_PRODUCT_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - checkVectorSimilarityFunctionsNullArgs("v_dot_product(null, vector)"); - checkVectorSimilarityFunctionsNullArgs("v_dot_product(vector, null)"); + checkVectorFunctionsNullArgs("v_dot_product(null, vector)"); + checkVectorFunctionsNullArgs("v_dot_product(vector, null)"); } if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - checkVectorSimilarityFunctionsNullArgs("v_l1_norm(null, vector)"); - checkVectorSimilarityFunctionsNullArgs("v_l1_norm(vector, null)"); + checkVectorFunctionsNullArgs("v_l1_norm(null, vector)"); + checkVectorFunctionsNullArgs("v_l1_norm(vector, null)"); } if (EsqlCapabilities.Cap.L2_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - checkVectorSimilarityFunctionsNullArgs("v_l2_norm(null, vector)"); - checkVectorSimilarityFunctionsNullArgs("v_l2_norm(vector, null)"); + checkVectorFunctionsNullArgs("v_l2_norm(null, vector)"); + checkVectorFunctionsNullArgs("v_l2_norm(vector, null)"); + } + if (EsqlCapabilities.Cap.MAGNITUDE_SCALAR_VECTOR_FUNCTION.isEnabled()) { + checkVectorFunctionsNullArgs("v_magnitude(null)"); } } - private void checkVectorSimilarityFunctionsNullArgs(String functionInvocation) throws Exception { + private void checkVectorFunctionsNullArgs(String functionInvocation) throws Exception { query("from test | eval similarity = " + functionInvocation, fullTextAnalyzer); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java index 791152df5acb0..6b0faaaf6d53e 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java @@ -10,7 +10,6 @@ import com.carrotsearch.randomizedtesting.annotations.Name; import org.elasticsearch.xpack.esql.action.EsqlCapabilities; -import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; import org.junit.Before; @@ -22,7 +21,7 @@ import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE; import static org.hamcrest.Matchers.equalTo; -public abstract class AbstractVectorSimilarityFunctionTestCase extends AbstractScalarFunctionTestCase { +public abstract class AbstractVectorSimilarityFunctionTestCase extends AbstractVectorTestCase { protected AbstractVectorSimilarityFunctionTestCase(@Name("TestCase") Supplier testCaseSupplier) { this.testCase = testCaseSupplier.get(); @@ -68,28 +67,4 @@ protected static Iterable similarityParameters( return parameterSuppliersFromTypedData(suppliers); } - - private static float[] listToFloatArray(List floatList) { - float[] floatArray = new float[floatList.size()]; - for (int i = 0; i < floatList.size(); i++) { - floatArray[i] = floatList.get(i); - } - return floatArray; - } - - protected double calculateSimilarity(List left, List right) { - return 0; - } - - /** - * @return A random dense vector for testing - * @param dimensions - */ - private static List randomDenseVector(int dimensions) { - List vector = new ArrayList<>(); - for (int i = 0; i < dimensions; i++) { - vector.add(randomFloat()); - } - return vector; - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorTestCase.java new file mode 100644 index 0000000000000..ddddcec21ea30 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorTestCase.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.vector; + +import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase; + +import java.util.ArrayList; +import java.util.List; + +public abstract class AbstractVectorTestCase extends AbstractScalarFunctionTestCase { + + protected static float[] listToFloatArray(List floatList) { + float[] floatArray = new float[floatList.size()]; + for (int i = 0; i < floatList.size(); i++) { + floatArray[i] = floatList.get(i); + } + return floatArray; + } + + /** + * @return A random dense vector for testing + * @param dimensions + */ + protected static List randomDenseVector(int dimensions) { + List vector = new ArrayList<>(); + for (int i = 0; i < dimensions; i++) { + vector.add(randomFloat()); + } + return vector; + } + +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/MagnitudeTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/MagnitudeTests.java new file mode 100644 index 0000000000000..651130a2c1be1 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/MagnitudeTests.java @@ -0,0 +1,76 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.vector; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.function.FunctionName; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Supplier; + +import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; +import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE; +import static org.hamcrest.Matchers.equalTo; + +@FunctionName("v_magnitude") +public class MagnitudeTests extends AbstractVectorTestCase { + + public MagnitudeTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + return scalarParameters(Magnitude.class.getSimpleName(), Magnitude.SCALAR_FUNCTION); + } + + protected EsqlCapabilities.Cap capability() { + return EsqlCapabilities.Cap.MAGNITUDE_SCALAR_VECTOR_FUNCTION; + } + + @Override + protected Expression build(Source source, List args) { + return new Magnitude(source, args.get(0)); + } + + @Before + public void checkCapability() { + assumeTrue("Scalar function is not enabled", capability().isEnabled()); + } + + protected static Iterable scalarParameters(String className, Magnitude.ScalarEvaluatorFunction scalarFunction) { + + final String evaluatorName = className + "Evaluator" + "[child=Attribute[channel=0]]"; + + List suppliers = new ArrayList<>(); + + // Basic test with a dense vector. + suppliers.add(new TestCaseSupplier(List.of(DENSE_VECTOR), () -> { + int dimensions = between(64, 128); + List input = randomDenseVector(dimensions); + float[] array = listToFloatArray(input); + double expected = scalarFunction.calculateScalar(array); + return new TestCaseSupplier.TestCase( + List.of(new TestCaseSupplier.TypedData(input, DENSE_VECTOR, "vector")), + evaluatorName, + DOUBLE, + equalTo(expected) + ); + })); + + return parameterSuppliersFromTypedData(suppliers); + } +}