diff --git a/docs/changelog/141060.yaml b/docs/changelog/141060.yaml new file mode 100644 index 0000000000000..fb9869aeeda7d --- /dev/null +++ b/docs/changelog/141060.yaml @@ -0,0 +1,6 @@ +area: ES|QL +issues: + - 140538 +pr: 141060 +summary: "Support arithmetic operations for dense_vectors: scalar version" +type: enhancement diff --git a/docs/reference/query-languages/esql/_snippets/operators/types/add.md b/docs/reference/query-languages/esql/_snippets/operators/types/add.md index 176302760ebdf..a20a81b878abe 100644 --- a/docs/reference/query-languages/esql/_snippets/operators/types/add.md +++ b/docs/reference/query-languages/esql/_snippets/operators/types/add.md @@ -12,12 +12,18 @@ | date_period | date_nanos | date_nanos | | date_period | date_period | date_period | | dense_vector | dense_vector | dense_vector | +| dense_vector | double | dense_vector | +| dense_vector | integer | dense_vector | +| dense_vector | long | dense_vector | +| double | dense_vector | dense_vector | | double | double | double | | double | integer | double | | double | long | double | +| integer | dense_vector | dense_vector | | integer | double | double | | integer | integer | integer | | integer | long | long | +| long | dense_vector | dense_vector | | long | double | double | | long | integer | long | | long | long | long | diff --git a/docs/reference/query-languages/esql/_snippets/operators/types/div.md b/docs/reference/query-languages/esql/_snippets/operators/types/div.md index e0ac5c939e8a1..ce247e3645c2d 100644 --- a/docs/reference/query-languages/esql/_snippets/operators/types/div.md +++ b/docs/reference/query-languages/esql/_snippets/operators/types/div.md @@ -5,12 +5,18 @@ | lhs | rhs | result | | --- | --- | --- | | dense_vector | dense_vector | dense_vector | +| dense_vector | double | dense_vector | +| dense_vector | integer | dense_vector | +| dense_vector | long | dense_vector | +| double | dense_vector | dense_vector | | double | double | double | | double | integer | double | | double | long | double | +| integer | dense_vector | dense_vector | | integer | double | double | | integer | integer | integer | | integer | long | long | +| long | dense_vector | dense_vector | | long | double | double | | long | integer | long | | long | long | long | diff --git a/docs/reference/query-languages/esql/_snippets/operators/types/mul.md b/docs/reference/query-languages/esql/_snippets/operators/types/mul.md index e0ac5c939e8a1..ce247e3645c2d 100644 --- a/docs/reference/query-languages/esql/_snippets/operators/types/mul.md +++ b/docs/reference/query-languages/esql/_snippets/operators/types/mul.md @@ -5,12 +5,18 @@ | lhs | rhs | result | | --- | --- | --- | | dense_vector | dense_vector | dense_vector | +| dense_vector | double | dense_vector | +| dense_vector | integer | dense_vector | +| dense_vector | long | dense_vector | +| double | dense_vector | dense_vector | | double | double | double | | double | integer | double | | double | long | double | +| integer | dense_vector | dense_vector | | integer | double | double | | integer | integer | integer | | integer | long | long | +| long | dense_vector | dense_vector | | long | double | double | | long | integer | long | | long | long | long | diff --git a/docs/reference/query-languages/esql/_snippets/operators/types/sub.md b/docs/reference/query-languages/esql/_snippets/operators/types/sub.md index bfb103204341f..755587b0e97bd 100644 --- a/docs/reference/query-languages/esql/_snippets/operators/types/sub.md +++ b/docs/reference/query-languages/esql/_snippets/operators/types/sub.md @@ -11,12 +11,18 @@ | date_period | date_nanos | date_nanos | | date_period | date_period | date_period | | dense_vector | dense_vector | dense_vector | +| dense_vector | double | dense_vector | +| dense_vector | integer | dense_vector | +| dense_vector | long | dense_vector | +| double | dense_vector | dense_vector | | double | double | double | | double | integer | double | | double | long | double | +| integer | dense_vector | dense_vector | | integer | double | double | | integer | integer | integer | | integer | long | long | +| long | dense_vector | dense_vector | | long | double | double | | long | integer | long | | long | long | long | diff --git a/docs/reference/query-languages/esql/kibana/definition/operators/add.json b/docs/reference/query-languages/esql/kibana/definition/operators/add.json index ef833a0cbf08b..bb83211c66457 100644 --- a/docs/reference/query-languages/esql/kibana/definition/operators/add.json +++ b/docs/reference/query-languages/esql/kibana/definition/operators/add.json @@ -149,6 +149,78 @@ "variadic" : false, "returnType" : "dense_vector" }, + { + "params" : [ + { + "name" : "lhs", + "type" : "dense_vector", + "optional" : false, + "description" : "A numeric value, dense_vector or a date time value." + }, + { + "name" : "rhs", + "type" : "double", + "optional" : false, + "description" : "A numeric value, dense_vector or a date time value." + } + ], + "variadic" : false, + "returnType" : "dense_vector" + }, + { + "params" : [ + { + "name" : "lhs", + "type" : "dense_vector", + "optional" : false, + "description" : "A numeric value, dense_vector or a date time value." + }, + { + "name" : "rhs", + "type" : "integer", + "optional" : false, + "description" : "A numeric value, dense_vector or a date time value." + } + ], + "variadic" : false, + "returnType" : "dense_vector" + }, + { + "params" : [ + { + "name" : "lhs", + "type" : "dense_vector", + "optional" : false, + "description" : "A numeric value, dense_vector or a date time value." + }, + { + "name" : "rhs", + "type" : "long", + "optional" : false, + "description" : "A numeric value, dense_vector or a date time value." + } + ], + "variadic" : false, + "returnType" : "dense_vector" + }, + { + "params" : [ + { + "name" : "lhs", + "type" : "double", + "optional" : false, + "description" : "A numeric value, dense_vector or a date time value." + }, + { + "name" : "rhs", + "type" : "dense_vector", + "optional" : false, + "description" : "A numeric value, dense_vector or a date time value." + } + ], + "variadic" : false, + "returnType" : "dense_vector" + }, { "params" : [ { @@ -203,6 +275,24 @@ "variadic" : false, "returnType" : "double" }, + { + "params" : [ + { + "name" : "lhs", + "type" : "integer", + "optional" : false, + "description" : "A numeric value, dense_vector or a date time value." + }, + { + "name" : "rhs", + "type" : "dense_vector", + "optional" : false, + "description" : "A numeric value, dense_vector or a date time value." + } + ], + "variadic" : false, + "returnType" : "dense_vector" + }, { "params" : [ { @@ -257,6 +347,24 @@ "variadic" : false, "returnType" : "long" }, + { + "params" : [ + { + "name" : "lhs", + "type" : "long", + "optional" : false, + "description" : "A numeric value, dense_vector or a date time value." + }, + { + "name" : "rhs", + "type" : "dense_vector", + "optional" : false, + "description" : "A numeric value, dense_vector or a date time value." + } + ], + "variadic" : false, + "returnType" : "dense_vector" + }, { "params" : [ { diff --git a/docs/reference/query-languages/esql/kibana/definition/operators/div.json b/docs/reference/query-languages/esql/kibana/definition/operators/div.json index 0aaeecab665ad..a3e4e6b58a634 100644 --- a/docs/reference/query-languages/esql/kibana/definition/operators/div.json +++ b/docs/reference/query-languages/esql/kibana/definition/operators/div.json @@ -24,6 +24,78 @@ "variadic" : false, "returnType" : "dense_vector" }, + { + "params" : [ + { + "name" : "lhs", + "type" : "dense_vector", + "optional" : false, + "description" : "A numeric value." + }, + { + "name" : "rhs", + "type" : "double", + "optional" : false, + "description" : "A numeric value." + } + ], + "variadic" : false, + "returnType" : "dense_vector" + }, + { + "params" : [ + { + "name" : "lhs", + "type" : "dense_vector", + "optional" : false, + "description" : "A numeric value." + }, + { + "name" : "rhs", + "type" : "integer", + "optional" : false, + "description" : "A numeric value." + } + ], + "variadic" : false, + "returnType" : "dense_vector" + }, + { + "params" : [ + { + "name" : "lhs", + "type" : "dense_vector", + "optional" : false, + "description" : "A numeric value." + }, + { + "name" : "rhs", + "type" : "long", + "optional" : false, + "description" : "A numeric value." + } + ], + "variadic" : false, + "returnType" : "dense_vector" + }, + { + "params" : [ + { + "name" : "lhs", + "type" : "double", + "optional" : false, + "description" : "A numeric value." + }, + { + "name" : "rhs", + "type" : "dense_vector", + "optional" : false, + "description" : "A numeric value." + } + ], + "variadic" : false, + "returnType" : "dense_vector" + }, { "params" : [ { @@ -78,6 +150,24 @@ "variadic" : false, "returnType" : "double" }, + { + "params" : [ + { + "name" : "lhs", + "type" : "integer", + "optional" : false, + "description" : "A numeric value." + }, + { + "name" : "rhs", + "type" : "dense_vector", + "optional" : false, + "description" : "A numeric value." + } + ], + "variadic" : false, + "returnType" : "dense_vector" + }, { "params" : [ { @@ -132,6 +222,24 @@ "variadic" : false, "returnType" : "long" }, + { + "params" : [ + { + "name" : "lhs", + "type" : "long", + "optional" : false, + "description" : "A numeric value." + }, + { + "name" : "rhs", + "type" : "dense_vector", + "optional" : false, + "description" : "A numeric value." + } + ], + "variadic" : false, + "returnType" : "dense_vector" + }, { "params" : [ { diff --git a/docs/reference/query-languages/esql/kibana/definition/operators/mul.json b/docs/reference/query-languages/esql/kibana/definition/operators/mul.json index a603b163a9d8d..c815f005045ca 100644 --- a/docs/reference/query-languages/esql/kibana/definition/operators/mul.json +++ b/docs/reference/query-languages/esql/kibana/definition/operators/mul.json @@ -24,6 +24,78 @@ "variadic" : false, "returnType" : "dense_vector" }, + { + "params" : [ + { + "name" : "lhs", + "type" : "dense_vector", + "optional" : false, + "description" : "A numeric value or dense_vector" + }, + { + "name" : "rhs", + "type" : "double", + "optional" : false, + "description" : "A numeric value or dense_vector" + } + ], + "variadic" : false, + "returnType" : "dense_vector" + }, + { + "params" : [ + { + "name" : "lhs", + "type" : "dense_vector", + "optional" : false, + "description" : "A numeric value or dense_vector" + }, + { + "name" : "rhs", + "type" : "integer", + "optional" : false, + "description" : "A numeric value or dense_vector" + } + ], + "variadic" : false, + "returnType" : "dense_vector" + }, + { + "params" : [ + { + "name" : "lhs", + "type" : "dense_vector", + "optional" : false, + "description" : "A numeric value or dense_vector" + }, + { + "name" : "rhs", + "type" : "long", + "optional" : false, + "description" : "A numeric value or dense_vector" + } + ], + "variadic" : false, + "returnType" : "dense_vector" + }, + { + "params" : [ + { + "name" : "lhs", + "type" : "double", + "optional" : false, + "description" : "A numeric value or dense_vector" + }, + { + "name" : "rhs", + "type" : "dense_vector", + "optional" : false, + "description" : "A numeric value or dense_vector" + } + ], + "variadic" : false, + "returnType" : "dense_vector" + }, { "params" : [ { @@ -78,6 +150,24 @@ "variadic" : false, "returnType" : "double" }, + { + "params" : [ + { + "name" : "lhs", + "type" : "integer", + "optional" : false, + "description" : "A numeric value or dense_vector" + }, + { + "name" : "rhs", + "type" : "dense_vector", + "optional" : false, + "description" : "A numeric value or dense_vector" + } + ], + "variadic" : false, + "returnType" : "dense_vector" + }, { "params" : [ { @@ -132,6 +222,24 @@ "variadic" : false, "returnType" : "long" }, + { + "params" : [ + { + "name" : "lhs", + "type" : "long", + "optional" : false, + "description" : "A numeric value or dense_vector" + }, + { + "name" : "rhs", + "type" : "dense_vector", + "optional" : false, + "description" : "A numeric value or dense_vector" + } + ], + "variadic" : false, + "returnType" : "dense_vector" + }, { "params" : [ { diff --git a/docs/reference/query-languages/esql/kibana/definition/operators/sub.json b/docs/reference/query-languages/esql/kibana/definition/operators/sub.json index 23acbad68d855..e85f0a162f176 100644 --- a/docs/reference/query-languages/esql/kibana/definition/operators/sub.json +++ b/docs/reference/query-languages/esql/kibana/definition/operators/sub.json @@ -132,6 +132,78 @@ "variadic" : false, "returnType" : "dense_vector" }, + { + "params" : [ + { + "name" : "lhs", + "type" : "dense_vector", + "optional" : false, + "description" : "A numeric value, dense_vector or a date time value." + }, + { + "name" : "rhs", + "type" : "double", + "optional" : false, + "description" : "A numeric value, dense_vector or a date time value." + } + ], + "variadic" : false, + "returnType" : "dense_vector" + }, + { + "params" : [ + { + "name" : "lhs", + "type" : "dense_vector", + "optional" : false, + "description" : "A numeric value, dense_vector or a date time value." + }, + { + "name" : "rhs", + "type" : "integer", + "optional" : false, + "description" : "A numeric value, dense_vector or a date time value." + } + ], + "variadic" : false, + "returnType" : "dense_vector" + }, + { + "params" : [ + { + "name" : "lhs", + "type" : "dense_vector", + "optional" : false, + "description" : "A numeric value, dense_vector or a date time value." + }, + { + "name" : "rhs", + "type" : "long", + "optional" : false, + "description" : "A numeric value, dense_vector or a date time value." + } + ], + "variadic" : false, + "returnType" : "dense_vector" + }, + { + "params" : [ + { + "name" : "lhs", + "type" : "double", + "optional" : false, + "description" : "A numeric value, dense_vector or a date time value." + }, + { + "name" : "rhs", + "type" : "dense_vector", + "optional" : false, + "description" : "A numeric value, dense_vector or a date time value." + } + ], + "variadic" : false, + "returnType" : "dense_vector" + }, { "params" : [ { @@ -186,6 +258,24 @@ "variadic" : false, "returnType" : "double" }, + { + "params" : [ + { + "name" : "lhs", + "type" : "integer", + "optional" : false, + "description" : "A numeric value, dense_vector or a date time value." + }, + { + "name" : "rhs", + "type" : "dense_vector", + "optional" : false, + "description" : "A numeric value, dense_vector or a date time value." + } + ], + "variadic" : false, + "returnType" : "dense_vector" + }, { "params" : [ { @@ -240,6 +330,24 @@ "variadic" : false, "returnType" : "long" }, + { + "params" : [ + { + "name" : "lhs", + "type" : "long", + "optional" : false, + "description" : "A numeric value, dense_vector or a date time value." + }, + { + "name" : "rhs", + "type" : "dense_vector", + "optional" : false, + "description" : "A numeric value, dense_vector or a date time value." + } + ], + "variadic" : false, + "returnType" : "dense_vector" + }, { "params" : [ { diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/NumericUtils.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/NumericUtils.java index 1dcb608fbe418..ffd5684b66e81 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/NumericUtils.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/NumericUtils.java @@ -166,7 +166,7 @@ public static double asFiniteNumber(double dbl) { */ public static float asFiniteNumber(float flt) { if (Double.isNaN(flt) || Double.isInfinite(flt)) { - throw new ArithmeticException("not a finite double number: " + flt); + throw new ArithmeticException("not a finite float number: " + flt); } return flt; } diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dense_vector-arithmetic.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dense_vector-arithmetic.csv-spec index edad4dcf1773e..7dad63f13f095 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dense_vector-arithmetic.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dense_vector-arithmetic.csv-spec @@ -86,6 +86,48 @@ null null ; +addDenseVectorAndScalarDouble +required_capability: dense_vector_scalar_arithmetic +FROM dense_vector_arithmetic +| eval result_vector = vector_field_1 + 1.0 +| SORT id +| KEEP result_vector; + +result_vector:dense_vector +[2.0, 3.0, 4.0] +[5.0, 6.0, 7.0] +[10.0, 9.0, 8.0] +[1.054, 1.032, 1.012] +; + +addDenseVectorAndScalarInt +required_capability: dense_vector_scalar_arithmetic +FROM dense_vector_arithmetic +| eval result_vector = vector_field_1 + 1 +| SORT id +| KEEP result_vector; + +result_vector:dense_vector +[2.0, 3.0, 4.0] +[5.0, 6.0, 7.0] +[10.0, 9.0, 8.0] +[1.054, 1.032, 1.012] +; + +addDenseVectorAndScalarLong +required_capability: dense_vector_scalar_arithmetic +FROM dense_vector_arithmetic +| eval result_vector = vector_field_1 + 1::long +| SORT id +| KEEP result_vector; + +result_vector:dense_vector +[2.0, 3.0, 4.0] +[5.0, 6.0, 7.0] +[10.0, 9.0, 8.0] +[1.054, 1.032, 1.012] +; + // tests for sub operation subDenseVectors @@ -174,6 +216,48 @@ null null ; +subDenseVectorAndScalarDouble +required_capability: dense_vector_scalar_arithmetic +FROM dense_vector_arithmetic +| eval result_vector = vector_field_1 - 1.0 +| SORT id +| KEEP result_vector; + +result_vector:dense_vector +[0.0, 1.0, 2.0] +[3.0, 4.0, 5.0] +[8.0, 7.0, 6.0] +[-0.946, -0.968, -0.988] +; + +subDenseVectorAndScalarInt +required_capability: dense_vector_scalar_arithmetic +FROM dense_vector_arithmetic +| eval result_vector = vector_field_1 - 1 +| SORT id +| KEEP result_vector; + +result_vector:dense_vector +[0.0, 1.0, 2.0] +[3.0, 4.0, 5.0] +[8.0, 7.0, 6.0] +[-0.946, -0.968, -0.988] +; + +subDenseVectorAndScalarLong +required_capability: dense_vector_scalar_arithmetic +FROM dense_vector_arithmetic +| eval result_vector = vector_field_1 - 1::long +| SORT id +| KEEP result_vector; + +result_vector:dense_vector +[0.0, 1.0, 2.0] +[3.0, 4.0, 5.0] +[8.0, 7.0, 6.0] +[-0.946, -0.968, -0.988] +; + // tests for mul operation mulDenseVectors @@ -262,6 +346,48 @@ null null ; +mulDenseVectorAndScalarDouble +required_capability: dense_vector_scalar_arithmetic +FROM dense_vector_arithmetic +| eval result_vector = vector_field_1 * 2.0 +| SORT id +| KEEP result_vector; + +result_vector:dense_vector +[2.0, 4.0, 6.0] +[8.0, 10.0, 12.0] +[18.0, 16.0, 14.0] +[0.108, 0.064, 0.024] +; + +mulDenseVectorAndScalarInt +required_capability: dense_vector_scalar_arithmetic +FROM dense_vector_arithmetic +| eval result_vector = vector_field_1 * 2 +| SORT id +| KEEP result_vector; + +result_vector:dense_vector +[2.0, 4.0, 6.0] +[8.0, 10.0, 12.0] +[18.0, 16.0, 14.0] +[0.108, 0.064, 0.024] +; + +mulDenseVectorAndScalarLong +required_capability: dense_vector_scalar_arithmetic +FROM dense_vector_arithmetic +| eval result_vector = vector_field_1 * 2::long +| SORT id +| KEEP result_vector; + +result_vector:dense_vector +[2.0, 4.0, 6.0] +[8.0, 10.0, 12.0] +[18.0, 16.0, 14.0] +[0.108, 0.064, 0.024] +; + // tests for div operation divDenseVectors @@ -349,3 +475,45 @@ null null null ; + +divDenseVectorAndScalarDouble +required_capability: dense_vector_scalar_arithmetic +FROM dense_vector_arithmetic +| eval result_vector = vector_field_1 / 2.0 +| SORT id +| KEEP result_vector; + +result_vector:dense_vector +[0.5, 1.0, 1.5] +[2.0, 2.5, 3.0] +[4.5, 4.0, 3.5] +[0.027, 0.016, 0.006] +; + +divDenseVectorAndScalarInt +required_capability: dense_vector_scalar_arithmetic +FROM dense_vector_arithmetic +| eval result_vector = vector_field_1 / 2 +| SORT id +| KEEP result_vector; + +result_vector:dense_vector +[0.5, 1.0, 1.5] +[2.0, 2.5, 3.0] +[4.5, 4.0, 3.5] +[0.027, 0.016, 0.006] +; + +divDenseVectorAndScalarLong +required_capability: dense_vector_scalar_arithmetic +FROM dense_vector_arithmetic +| eval result_vector = vector_field_1 / 2::long +| SORT id +| KEEP result_vector; + +result_vector:dense_vector +[0.5, 1.0, 1.5] +[2.0, 2.5, 3.0] +[4.5, 4.0, 3.5] +[0.027, 0.016, 0.006] +; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 176166616ff36..eab5691d3f01b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -2058,6 +2058,11 @@ public enum Cap { */ DENSE_VECTOR_ARITHMETIC, + /** + * Support for dense_vector arithmetic operations (+, -, *, /) + */ + DENSE_VECTOR_SCALAR_ARITHMETIC, + /** * Dense_vector aggregation functions */ diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Add.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Add.java index da73ae9b15804..a18be56beb232 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Add.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Add.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.time.DateUtils; import org.elasticsearch.compute.ann.Evaluator; import org.elasticsearch.compute.ann.Fixed; +import org.elasticsearch.compute.operator.EvalOperator; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -36,6 +37,7 @@ public class Add extends DateTimeArithmeticOperation implements BinaryComparisonInversible { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Add", Add::new); + public static final String OP_NAME = "Add"; private final Configuration configuration; @@ -97,7 +99,7 @@ public Add( AddLongsEvaluator.Factory::new, AddUnsignedLongsEvaluator.Factory::new, AddDoublesEvaluator.Factory::new, - DenseVectorsEvaluator.AddFactory::new, + ADD_DENSE_VECTOR_EVALUATOR, AddDatetimesEvaluator.Factory::new, AddDateNanosEvaluator.Factory::new ); @@ -112,7 +114,7 @@ private Add(StreamInput in) throws IOException { AddLongsEvaluator.Factory::new, AddUnsignedLongsEvaluator.Factory::new, AddDoublesEvaluator.Factory::new, - DenseVectorsEvaluator.AddFactory::new, + ADD_DENSE_VECTOR_EVALUATOR, AddDatetimesEvaluator.Factory::new, AddDateNanosEvaluator.Factory::new ); @@ -208,4 +210,37 @@ public Configuration configuration() { public Add withConfiguration(Configuration configuration) { return new Add(source(), left(), right(), configuration); } + + private static float addDenseVectorElements(float lhs, float rhs) { + return NumericUtils.asFiniteNumber(lhs + rhs); + } + + private static final DenseVectorBinaryEvaluator ADD_DENSE_VECTOR_EVALUATOR = new DenseVectorBinaryEvaluator() { + @Override + public EvalOperator.ExpressionEvaluator.Factory vectorsOperation( + Source source, + EvalOperator.ExpressionEvaluator.Factory lhs, + EvalOperator.ExpressionEvaluator.Factory rhs + ) { + return new DenseVectorsEvaluator.Factory(source, lhs, rhs, Add::addDenseVectorElements, OP_NAME); + } + + @Override + public EvalOperator.ExpressionEvaluator.Factory scalarVectorOperation( + Source source, + float lhs, + EvalOperator.ExpressionEvaluator.Factory rhs + ) { + return new DenseVectorScalarEvaluator.Factory(source, lhs, rhs, Add::addDenseVectorElements, OP_NAME); + } + + @Override + public EvalOperator.ExpressionEvaluator.Factory vectorScalarOperation( + Source source, + EvalOperator.ExpressionEvaluator.Factory lhs, + float rhs + ) { + return new DenseVectorScalarEvaluator.Factory(source, lhs, rhs, Add::addDenseVectorElements, OP_NAME); + } + }; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DateTimeArithmeticOperation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DateTimeArithmeticOperation.java index 60020810f9ede..03db8e69ecc5e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DateTimeArithmeticOperation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DateTimeArithmeticOperation.java @@ -58,7 +58,7 @@ ExpressionEvaluator.Factory apply( BinaryEvaluator longs, BinaryEvaluator ulongs, BinaryEvaluator doubles, - BinaryEvaluator denseVectors, + DenseVectorBinaryEvaluator denseVectors, DatetimeArithmeticEvaluator millisEvaluator, DatetimeArithmeticEvaluator nanosEvaluator ) { @@ -74,7 +74,7 @@ ExpressionEvaluator.Factory apply( BinaryEvaluator longs, BinaryEvaluator ulongs, BinaryEvaluator doubles, - BinaryEvaluator denseVectors, + DenseVectorBinaryEvaluator denseVectors, DatetimeArithmeticEvaluator millisEvaluator, DatetimeArithmeticEvaluator nanosEvaluator ) throws IOException { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DenseVectorArithmeticOperation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DenseVectorArithmeticOperation.java index 19f8a115d2f9e..4ab737974f25c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DenseVectorArithmeticOperation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DenseVectorArithmeticOperation.java @@ -8,22 +8,53 @@ package org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic; import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.logging.LoggerMessageFormat; import org.elasticsearch.compute.operator.EvalOperator; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.type.DataTypeConverter; import java.io.IOException; import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; +import static org.elasticsearch.xpack.esql.core.type.DataType.FLOAT; import static org.elasticsearch.xpack.esql.core.type.DataType.NULL; +import static org.elasticsearch.xpack.esql.core.type.DataType.UNSIGNED_LONG; +import static org.elasticsearch.xpack.esql.core.type.DataType.isNullOrNumeric; /** * Adds support for dense_vector data types. Specifically provides the logic when either left or right type is a dense_vector. */ public abstract class DenseVectorArithmeticOperation extends EsqlArithmeticOperation { - private final BinaryEvaluator denseVectors; + private static final String ERROR_MSG = "[{}] should evaluate to a dense_vector or scalar constant"; + + private final DenseVectorBinaryEvaluator denseVectors; + + /** Set of arithmetic (quad) functions for dense_vectors. */ + public interface DenseVectorBinaryEvaluator { + // when both arguments are dense_vectors + EvalOperator.ExpressionEvaluator.Factory vectorsOperation( + Source source, + EvalOperator.ExpressionEvaluator.Factory lhs, + EvalOperator.ExpressionEvaluator.Factory rhs + ); + + // when lhs is a scalar and rhs is a dense_vector + EvalOperator.ExpressionEvaluator.Factory scalarVectorOperation( + Source source, + float lhs, + EvalOperator.ExpressionEvaluator.Factory rhs + ); + + // when lhs is a dense_vector and rhs is a scalar + EvalOperator.ExpressionEvaluator.Factory vectorScalarOperation( + Source source, + EvalOperator.ExpressionEvaluator.Factory lhs, + float rhs + ); + } protected DenseVectorArithmeticOperation( Source source, @@ -34,7 +65,7 @@ protected DenseVectorArithmeticOperation( BinaryEvaluator longs, BinaryEvaluator ulongs, BinaryEvaluator doubles, - BinaryEvaluator denseVectors + DenseVectorBinaryEvaluator denseVectors ) { super(source, left, right, op, ints, longs, ulongs, doubles); this.denseVectors = denseVectors; @@ -47,7 +78,7 @@ protected DenseVectorArithmeticOperation( BinaryEvaluator longs, BinaryEvaluator ulongs, BinaryEvaluator doubles, - BinaryEvaluator denseVectors + DenseVectorBinaryEvaluator denseVectors ) throws IOException { super(in, op, ints, longs, ulongs, doubles); this.denseVectors = denseVectors; @@ -67,23 +98,58 @@ protected TypeResolution resolveInputType(Expression e, TypeResolutions.ParamOrd @Override protected TypeResolution checkCompatibility() { - // dense_vectors arithmetic only supported when both arguments are dense_vectors or one argument is null + // dense_vectors arithmetic only supported when both arguments are dense_vectors or one argument is numeric or null DataType leftType = left().dataType(); DataType rightType = right().dataType(); if (leftType == DENSE_VECTOR || rightType == DENSE_VECTOR) { - if ((leftType == DENSE_VECTOR || leftType == NULL) && (rightType == DENSE_VECTOR || rightType == NULL)) { + if (leftType == NULL || rightType == NULL) { return TypeResolution.TYPE_RESOLVED; } - return new TypeResolution(formatIncompatibleTypesMessage(symbol(), leftType, rightType)); + if (leftType != DENSE_VECTOR) { + if (false == isSupportedScalar(leftType)) { + return new TypeResolution(formatIncompatibleTypesMessage(symbol(), leftType, rightType)); + } + if (false == left().foldable()) { + return new TypeResolution( + LoggerMessageFormat.format(null, "[{}] should evaluate to a dense_vector or scalar constant", left().sourceText()) + ); + } + } + if (rightType != DENSE_VECTOR) { + if (false == isSupportedScalar(rightType)) { + return new TypeResolution(formatIncompatibleTypesMessage(symbol(), leftType, rightType)); + } + if (false == right().foldable()) { + return new TypeResolution( + LoggerMessageFormat.format(null, "[{}] should evaluate to a dense_vector or scalar constant", right().sourceText()) + ); + } + } + return TypeResolution.TYPE_RESOLVED; } return super.checkCompatibility(); } @Override public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { - if (dataType() == DENSE_VECTOR) { - return this.denseVectors.apply(source(), toEvaluator.apply(left()), toEvaluator.apply(right())); + var commonType = dataType(); + if (commonType == DENSE_VECTOR) { + if (left().dataType() == DENSE_VECTOR && right().dataType() == DENSE_VECTOR) { + return denseVectors.vectorsOperation(source(), toEvaluator.apply(left()), toEvaluator.apply(right())); + } + if (left().dataType() != DENSE_VECTOR) { + float lhs = (Float) DataTypeConverter.convert(left().fold(toEvaluator.foldCtx()), FLOAT); + return denseVectors.scalarVectorOperation(source(), lhs, toEvaluator.apply(right())); + } else { + float rhs = (Float) DataTypeConverter.convert(right().fold(toEvaluator.foldCtx()), FLOAT); + return denseVectors.vectorScalarOperation(source(), toEvaluator.apply(left()), rhs); + } } return super.toEvaluator(toEvaluator); } + + private static boolean isSupportedScalar(DataType dataType) { + return isNullOrNumeric(dataType) && dataType != UNSIGNED_LONG; + } + } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DenseVectorScalarEvaluator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DenseVectorScalarEvaluator.java new file mode 100644 index 0000000000000..130f3e593b3e0 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DenseVectorScalarEvaluator.java @@ -0,0 +1,154 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic; + +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.Warnings; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.core.tree.Source; + +import java.util.function.BiFunction; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for performing arithmetic operations when + * lhs is a dense_vector and rhs a scalar. + * + */ +class DenseVectorScalarEvaluator implements EvalOperator.ExpressionEvaluator { + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(DenseVectorScalarEvaluator.class); + + private final BiFunction op; + private final String name; + private final Source source; + private final EvalOperator.ExpressionEvaluator lhs; + private final Float rhs; + private final DriverContext driverContext; + private Warnings warnings; + + DenseVectorScalarEvaluator( + BiFunction op, + String name, + Source source, + EvalOperator.ExpressionEvaluator lhs, + Float rhs, + DriverContext driverContext + ) { + this.op = op; + this.name = name; + this.source = source; + this.lhs = lhs; + this.rhs = rhs; + this.driverContext = driverContext; + } + + @Override + public Block eval(Page page) { + assert rhs != null : "Operand for dense vector arithmetic operation cannot be null"; + try (var lhsBlock = (FloatBlock) lhs.eval(page)) { + int positionCount = page.getPositionCount(); + try (var resultBlock = driverContext.blockFactory().newFloatBlockBuilder(positionCount)) { + float[] buffer = new float[0]; + for (int p = 0; p < positionCount; p++) { + if (lhsBlock.isNull(p)) { + resultBlock.appendNull(); + continue; + } + + int lhsValueCount = lhsBlock.getValueCount(p); + if (buffer.length < lhsValueCount) { + buffer = new float[lhsValueCount]; + } + int lhsStart = lhsBlock.getFirstValueIndex(p); + try { + for (int i = 0; i < lhsValueCount; i++) { + float l = lhsBlock.getFloat(lhsStart + i); + // Always assume the scalar operand is the rhs in the processing. + // We need to flip the order of arguments for non-commutative operations in the Factory when the scalar is + // the lhs, to ensure the correct order of arguments is applied here. + buffer[i] = op.apply(l, rhs); + } + resultBlock.beginPositionEntry(); + for (int i = 0; i < lhsValueCount; i++) { + resultBlock.appendFloat(buffer[i]); + } + resultBlock.endPositionEntry(); + } catch (ArithmeticException e) { + warnings().registerException(e); + resultBlock.appendNull(); + } + } + return resultBlock.build(); + } + } + } + + @Override + public long baseRamBytesUsed() { + return BASE_RAM_BYTES_USED + lhs.baseRamBytesUsed() + RamUsageEstimator.shallowSizeOfInstance(Float.class); + } + + @Override + public String toString() { + return "DenseVectorScalarEvaluator[" + "lhs=" + lhs + ", rhs=scalar_constant" + ", opName=" + name + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(lhs); + } + + private Warnings warnings() { + if (warnings == null) { + this.warnings = Warnings.createWarnings(driverContext.warningsMode(), source); + } + return warnings; + } + + static final class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + private final EvalOperator.ExpressionEvaluator.Factory vector; + private final Float scalar; + private final BiFunction op; + private final String opName; + + // Factory when lhs is a dense_vector and rhs a scalar + Factory(Source source, EvalOperator.ExpressionEvaluator.Factory lhs, Float rhs, BiFunction op, String opName) { + this.source = source; + this.vector = lhs; + this.scalar = rhs; + this.op = op; + this.opName = opName; + } + + // Factory when lhs is a scalar and rhs a dense_vector. + Factory(Source source, Float lhs, EvalOperator.ExpressionEvaluator.Factory rhs, BiFunction op, String opName) { + this.source = source; + this.scalar = lhs; + this.vector = rhs; + // flip the order of arguments for scalar-vector operations, as we assume the scalar is always the rhs in the processing + this.op = (a, b) -> op.apply(b, a); + this.opName = opName; + } + + @Override + public DenseVectorScalarEvaluator get(DriverContext context) { + return new DenseVectorScalarEvaluator(op, opName, source, vector.get(context), scalar, context); + } + + @Override + public String toString() { + return "DenseVectorScalarEvaluator[" + "lhs=" + vector + ", rhs=scalar_constant, opName=" + opName + "]"; + } + } + +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DenseVectorsEvaluator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DenseVectorsEvaluator.java index 98df0e7cc5ad3..51a2dbdd743c2 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DenseVectorsEvaluator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DenseVectorsEvaluator.java @@ -16,20 +16,17 @@ import org.elasticsearch.compute.operator.Warnings; import org.elasticsearch.core.Releasables; import org.elasticsearch.xpack.esql.core.tree.Source; -import org.elasticsearch.xpack.esql.core.util.NumericUtils; import java.util.function.BiFunction; +import static org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; + /** * {@link EvalOperator.ExpressionEvaluator} implementation for performing arithmetic operations on two dense_vector arguments. * */ class DenseVectorsEvaluator implements EvalOperator.ExpressionEvaluator { private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(DenseVectorsEvaluator.class); - private static final String ADD_DENSE_VECTOR_EVALUATOR = "AddDenseVectorsEvaluator"; - private static final String SUB_DENSE_VECTOR_EVALUATOR = "SubDenseVectorsEvaluator"; - private static final String MUL_DENSE_VECTOR_EVALUATOR = "MulDenseVectorsEvaluator"; - private static final String DIV_DENSE_VECTOR_EVALUATOR = "DivDenseVectorsEvaluator"; private final BiFunction op; private final String name; @@ -83,7 +80,6 @@ public Block eval(Page page) { if (buffer.length < lhsValueCount) { buffer = new float[lhsValueCount]; } - boolean success = true; try { for (int i = 0; i < lhsValueCount; i++) { float l = lhsBlock.getFloat(lhsStart + i); @@ -98,7 +94,6 @@ public Block eval(Page page) { } catch (ArithmeticException e) { warnings().registerException(e); resultBlock.appendNull(); - success = false; } } return resultBlock.build(); @@ -113,7 +108,7 @@ public long baseRamBytesUsed() { @Override public String toString() { - return name + "[" + "lhs=" + lhs + ", rhs=" + rhs + "]"; + return "DenseVectorsEvaluator[" + "lhs=" + lhs + ", rhs=" + rhs + ", opName=" + name + "]"; } @Override @@ -128,139 +123,35 @@ private Warnings warnings() { return warnings; } - private static float processAdd(float lhs, float rhs) { - return NumericUtils.asFiniteNumber(lhs + rhs); - } - - private static float processSub(float lhs, float rhs) { - return NumericUtils.asFiniteNumber(lhs - rhs); - } - - private static float processMul(float lhs, float rhs) { - return NumericUtils.asFiniteNumber(lhs * rhs); - } - - private static float processDiv(float lhs, float rhs) { - float result = lhs / rhs; - if (Double.isNaN(result) || Double.isInfinite(result)) { - throw new ArithmeticException("/ by zero"); - } - return result; - } - - static final class AddFactory implements Factory { - private final Source source; - private final Factory lhs; - private final Factory rhs; - - AddFactory(Source source, Factory lhs, Factory rhs) { - this.source = source; - this.lhs = lhs; - this.rhs = rhs; - } - - @Override - public DenseVectorsEvaluator get(DriverContext context) { - return new DenseVectorsEvaluator( - DenseVectorsEvaluator::processAdd, - ADD_DENSE_VECTOR_EVALUATOR, - source, - lhs.get(context), - rhs.get(context), - context - ); - } - - @Override - public String toString() { - return ADD_DENSE_VECTOR_EVALUATOR + "[" + "lhs=" + lhs + ", rhs=" + rhs + "]"; - } - } - - static class SubFactory implements Factory { + static final class Factory implements ExpressionEvaluator.Factory { private final Source source; - private final Factory lhs; - private final Factory rhs; - - SubFactory(Source source, Factory lhs, Factory rhs) { - this.source = source; - this.lhs = lhs; - this.rhs = rhs; - } - - @Override - public DenseVectorsEvaluator get(DriverContext context) { - return new DenseVectorsEvaluator( - DenseVectorsEvaluator::processSub, - SUB_DENSE_VECTOR_EVALUATOR, - source, - lhs.get(context), - rhs.get(context), - context - ); - } - - @Override - public String toString() { - return SUB_DENSE_VECTOR_EVALUATOR + "[lhs=" + lhs + ", rhs=" + rhs + "]"; - } - } - - static class MulFactory implements Factory { - private final Source source; - private final Factory lhs; - private final Factory rhs; - - MulFactory(Source source, Factory lhs, Factory rhs) { - this.source = source; - this.lhs = lhs; - this.rhs = rhs; - } - - @Override - public DenseVectorsEvaluator get(DriverContext context) { - return new DenseVectorsEvaluator( - DenseVectorsEvaluator::processMul, - MUL_DENSE_VECTOR_EVALUATOR, - source, - lhs.get(context), - rhs.get(context), - context - ); - } - - @Override - public String toString() { - return MUL_DENSE_VECTOR_EVALUATOR + "[lhs=" + lhs + ", rhs=" + rhs + "]"; - } - } - - static class DivFactory implements Factory { - private final Source source; - private final Factory lhs; - private final Factory rhs; - - DivFactory(Source source, Factory lhs, Factory rhs) { + private final EvalOperator.ExpressionEvaluator.Factory lhs; + private final EvalOperator.ExpressionEvaluator.Factory rhs; + private final BiFunction op; + private final String opName; + + Factory( + Source source, + ExpressionEvaluator.Factory lhs, + ExpressionEvaluator.Factory rhs, + BiFunction op, + String opName + ) { this.source = source; this.lhs = lhs; this.rhs = rhs; + this.op = op; + this.opName = opName; } @Override public DenseVectorsEvaluator get(DriverContext context) { - return new DenseVectorsEvaluator( - DenseVectorsEvaluator::processDiv, - DIV_DENSE_VECTOR_EVALUATOR, - source, - lhs.get(context), - rhs.get(context), - context - ); + return new DenseVectorsEvaluator(op, opName, source, lhs.get(context), rhs.get(context), context); } @Override public String toString() { - return DIV_DENSE_VECTOR_EVALUATOR + "[lhs=" + lhs + ", rhs=" + rhs + "]"; + return "DenseVectorsEvaluator[" + "lhs=" + lhs + ", rhs=" + rhs + ", opName=" + opName + "]"; } } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Div.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Div.java index c8be6a3fe1985..f9118a8c23268 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Div.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Div.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.compute.ann.Evaluator; +import org.elasticsearch.compute.operator.EvalOperator; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -25,6 +26,7 @@ public class Div extends DenseVectorArithmeticOperation implements BinaryComparisonInversible { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Div", Div::new); + public static final String OP_NAME = "Div"; private DataType type; @@ -53,7 +55,7 @@ public Div(Source source, Expression left, Expression right, DataType type) { DivLongsEvaluator.Factory::new, DivUnsignedLongsEvaluator.Factory::new, DivDoublesEvaluator.Factory::new, - DenseVectorsEvaluator.DivFactory::new + DIV_DENSE_VECTOR_EVALUATOR ); this.type = type; } @@ -66,7 +68,7 @@ private Div(StreamInput in) throws IOException { DivLongsEvaluator.Factory::new, DivUnsignedLongsEvaluator.Factory::new, DivDoublesEvaluator.Factory::new, - DenseVectorsEvaluator.DivFactory::new + DIV_DENSE_VECTOR_EVALUATOR ); } @@ -129,4 +131,41 @@ static double processDoubles(double lhs, double rhs) { } return value; } + + private static float divDenseVectorElements(float lhs, float rhs) { + float value = lhs / rhs; + if (Float.isNaN(value) || Float.isInfinite(value)) { + throw new ArithmeticException("/ by zero"); + } + return value; + } + + private static final DenseVectorBinaryEvaluator DIV_DENSE_VECTOR_EVALUATOR = new DenseVectorBinaryEvaluator() { + @Override + public EvalOperator.ExpressionEvaluator.Factory vectorsOperation( + Source source, + EvalOperator.ExpressionEvaluator.Factory lhs, + EvalOperator.ExpressionEvaluator.Factory rhs + ) { + return new DenseVectorsEvaluator.Factory(source, lhs, rhs, Div::divDenseVectorElements, OP_NAME); + } + + @Override + public EvalOperator.ExpressionEvaluator.Factory scalarVectorOperation( + Source source, + float lhs, + EvalOperator.ExpressionEvaluator.Factory rhs + ) { + return new DenseVectorScalarEvaluator.Factory(source, lhs, rhs, Div::divDenseVectorElements, OP_NAME); + } + + @Override + public EvalOperator.ExpressionEvaluator.Factory vectorScalarOperation( + Source source, + EvalOperator.ExpressionEvaluator.Factory lhs, + float rhs + ) { + return new DenseVectorScalarEvaluator.Factory(source, lhs, rhs, Div::divDenseVectorElements, OP_NAME); + } + }; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Mul.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Mul.java index e1e110886ff77..615daf6aa6869 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Mul.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Mul.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.compute.ann.Evaluator; +import org.elasticsearch.compute.operator.EvalOperator; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -24,6 +25,7 @@ public class Mul extends DenseVectorArithmeticOperation implements BinaryComparisonInversible { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Mul", Mul::new); + public static final String OP_NAME = "Mul"; @FunctionInfo(operator = "*", returnType = { "double", "integer", "long", "unsigned_long", "dense_vector" }, description = """ Multiply two values together. For numeric fields, if either field is <> @@ -52,7 +54,7 @@ public Mul( MulLongsEvaluator.Factory::new, MulUnsignedLongsEvaluator.Factory::new, MulDoublesEvaluator.Factory::new, - DenseVectorsEvaluator.MulFactory::new + MUL_DENSE_VECTOR_EVALUATOR ); } @@ -64,7 +66,7 @@ private Mul(StreamInput in) throws IOException { MulLongsEvaluator.Factory::new, MulUnsignedLongsEvaluator.Factory::new, MulDoublesEvaluator.Factory::new, - DenseVectorsEvaluator.MulFactory::new + MUL_DENSE_VECTOR_EVALUATOR ); } @@ -118,4 +120,36 @@ static double processDoubles(double lhs, double rhs) { return NumericUtils.asFiniteNumber(lhs * rhs); } + private static float mulDenseVectorElements(float lhs, float rhs) { + return NumericUtils.asFiniteNumber(lhs * rhs); + } + + private static final DenseVectorBinaryEvaluator MUL_DENSE_VECTOR_EVALUATOR = new DenseVectorBinaryEvaluator() { + @Override + public EvalOperator.ExpressionEvaluator.Factory vectorsOperation( + Source source, + EvalOperator.ExpressionEvaluator.Factory lhs, + EvalOperator.ExpressionEvaluator.Factory rhs + ) { + return new DenseVectorsEvaluator.Factory(source, lhs, rhs, Mul::mulDenseVectorElements, OP_NAME); + } + + @Override + public EvalOperator.ExpressionEvaluator.Factory scalarVectorOperation( + Source source, + float lhs, + EvalOperator.ExpressionEvaluator.Factory rhs + ) { + return new DenseVectorScalarEvaluator.Factory(source, lhs, rhs, Mul::mulDenseVectorElements, OP_NAME); + } + + @Override + public EvalOperator.ExpressionEvaluator.Factory vectorScalarOperation( + Source source, + EvalOperator.ExpressionEvaluator.Factory lhs, + float rhs + ) { + return new DenseVectorScalarEvaluator.Factory(source, lhs, rhs, Mul::mulDenseVectorElements, OP_NAME); + } + }; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Sub.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Sub.java index 3190170a7741d..52697d37f988d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Sub.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Sub.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.time.DateUtils; import org.elasticsearch.compute.ann.Evaluator; import org.elasticsearch.compute.ann.Fixed; +import org.elasticsearch.compute.operator.EvalOperator; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -38,6 +39,7 @@ public class Sub extends DateTimeArithmeticOperation implements BinaryComparisonInversible { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Sub", Sub::new); + public static final String OP_NAME = "Sub"; private final Configuration configuration; @@ -73,7 +75,7 @@ public Sub( SubLongsEvaluator.Factory::new, SubUnsignedLongsEvaluator.Factory::new, SubDoublesEvaluator.Factory::new, - DenseVectorsEvaluator.SubFactory::new, + SUB_DENSE_VECTOR_EVALUATOR, SubDatetimesEvaluator.Factory::new, SubDateNanosEvaluator.Factory::new ); @@ -88,7 +90,7 @@ private Sub(StreamInput in) throws IOException { SubLongsEvaluator.Factory::new, SubUnsignedLongsEvaluator.Factory::new, SubDoublesEvaluator.Factory::new, - DenseVectorsEvaluator.SubFactory::new, + SUB_DENSE_VECTOR_EVALUATOR, SubDatetimesEvaluator.Factory::new, SubDateNanosEvaluator.Factory::new ); @@ -194,4 +196,37 @@ public Configuration configuration() { public Sub withConfiguration(Configuration configuration) { return new Sub(source(), left(), right(), configuration); } + + private static float subDenseVectorElements(float lhs, float rhs) { + return NumericUtils.asFiniteNumber(lhs - rhs); + } + + private static final DenseVectorBinaryEvaluator SUB_DENSE_VECTOR_EVALUATOR = new DenseVectorBinaryEvaluator() { + @Override + public EvalOperator.ExpressionEvaluator.Factory vectorsOperation( + Source source, + EvalOperator.ExpressionEvaluator.Factory lhs, + EvalOperator.ExpressionEvaluator.Factory rhs + ) { + return new DenseVectorsEvaluator.Factory(source, lhs, rhs, Sub::subDenseVectorElements, OP_NAME); + } + + @Override + public EvalOperator.ExpressionEvaluator.Factory scalarVectorOperation( + Source source, + float lhs, + EvalOperator.ExpressionEvaluator.Factory rhs + ) { + return new DenseVectorScalarEvaluator.Factory(source, lhs, rhs, Sub::subDenseVectorElements, OP_NAME); + } + + @Override + public EvalOperator.ExpressionEvaluator.Factory vectorScalarOperation( + Source source, + EvalOperator.ExpressionEvaluator.Factory lhs, + float rhs + ) { + return new DenseVectorScalarEvaluator.Factory(source, lhs, rhs, Sub::subDenseVectorElements, OP_NAME); + } + }; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/type/EsqlDataTypeConverter.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/type/EsqlDataTypeConverter.java index 504821f3a9a46..614335de568db 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/type/EsqlDataTypeConverter.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/type/EsqlDataTypeConverter.java @@ -486,6 +486,9 @@ public static DataType commonType(DataType left, DataType right) { // Both TEXT and SEMANTIC_TEXT are processed as KEYWORD return KEYWORD; } + if ((left == DENSE_VECTOR && right.isNumeric()) || (right == DENSE_VECTOR && left.isNumeric())) { + return DENSE_VECTOR; + } if (left.isNumeric() && right.isNumeric()) { int lsize = left.estimatedSize(); int rsize = right.estimatedSize(); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/AddTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/AddTests.java index 54fe4ec6b7a67..d9e104f302a30 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/AddTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/AddTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.type.DataTypeConverter; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; import org.elasticsearch.xpack.esql.expression.function.scalar.AbstractConfigurationFunctionTestCase; import org.elasticsearch.xpack.esql.session.Configuration; @@ -35,11 +36,13 @@ import static org.elasticsearch.test.ReadableMatchers.matchesDateMillis; import static org.elasticsearch.test.ReadableMatchers.matchesDateNanos; import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; +import static org.elasticsearch.xpack.esql.core.type.DataType.FLOAT; import static org.elasticsearch.xpack.esql.core.util.DateUtils.asDateTime; import static org.elasticsearch.xpack.esql.core.util.DateUtils.asMillis; import static org.elasticsearch.xpack.esql.core.util.NumericUtils.asLongUnsigned; import static org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier.TEST_SOURCE; import static org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier.randomDenseVector; +import static org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.DenseVectorTestCaseHelper.denseVectorScalarCases; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.Matchers.startsWith; @@ -304,7 +307,7 @@ public static Iterable parameters() { new TestCaseSupplier.TypedData(left, DENSE_VECTOR, "vector1"), new TestCaseSupplier.TypedData(right, DENSE_VECTOR, "vector2") ), - "AddDenseVectorsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1]]", + "DenseVectorsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1], opName=Add]", DENSE_VECTOR, equalTo(expected) ); @@ -319,7 +322,7 @@ public static Iterable parameters() { new TestCaseSupplier.TypedData(left, DENSE_VECTOR, "vector1"), new TestCaseSupplier.TypedData(right, DENSE_VECTOR, "vector2") ), - "AddDenseVectorsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1]]", + "DenseVectorsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1], opName=Add]", DENSE_VECTOR, equalTo(null) ).withWarning("Line 1:1: evaluation of [source] failed, treating result as null. Only first 20 failures recorded.") @@ -337,13 +340,21 @@ public static Iterable parameters() { new TestCaseSupplier.TypedData(left, DENSE_VECTOR, "vector1"), new TestCaseSupplier.TypedData(right, DENSE_VECTOR, "vector2") ), - "AddDenseVectorsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1]]", + "DenseVectorsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1], opName=Add]", DENSE_VECTOR, equalTo(null) ).withWarning("Line 1:1: evaluation of [source] failed, treating result as null. Only first 20 failures recorded.") - .withWarning("Line 1:1: java.lang.ArithmeticException: not a finite double number: Infinity"); + .withWarning("Line 1:1: java.lang.ArithmeticException: not a finite float number: Infinity"); })); + suppliers.addAll( + denseVectorScalarCases( + "Add", + (v, s) -> v.stream().map(f -> f + (Float) DataTypeConverter.convert(s, FLOAT)).toList(), + (s, v) -> v.stream().map(f -> (Float) DataTypeConverter.convert(s, FLOAT) + f).toList() + ) + ); + // Set the timezone to UTC for test cases up to here suppliers = TestCaseSupplier.mapTestCases( suppliers, diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DenseVectorTestCaseHelper.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DenseVectorTestCaseHelper.java new file mode 100644 index 0000000000000..b5e753ff7c1d9 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DenseVectorTestCaseHelper.java @@ -0,0 +1,139 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic; + +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.BiFunction; + +import static org.elasticsearch.test.ESTestCase.between; +import static org.elasticsearch.test.ESTestCase.randomDouble; +import static org.elasticsearch.test.ESTestCase.randomDoubleBetween; +import static org.elasticsearch.test.ESTestCase.randomInt; +import static org.elasticsearch.test.ESTestCase.randomLong; +import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; +import static org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier.randomDenseVector; +import static org.hamcrest.Matchers.equalTo; + +public class DenseVectorTestCaseHelper { + + static List denseVectorScalarCases( + String opName, + BiFunction, Number, List> vectorScalarOp, + BiFunction, List> scalarVectorOp + ) { + List suppliers = new ArrayList<>(); + + // Vector + Integer + suppliers.add(new TestCaseSupplier(List.of(DENSE_VECTOR, DataType.INTEGER), () -> { + int dimensions = between(64, 128); + List vector = randomDenseVector(dimensions); + int scalar = randomInt(); + List expected = vectorScalarOp.apply(vector, scalar); + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(vector, DENSE_VECTOR, "vector"), + new TestCaseSupplier.TypedData(scalar, DataType.INTEGER, "scalar").forceLiteral() + ), + "DenseVectorScalarEvaluator[lhs=Attribute[channel=0], rhs=scalar_constant, opName=" + opName + "]", + DENSE_VECTOR, + equalTo(expected) + ); + })); + + // Integer + Vector + suppliers.add(new TestCaseSupplier(List.of(DataType.INTEGER, DENSE_VECTOR), () -> { + int dimensions = between(64, 128); + List vector = randomDenseVector(dimensions); + int scalar = randomInt(); + List expected = scalarVectorOp.apply(scalar, vector); + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(scalar, DataType.INTEGER, "scalar").forceLiteral(), + new TestCaseSupplier.TypedData(vector, DENSE_VECTOR, "vector") + ), + "DenseVectorScalarEvaluator[lhs=Attribute[channel=0], rhs=scalar_constant, opName=" + opName + "]", + DENSE_VECTOR, + equalTo(expected) + ); + })); + + // Vector + Long + suppliers.add(new TestCaseSupplier(List.of(DENSE_VECTOR, DataType.LONG), () -> { + int dimensions = between(64, 128); + List vector = randomDenseVector(dimensions); + long scalar = randomLong(); + List expected = vectorScalarOp.apply(vector, scalar); + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(vector, DENSE_VECTOR, "vector"), + new TestCaseSupplier.TypedData(scalar, DataType.LONG, "scalar").forceLiteral() + ), + "DenseVectorScalarEvaluator[lhs=Attribute[channel=0], rhs=scalar_constant, opName=" + opName + "]", + DENSE_VECTOR, + equalTo(expected) + ); + })); + + // Long + Vector + suppliers.add(new TestCaseSupplier(List.of(DataType.LONG, DENSE_VECTOR), () -> { + int dimensions = between(64, 128); + List vector = randomDenseVector(dimensions); + long scalar = randomLong(); + List expected = scalarVectorOp.apply(scalar, vector); + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(scalar, DataType.LONG, "scalar").forceLiteral(), + new TestCaseSupplier.TypedData(vector, DENSE_VECTOR, "vector") + ), + "DenseVectorScalarEvaluator[lhs=Attribute[channel=0], rhs=scalar_constant, opName=" + opName + "]", + DENSE_VECTOR, + equalTo(expected) + ); + })); + + // Vector + Double + suppliers.add(new TestCaseSupplier(List.of(DENSE_VECTOR, DataType.DOUBLE), () -> { + int dimensions = between(64, 128); + List vector = randomDenseVector(dimensions); + double scalar = randomDouble(); + List expected = vectorScalarOp.apply(vector, scalar); + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(vector, DENSE_VECTOR, "vector"), + new TestCaseSupplier.TypedData(scalar, DataType.DOUBLE, "scalar").forceLiteral() + ), + "DenseVectorScalarEvaluator[lhs=Attribute[channel=0], rhs=scalar_constant, opName=" + opName + "]", + DENSE_VECTOR, + equalTo(expected) + ); + })); + + // Double + Vector + suppliers.add(new TestCaseSupplier(List.of(DataType.DOUBLE, DENSE_VECTOR), () -> { + int dimensions = between(64, 128); + List vector = randomDenseVector(dimensions); + double scalar = randomDoubleBetween(-1.0, 1.0, true); + List expected = scalarVectorOp.apply(scalar, vector); + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(scalar, DataType.DOUBLE, "scalar").forceLiteral(), + new TestCaseSupplier.TypedData(vector, DENSE_VECTOR, "vector") + ), + "DenseVectorScalarEvaluator[lhs=Attribute[channel=0], rhs=scalar_constant, opName=" + opName + "]", + DENSE_VECTOR, + equalTo(expected) + ); + })); + + return suppliers; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DivTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DivTests.java index 4f3e2de19b38a..9d7b65d73393a 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DivTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DivTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.type.DataTypeConverter; import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; import org.hamcrest.Matcher; @@ -25,7 +26,9 @@ import java.util.function.Supplier; import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; +import static org.elasticsearch.xpack.esql.core.type.DataType.FLOAT; import static org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier.randomDenseVector; +import static org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.DenseVectorTestCaseHelper.denseVectorScalarCases; import static org.hamcrest.Matchers.equalTo; public class DivTests extends AbstractScalarFunctionTestCase { @@ -173,7 +176,7 @@ public static Iterable parameters() { new TestCaseSupplier.TypedData(left, DENSE_VECTOR, "vector1"), new TestCaseSupplier.TypedData(right, DENSE_VECTOR, "vector2") ), - "DivDenseVectorsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1]]", + "DenseVectorsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1], opName=Div]", DENSE_VECTOR, equalTo(expected) ); @@ -187,7 +190,7 @@ public static Iterable parameters() { new TestCaseSupplier.TypedData(left, DENSE_VECTOR, "vector1"), new TestCaseSupplier.TypedData(right, DENSE_VECTOR, "vector2") ), - "DivDenseVectorsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1]]", + "DenseVectorsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1], opName=Div]", DENSE_VECTOR, equalTo(null) ).withWarning("Line 1:1: evaluation of [source] failed, treating result as null. Only first 20 failures recorded.") @@ -204,13 +207,21 @@ public static Iterable parameters() { new TestCaseSupplier.TypedData(left, DENSE_VECTOR, "vector1"), new TestCaseSupplier.TypedData(right, DENSE_VECTOR, "vector2") ), - "DivDenseVectorsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1]]", + "DenseVectorsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1], opName=Div]", DENSE_VECTOR, equalTo(null) ).withWarning("Line 1:1: evaluation of [source] failed, treating result as null. Only first 20 failures recorded.") .withWarning("Line 1:1: java.lang.ArithmeticException: / by zero"); })); + suppliers.addAll( + denseVectorScalarCases( + "Div", + (v, s) -> v.stream().map(f -> f / (Float) DataTypeConverter.convert(s, FLOAT)).toList(), + (s, v) -> v.stream().map(f -> ((Float) DataTypeConverter.convert(s, FLOAT) / f)).toList() + ) + ); + suppliers = errorsForCasesWithoutExamples(anyNullIsNull(true, suppliers), DivTests::divErrorMessageString); // Cannot use parameterSuppliersFromTypedDataWithDefaultChecks as error messages are non-trivial diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/MulTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/MulTests.java index fb8333ca88707..02c7e1b34c9dd 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/MulTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/MulTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.type.DataTypeConverter; import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; @@ -22,8 +23,10 @@ import java.util.function.Supplier; import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; +import static org.elasticsearch.xpack.esql.core.type.DataType.FLOAT; import static org.elasticsearch.xpack.esql.core.util.NumericUtils.asLongUnsigned; import static org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier.randomDenseVector; +import static org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.DenseVectorTestCaseHelper.denseVectorScalarCases; import static org.hamcrest.Matchers.equalTo; public class MulTests extends AbstractScalarFunctionTestCase { @@ -137,7 +140,7 @@ public static Iterable parameters() { new TestCaseSupplier.TypedData(left, DENSE_VECTOR, "vector1"), new TestCaseSupplier.TypedData(right, DENSE_VECTOR, "vector2") ), - "MulDenseVectorsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1]]", + "DenseVectorsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1], opName=Mul]", DENSE_VECTOR, equalTo(expected) ); @@ -151,7 +154,7 @@ public static Iterable parameters() { new TestCaseSupplier.TypedData(left, DENSE_VECTOR, "vector1"), new TestCaseSupplier.TypedData(right, DENSE_VECTOR, "vector2") ), - "MulDenseVectorsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1]]", + "DenseVectorsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1], opName=Mul]", DENSE_VECTOR, equalTo(null) ).withWarning("Line 1:1: evaluation of [source] failed, treating result as null. Only first 20 failures recorded.") @@ -168,13 +171,21 @@ public static Iterable parameters() { new TestCaseSupplier.TypedData(left, DENSE_VECTOR, "vector1"), new TestCaseSupplier.TypedData(right, DENSE_VECTOR, "vector2") ), - "MulDenseVectorsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1]]", + "DenseVectorsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1], opName=Mul]", DENSE_VECTOR, equalTo(null) ).withWarning("Line 1:1: evaluation of [source] failed, treating result as null. Only first 20 failures recorded.") - .withWarning("Line 1:1: java.lang.ArithmeticException: not a finite double number: Infinity"); + .withWarning("Line 1:1: java.lang.ArithmeticException: not a finite float number: Infinity"); })); + suppliers.addAll( + denseVectorScalarCases( + "Mul", + (v, s) -> v.stream().map(f -> f * (Float) DataTypeConverter.convert(s, FLOAT)).toList(), + (s, v) -> v.stream().map(f -> (Float) DataTypeConverter.convert(s, FLOAT) * f).toList() + ) + ); + suppliers = errorsForCasesWithoutExamples(anyNullIsNull(true, suppliers), MulTests::mulErrorMessageString); // Cannot use parameterSuppliersFromTypedDataWithDefaultChecks as error messages are non-trivial diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/SubTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/SubTests.java index a28dff20cf457..ef714fdb19993 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/SubTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/SubTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.type.DataTypeConverter; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; import org.elasticsearch.xpack.esql.expression.function.scalar.AbstractConfigurationFunctionTestCase; import org.elasticsearch.xpack.esql.session.Configuration; @@ -37,11 +38,13 @@ import static org.elasticsearch.test.ReadableMatchers.matchesDateNanos; import static org.elasticsearch.xpack.esql.EsqlTestUtils.randomLiteral; import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; +import static org.elasticsearch.xpack.esql.core.type.DataType.FLOAT; import static org.elasticsearch.xpack.esql.core.util.DateUtils.asDateTime; import static org.elasticsearch.xpack.esql.core.util.DateUtils.asMillis; import static org.elasticsearch.xpack.esql.core.util.NumericUtils.ZERO_AS_UNSIGNED_LONG; import static org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier.TEST_SOURCE; import static org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier.randomDenseVector; +import static org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.DenseVectorTestCaseHelper.denseVectorScalarCases; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.startsWith; @@ -264,7 +267,7 @@ public static Iterable parameters() { new TestCaseSupplier.TypedData(left, DENSE_VECTOR, "vector1"), new TestCaseSupplier.TypedData(right, DENSE_VECTOR, "vector2") ), - "SubDenseVectorsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1]]", + "DenseVectorsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1], opName=Sub]", DENSE_VECTOR, equalTo(expected) ); @@ -278,7 +281,7 @@ public static Iterable parameters() { new TestCaseSupplier.TypedData(left, DENSE_VECTOR, "vector1"), new TestCaseSupplier.TypedData(right, DENSE_VECTOR, "vector2") ), - "SubDenseVectorsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1]]", + "DenseVectorsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1], opName=Sub]", DENSE_VECTOR, equalTo(null) ).withWarning("Line 1:1: evaluation of [source] failed, treating result as null. Only first 20 failures recorded.") @@ -295,13 +298,21 @@ public static Iterable parameters() { new TestCaseSupplier.TypedData(left, DENSE_VECTOR, "vector1"), new TestCaseSupplier.TypedData(right, DENSE_VECTOR, "vector2") ), - "SubDenseVectorsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1]]", + "DenseVectorsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1], opName=Sub]", DENSE_VECTOR, equalTo(null) ).withWarning("Line 1:1: evaluation of [source] failed, treating result as null. Only first 20 failures recorded.") - .withWarning("Line 1:1: java.lang.ArithmeticException: not a finite double number: -Infinity"); + .withWarning("Line 1:1: java.lang.ArithmeticException: not a finite float number: -Infinity"); })); + suppliers.addAll( + denseVectorScalarCases( + "Sub", + (v, s) -> v.stream().map(f -> f - (Float) DataTypeConverter.convert(s, FLOAT)).toList(), + (s, v) -> v.stream().map(f -> (Float) DataTypeConverter.convert(s, FLOAT) - f).toList() + ) + ); + // Set the timezone to UTC for test cases up to here suppliers = TestCaseSupplier.mapTestCases( suppliers, @@ -323,7 +334,7 @@ public static Iterable parameters() { } return original.expectedType(); }, (nullPosition, nullData, original) -> { - if (DataType.isTemporalAmount(nullData.type())) { + if (DataType.isTemporalAmount(nullData.type()) || nullData.isForceLiteral()) { return equalTo("LiteralsEvaluator[lit=null]"); } return original; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/EsqlDataTypeConverterTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/EsqlDataTypeConverterTests.java index e8e595cd2277a..871e8f3be97c6 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/EsqlDataTypeConverterTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/type/EsqlDataTypeConverterTests.java @@ -36,6 +36,7 @@ import static org.elasticsearch.xpack.esql.core.type.DataType.COUNTER_LONG; import static org.elasticsearch.xpack.esql.core.type.DataType.DATETIME; import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_NANOS; +import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; import static org.elasticsearch.xpack.esql.core.type.DataType.DOC_DATA_TYPE; import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE; import static org.elasticsearch.xpack.esql.core.type.DataType.FLOAT; @@ -180,6 +181,8 @@ public void testCommonTypeNumeric() { commonNumericType(FLOAT, List.of(NULL, BYTE, SHORT, INTEGER, LONG, UNSIGNED_LONG, FLOAT, HALF_FLOAT)); commonNumericType(DOUBLE, List.of(NULL, BYTE, SHORT, INTEGER, LONG, UNSIGNED_LONG, HALF_FLOAT, FLOAT, DOUBLE, SCALED_FLOAT)); commonNumericType(SCALED_FLOAT, List.of(NULL, BYTE, SHORT, INTEGER, LONG, UNSIGNED_LONG, HALF_FLOAT, FLOAT, SCALED_FLOAT, DOUBLE)); + // dense vectors + commonNumericType(DENSE_VECTOR, List.of(NULL, BYTE, SHORT, INTEGER, LONG, UNSIGNED_LONG, HALF_FLOAT, FLOAT, DOUBLE, SCALED_FLOAT)); } /** @@ -189,7 +192,9 @@ private static void commonNumericType(DataType numericType, List lower List NUMERICS = Arrays.stream(DataType.values()).filter(DataType::isNumeric).toList(); List DOUBLES = Arrays.stream(DataType.values()).filter(DataType::isRationalNumber).toList(); for (DataType dataType : DataType.values()) { - if (DOUBLES.containsAll(List.of(numericType, dataType)) && (dataType.estimatedSize() == numericType.estimatedSize())) { + if (dataType == DENSE_VECTOR) { + assertEquals(DENSE_VECTOR, commonType(dataType, numericType)); + } else if (DOUBLES.containsAll(List.of(numericType, dataType)) && (dataType.estimatedSize() == numericType.estimatedSize())) { assertEquals(numericType, commonType(dataType, numericType)); } else if (lowerTypes.contains(dataType)) { assertEqualsCommonType(numericType, dataType, numericType);