diff --git a/docs/changelog/133636.yaml b/docs/changelog/133636.yaml new file mode 100644 index 0000000000000..46202e2d1c2b1 --- /dev/null +++ b/docs/changelog/133636.yaml @@ -0,0 +1,5 @@ +pr: 133636 +summary: Esql `mv_contains` function +area: ES|QL +type: enhancement +issues: [] diff --git a/docs/reference/query-languages/esql/_snippets/functions/description/mv_contains.md b/docs/reference/query-languages/esql/_snippets/functions/description/mv_contains.md new file mode 100644 index 0000000000000..da02211ed38db --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/description/mv_contains.md @@ -0,0 +1,6 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Description** + +Checks if all values yielded by the second multivalue expression are present in the values yielded by the first multivalue expression. Returns a boolean. Null values are treated as an empty set. + diff --git a/docs/reference/query-languages/esql/_snippets/functions/examples/mv_contains.md b/docs/reference/query-languages/esql/_snippets/functions/examples/mv_contains.md new file mode 100644 index 0000000000000..df77ca463c9be --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/examples/mv_contains.md @@ -0,0 +1,34 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Examples** + +```esql +ROW set = ["a", "b", "c"], element = "a" +| EVAL set_contains_element = mv_contains(set, element) +``` + +| set:keyword | element:keyword | set_contains_element:boolean | +| --- | --- | --- | +| [a, b, c] | a | true | + +```esql +ROW setA = ["a","c"], setB = ["a", "b", "c"] +| EVAL a_subset_of_b = mv_contains(setB, setA) +| EVAL b_subset_of_a = mv_contains(setA, setB) +``` + +| setA:keyword | setB:keyword | a_subset_of_b:boolean | b_subset_of_a:boolean | +| --- | --- | --- | --- | +| [a, c] | [a, b, c] | true | false | + +```esql +FROM airports +| WHERE mv_contains(type, ["major","military"]) AND scalerank == 9 +| KEEP scalerank, name, country +``` + +| scalerank:integer | name:text | country:keyword | +| --- | --- | --- | +| 9 | Chandigarh Int'l | India | + + diff --git a/docs/reference/query-languages/esql/_snippets/functions/layout/mv_contains.md b/docs/reference/query-languages/esql/_snippets/functions/layout/mv_contains.md new file mode 100644 index 0000000000000..aee13481b435f --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/layout/mv_contains.md @@ -0,0 +1,26 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +## `MV_CONTAINS` [esql-mv_contains] +```{applies_to} +stack: preview 9.2.0 +``` + +**Syntax** + +:::{image} ../../../images/functions/mv_contains.svg +:alt: Embedded +:class: text-center +::: + + +:::{include} ../parameters/mv_contains.md +::: + +:::{include} ../description/mv_contains.md +::: + +:::{include} ../types/mv_contains.md +::: + +:::{include} ../examples/mv_contains.md +::: diff --git a/docs/reference/query-languages/esql/_snippets/functions/parameters/mv_contains.md b/docs/reference/query-languages/esql/_snippets/functions/parameters/mv_contains.md new file mode 100644 index 0000000000000..9f002719dbd99 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/parameters/mv_contains.md @@ -0,0 +1,10 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Parameters** + +`superset` +: Multivalue expression. + +`subset` +: Multivalue expression. + diff --git a/docs/reference/query-languages/esql/_snippets/functions/types/mv_contains.md b/docs/reference/query-languages/esql/_snippets/functions/types/mv_contains.md new file mode 100644 index 0000000000000..fe4b46ac5b280 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/types/mv_contains.md @@ -0,0 +1,24 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Supported types** + +| superset | subset | result | +| --- | --- | --- | +| boolean | boolean | boolean | +| cartesian_point | cartesian_point | boolean | +| cartesian_shape | cartesian_shape | boolean | +| date | date | boolean | +| date_nanos | date_nanos | boolean | +| double | double | boolean | +| geo_point | geo_point | boolean | +| geo_shape | geo_shape | boolean | +| integer | integer | boolean | +| ip | ip | boolean | +| keyword | keyword | boolean | +| keyword | text | boolean | +| long | long | boolean | +| text | keyword | boolean | +| text | text | boolean | +| unsigned_long | unsigned_long | boolean | +| version | version | boolean | + diff --git a/docs/reference/query-languages/esql/_snippets/lists/mv-functions.md b/docs/reference/query-languages/esql/_snippets/lists/mv-functions.md index a7b32dfb3835e..db2d1149e7f75 100644 --- a/docs/reference/query-languages/esql/_snippets/lists/mv-functions.md +++ b/docs/reference/query-languages/esql/_snippets/lists/mv-functions.md @@ -1,6 +1,7 @@ * [`MV_APPEND`](../../functions-operators/mv-functions.md#esql-mv_append) * [`MV_AVG`](../../functions-operators/mv-functions.md#esql-mv_avg) * [`MV_CONCAT`](../../functions-operators/mv-functions.md#esql-mv_concat) +* [preview] [`MV_CONTAINS`](../../functions-operators/mv-functions.md#esql-mv_contains) * [`MV_COUNT`](../../functions-operators/mv-functions.md#esql-mv_count) * [`MV_DEDUPE`](../../functions-operators/mv-functions.md#esql-mv_dedupe) * [`MV_FIRST`](../../functions-operators/mv-functions.md#esql-mv_first) diff --git a/docs/reference/query-languages/esql/functions-operators/mv-functions.md b/docs/reference/query-languages/esql/functions-operators/mv-functions.md index 7eca1a53ab8ff..acb0b882e6bfc 100644 --- a/docs/reference/query-languages/esql/functions-operators/mv-functions.md +++ b/docs/reference/query-languages/esql/functions-operators/mv-functions.md @@ -21,6 +21,9 @@ mapped_pages: :::{include} ../_snippets/functions/layout/mv_concat.md ::: +:::{include} ../_snippets/functions/layout/mv_contains.md +::: + :::{include} ../_snippets/functions/layout/mv_count.md ::: diff --git a/docs/reference/query-languages/esql/images/functions/mv_contains.svg b/docs/reference/query-languages/esql/images/functions/mv_contains.svg new file mode 100644 index 0000000000000..3a588496e392b --- /dev/null +++ b/docs/reference/query-languages/esql/images/functions/mv_contains.svg @@ -0,0 +1 @@ +MV_CONTAINS(superset,subset) \ No newline at end of file diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/mv_contains.json b/docs/reference/query-languages/esql/kibana/definition/functions/mv_contains.json new file mode 100644 index 0000000000000..0116939ba4f59 --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/definition/functions/mv_contains.json @@ -0,0 +1,321 @@ +{ + "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.", + "type" : "scalar", + "name" : "mv_contains", + "description" : "Checks if all values yielded by the second multivalue expression are present in the values yielded by the first multivalue expression. Returns a boolean. Null values are treated as an empty set.", + "signatures" : [ + { + "params" : [ + { + "name" : "superset", + "type" : "boolean", + "optional" : false, + "description" : "Multivalue expression." + }, + { + "name" : "subset", + "type" : "boolean", + "optional" : false, + "description" : "Multivalue expression." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, + { + "params" : [ + { + "name" : "superset", + "type" : "cartesian_point", + "optional" : false, + "description" : "Multivalue expression." + }, + { + "name" : "subset", + "type" : "cartesian_point", + "optional" : false, + "description" : "Multivalue expression." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, + { + "params" : [ + { + "name" : "superset", + "type" : "cartesian_shape", + "optional" : false, + "description" : "Multivalue expression." + }, + { + "name" : "subset", + "type" : "cartesian_shape", + "optional" : false, + "description" : "Multivalue expression." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, + { + "params" : [ + { + "name" : "superset", + "type" : "date", + "optional" : false, + "description" : "Multivalue expression." + }, + { + "name" : "subset", + "type" : "date", + "optional" : false, + "description" : "Multivalue expression." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, + { + "params" : [ + { + "name" : "superset", + "type" : "date_nanos", + "optional" : false, + "description" : "Multivalue expression." + }, + { + "name" : "subset", + "type" : "date_nanos", + "optional" : false, + "description" : "Multivalue expression." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, + { + "params" : [ + { + "name" : "superset", + "type" : "double", + "optional" : false, + "description" : "Multivalue expression." + }, + { + "name" : "subset", + "type" : "double", + "optional" : false, + "description" : "Multivalue expression." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, + { + "params" : [ + { + "name" : "superset", + "type" : "geo_point", + "optional" : false, + "description" : "Multivalue expression." + }, + { + "name" : "subset", + "type" : "geo_point", + "optional" : false, + "description" : "Multivalue expression." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, + { + "params" : [ + { + "name" : "superset", + "type" : "geo_shape", + "optional" : false, + "description" : "Multivalue expression." + }, + { + "name" : "subset", + "type" : "geo_shape", + "optional" : false, + "description" : "Multivalue expression." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, + { + "params" : [ + { + "name" : "superset", + "type" : "integer", + "optional" : false, + "description" : "Multivalue expression." + }, + { + "name" : "subset", + "type" : "integer", + "optional" : false, + "description" : "Multivalue expression." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, + { + "params" : [ + { + "name" : "superset", + "type" : "ip", + "optional" : false, + "description" : "Multivalue expression." + }, + { + "name" : "subset", + "type" : "ip", + "optional" : false, + "description" : "Multivalue expression." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, + { + "params" : [ + { + "name" : "superset", + "type" : "keyword", + "optional" : false, + "description" : "Multivalue expression." + }, + { + "name" : "subset", + "type" : "keyword", + "optional" : false, + "description" : "Multivalue expression." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, + { + "params" : [ + { + "name" : "superset", + "type" : "keyword", + "optional" : false, + "description" : "Multivalue expression." + }, + { + "name" : "subset", + "type" : "text", + "optional" : false, + "description" : "Multivalue expression." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, + { + "params" : [ + { + "name" : "superset", + "type" : "long", + "optional" : false, + "description" : "Multivalue expression." + }, + { + "name" : "subset", + "type" : "long", + "optional" : false, + "description" : "Multivalue expression." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, + { + "params" : [ + { + "name" : "superset", + "type" : "text", + "optional" : false, + "description" : "Multivalue expression." + }, + { + "name" : "subset", + "type" : "keyword", + "optional" : false, + "description" : "Multivalue expression." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, + { + "params" : [ + { + "name" : "superset", + "type" : "text", + "optional" : false, + "description" : "Multivalue expression." + }, + { + "name" : "subset", + "type" : "text", + "optional" : false, + "description" : "Multivalue expression." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, + { + "params" : [ + { + "name" : "superset", + "type" : "unsigned_long", + "optional" : false, + "description" : "Multivalue expression." + }, + { + "name" : "subset", + "type" : "unsigned_long", + "optional" : false, + "description" : "Multivalue expression." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, + { + "params" : [ + { + "name" : "superset", + "type" : "version", + "optional" : false, + "description" : "Multivalue expression." + }, + { + "name" : "subset", + "type" : "version", + "optional" : false, + "description" : "Multivalue expression." + } + ], + "variadic" : false, + "returnType" : "boolean" + } + ], + "examples" : [ + "ROW set = [\"a\", \"b\", \"c\"], element = \"a\"\n| EVAL set_contains_element = mv_contains(set, element)", + "ROW setA = [\"a\",\"c\"], setB = [\"a\", \"b\", \"c\"]\n| EVAL a_subset_of_b = mv_contains(setB, setA)\n| EVAL b_subset_of_a = mv_contains(setA, setB)", + "FROM airports\n| WHERE mv_contains(type, [\"major\",\"military\"]) AND scalerank == 9\n| KEEP scalerank, name, country" + ], + "preview" : false, + "snapshot_only" : false +} diff --git a/docs/reference/query-languages/esql/kibana/docs/functions/mv_contains.md b/docs/reference/query-languages/esql/kibana/docs/functions/mv_contains.md new file mode 100644 index 0000000000000..4bc82881dc292 --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/docs/functions/mv_contains.md @@ -0,0 +1,9 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +### MV CONTAINS +Checks if all values yielded by the second multivalue expression are present in the values yielded by the first multivalue expression. Returns a boolean. Null values are treated as an empty set. + +```esql +ROW set = ["a", "b", "c"], element = "a" +| EVAL set_contains_element = mv_contains(set, element) +``` diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EvalOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EvalOperator.java index 2c9bf74fb8b0a..983b1cc7cfcbf 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EvalOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/EvalOperator.java @@ -157,34 +157,114 @@ default boolean eagerEvalSafeInLazy() { long baseRamBytesUsed(); } - public static final ExpressionEvaluator.Factory CONSTANT_NULL_FACTORY = new ExpressionEvaluator.Factory() { + private record ConstantNullEvaluator(DriverContext context) implements ExpressionEvaluator { + private static final String NAME = "ConstantNull"; + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(ConstantNullEvaluator.class); + + @Override + public Block eval(Page page) { + return context.blockFactory().newConstantNullBlock(page.getPositionCount()); + } + @Override - public ExpressionEvaluator get(DriverContext driverContext) { - return new ExpressionEvaluator() { - @Override - public Block eval(Page page) { - return driverContext.blockFactory().newConstantNullBlock(page.getPositionCount()); - } - - @Override - public void close() {} - - @Override - public String toString() { - return CONSTANT_NULL_NAME; - } - - @Override - public long baseRamBytesUsed() { - return 0; - } + public void close() {} + + @Override + public String toString() { + return NAME; + } + + @Override + public long baseRamBytesUsed() { + return BASE_RAM_BYTES_USED; + } + + record Factory() implements ExpressionEvaluator.Factory { + @Override + public ConstantNullEvaluator get(DriverContext context) { + return new ConstantNullEvaluator(context); }; + + @Override + public String toString() { + return NAME; + } + }; + } + + public static final ExpressionEvaluator.Factory CONSTANT_NULL_FACTORY = new ConstantNullEvaluator.Factory(); + + private record ConstantTrueEvaluator(DriverContext context) implements ExpressionEvaluator { + private static final String NAME = "ConstantTrue"; + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(ConstantTrueEvaluator.class); + + @Override + public Block eval(Page page) { + return context.blockFactory().newConstantBooleanBlockWith(true, page.getPositionCount()); } + @Override + public void close() {} + @Override public String toString() { - return CONSTANT_NULL_NAME; + return NAME; + } + + @Override + public long baseRamBytesUsed() { + return BASE_RAM_BYTES_USED; } - }; - private static final String CONSTANT_NULL_NAME = "ConstantNull"; + + record Factory() implements ExpressionEvaluator.Factory { + @Override + public ConstantTrueEvaluator get(DriverContext context) { + return new ConstantTrueEvaluator(context); + }; + + @Override + public String toString() { + return NAME; + } + }; + } + + public static final ExpressionEvaluator.Factory CONSTANT_TRUE_FACTORY = new ConstantTrueEvaluator.Factory(); + + private record ConstantFalseEvaluator(DriverContext context) implements ExpressionEvaluator { + private static final String NAME = "ConstantFalse"; + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(ConstantFalseEvaluator.class); + + @Override + public Block eval(Page page) { + return context.blockFactory().newConstantBooleanBlockWith(false, page.getPositionCount()); + } + + @Override + public void close() {} + + @Override + public String toString() { + return NAME; + } + + @Override + public long baseRamBytesUsed() { + return BASE_RAM_BYTES_USED; + } + + record Factory() implements ExpressionEvaluator.Factory { + @Override + public ConstantFalseEvaluator get(DriverContext context) { + return new ConstantFalseEvaluator(context); + }; + + @Override + public String toString() { + return NAME; + } + }; + } + + public static final ExpressionEvaluator.Factory CONSTANT_FALSE_FACTORY = new ConstantFalseEvaluator.Factory(); } diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/string.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/string.csv-spec index d89f5a52f9899..d2fbc849318de 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/string.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/string.csv-spec @@ -2026,6 +2026,92 @@ l1:integer | l2:integer null | 0 ; +mvContains +required_capability: fn_mv_contains +// tag::mv_contains[] +ROW set = ["a", "b", "c"], element = "a" +| EVAL set_contains_element = mv_contains(set, element) +// end::mv_contains[] +; + +// tag::mv_contains-result[] +set:keyword | element:keyword | set_contains_element:boolean +[a, b, c] | a | true +// end::mv_contains-result[] +; + +mvContains_bothsides +required_capability: fn_mv_contains +// tag::mv_contains_bothsides[] +ROW setA = ["a","c"], setB = ["a", "b", "c"] +| EVAL a_subset_of_b = mv_contains(setB, setA) +| EVAL b_subset_of_a = mv_contains(setA, setB) +// end::mv_contains_bothsides[] +; + +// tag::mv_contains_bothsides-result[] +setA:keyword | setB:keyword | a_subset_of_b:boolean | b_subset_of_a:boolean +[a, c] | [a, b, c] | true | false +// end::mv_contains_bothsides-result[] +; + +mvContainsCombinations +required_capability: fn_mv_contains + +ROW a = "a", b = ["a", "b", "c"], n = null +| EVAL aa = mv_contains(a, a), + bb = mv_contains(b, b), + ab = mv_contains(a, b), + ba = mv_contains(b,a), + na = mv_contains(n, a), + an = mv_contains(a, n), + nn = mv_contains(n,n) +; + +a:keyword | b:keyword | n:null | aa:boolean | bb:boolean | ab:boolean | ba:boolean | na:boolean | an:boolean | nn:boolean +a | [a, b, c] | null | true | true | false | true | false | true | true +; + +mvContainsCombinations_multirow +required_capability: fn_mv_contains + +ROW row_number = [1,2,3,4,5], element = "e", n = null, setA = ["b","d"], setB = ["a", "c", "e"] +| MV_EXPAND row_number +| EVAL superset = CASE( + row_number == 1, ["a","e"], + row_number == 2, ["b","d"], + row_number == 3, null, + row_number == 4, ["a","e","c","b","d"], + row_number == 5, ["a","d","c","b","e"], + null) +| EVAL contains_element = mv_contains(superset, element), + contains_null = mv_contains(superset, n), + contains_setA = mv_contains(superset, setA), + contains_setB = mv_contains(superset, setB) +; + +row_number:INTEGER | element:keyword | n:null | setA:keyword | setB:keyword | superset:keyword |contains_element:boolean | contains_null:boolean | contains_setA:boolean | contains_setB:boolean +1 | "e" | null | ["b","d"] | ["a", "c", "e"] | ["a","e"] | true | true | false | false +2 | "e" | null | ["b","d"] | ["a", "c", "e"] | ["b","d"] | false | true | true | false +3 | "e" | null | ["b","d"] | ["a", "c", "e"] | null | false | true | false | false +4 | "e" | null | ["b","d"] | ["a", "c", "e"] | ["a","e","c","b","d"] | true | true | true | true +5 | "e" | null | ["b","d"] | ["a", "c", "e"] | ["a","d","c","b","e"] | true | true | true | true +; + +mvContains_where +required_capability: fn_mv_contains +// tag::mv_contains_where[] +FROM airports +| WHERE mv_contains(type, ["major","military"]) AND scalerank == 9 +| KEEP scalerank, name, country +// end::mv_contains_where[] +; + +// tag::mv_contains_where-result[] +scalerank:integer | name:text | country:keyword +9 | Chandigarh Int'l | India +// end::mv_contains_where-result[] +; mvAppend required_capability: fn_mv_append diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index c3c4121b095f4..0035ea9e07a3a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -250,6 +250,12 @@ public enum Cap { */ FN_MONTH_NAME, + /** + * support for MV_CONTAINS function + * Add MV_CONTAINS function #133099 + */ + FN_MV_CONTAINS, + /** * Fixes for multiple functions not serializing their source, and emitting warnings with wrong line number and text. */ diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index 03ef45d7127ed..bb98f4b2fbb20 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -125,6 +125,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvAppend; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvAvg; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvConcat; +import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvContains; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvCount; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvDedupe; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvFirst; @@ -455,6 +456,7 @@ private static FunctionDefinition[][] functions() { def(MvAppend.class, MvAppend::new, "mv_append"), def(MvAvg.class, MvAvg::new, "mv_avg"), def(MvConcat.class, MvConcat::new, "mv_concat"), + def(MvContains.class, MvContains::new, "mv_contains"), def(MvCount.class, MvCount::new, "mv_count"), def(MvDedupe.class, MvDedupe::new, "mv_dedupe"), def(MvFirst.class, MvFirst::new, "mv_first"), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvContains.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvContains.java new file mode 100644 index 0000000000000..dfdec763e330a --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvContains.java @@ -0,0 +1,791 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue; + +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.expression.Nullability; +import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper; +import org.elasticsearch.xpack.esql.expression.function.Example; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNull; +import org.elasticsearch.xpack.esql.planner.PlannerUtils; + +import java.io.IOException; + +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isRepresentableExceptCounters; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; + +/** + * Function that takes two multivalued expressions and checks if values of one expression are all present(equals) in the other. + *

+ * Given Set A = {"a","b","c"} and Set B = {"b","c"}, the relationship between first (row) and second (column) arguments is: + *

+ */ +public class MvContains extends BinaryScalarFunction implements EvaluatorMapper { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "MvContains", + MvContains::new + ); + + @FunctionInfo( + returnType = "boolean", + description = "Checks if all values yielded by the second multivalue expression are present in the values yielded by " + + "the first multivalue expression. Returns a boolean. Null values are treated as an empty set.", + examples = { + @Example(file = "string", tag = "mv_contains"), + @Example(file = "string", tag = "mv_contains_bothsides"), + @Example(file = "string", tag = "mv_contains_where"), }, + appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.PREVIEW, version = "9.2.0") } + ) + public MvContains( + Source source, + @Param( + name = "superset", + type = { + "boolean", + "cartesian_point", + "cartesian_shape", + "date", + "date_nanos", + "double", + "geo_point", + "geo_shape", + "integer", + "ip", + "keyword", + "long", + "text", + "unsigned_long", + "version" }, + description = "Multivalue expression." + ) Expression superset, + @Param( + name = "subset", + type = { + "boolean", + "cartesian_point", + "cartesian_shape", + "date", + "date_nanos", + "double", + "geo_point", + "geo_shape", + "integer", + "ip", + "keyword", + "long", + "text", + "unsigned_long", + "version" }, + description = "Multivalue expression." + ) Expression subset + ) { + super(source, superset, subset); + } + + private MvContains(StreamInput in) throws IOException { + super(in); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + @Override + protected TypeResolution resolveType() { + if (childrenResolved() == false) { + return new TypeResolution("Unresolved children"); + } + + TypeResolution resolution = isRepresentableExceptCounters(left(), sourceText(), FIRST); + if (resolution.unresolved()) { + return resolution; + } + if (left().dataType() == DataType.NULL) { + return isRepresentableExceptCounters(right(), sourceText(), SECOND); + } + return isType(right(), t -> t.noText() == left().dataType().noText(), sourceText(), SECOND, left().dataType().noText().typeName()); + } + + @Override + public DataType dataType() { + return DataType.BOOLEAN; + } + + @Override + public Nullability nullable() { + return Nullability.FALSE; + } + + @Override + protected MvContains replaceChildren(Expression newLeft, Expression newRight) { + return new MvContains(source(), newLeft, newRight); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, MvContains::new, left(), right()); + } + + @Override + public Object fold(FoldContext ctx) { + return EvaluatorMapper.super.fold(source(), ctx); + } + + @Override + public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { + var supersetType = PlannerUtils.toElementType(left().dataType()); + var subsetType = PlannerUtils.toElementType(right().dataType()); + + if (subsetType == ElementType.NULL) { + return EvalOperator.CONSTANT_TRUE_FACTORY; + } + + if (supersetType != ElementType.NULL && supersetType != subsetType) { + throw new EsqlIllegalArgumentException( + "Incompatible data types for MvContains, superset type({}) value({}) and subset type({}) value({}) don't match.", + supersetType, + left(), + subsetType, + right() + ); + } + + return switch (supersetType) { + case BOOLEAN -> new MvContainsBooleanEvaluator.Factory(source(), toEvaluator.apply(left()), toEvaluator.apply(right())); + case BYTES_REF -> new MvContainsBytesRefEvaluator.Factory(source(), toEvaluator.apply(left()), toEvaluator.apply(right())); + case DOUBLE -> new MvContainsDoubleEvaluator.Factory(source(), toEvaluator.apply(left()), toEvaluator.apply(right())); + case INT -> new MvContainsIntEvaluator.Factory(source(), toEvaluator.apply(left()), toEvaluator.apply(right())); + case LONG -> new MvContainsLongEvaluator.Factory(source(), toEvaluator.apply(left()), toEvaluator.apply(right())); + case NULL -> new IsNull.IsNullEvaluatorFactory(toEvaluator.apply(right())); + default -> throw EsqlIllegalArgumentException.illegalDataType(dataType()); + }; + } + + // @Evaluator(extraName = "Int") see end of file. + static void process(BooleanBlock.Builder builder, int position, IntBlock field1, IntBlock field2) { + appendTo(builder, containsAll(field1, field2, position, IntBlock::getInt)); + } + + // @Evaluator(extraName = "Boolean") see end of file. + static void process(BooleanBlock.Builder builder, int position, BooleanBlock field1, BooleanBlock field2) { + appendTo(builder, containsAll(field1, field2, position, BooleanBlock::getBoolean)); + } + + // @Evaluator(extraName = "Long") see end of file. + static void process(BooleanBlock.Builder builder, int position, LongBlock field1, LongBlock field2) { + appendTo(builder, containsAll(field1, field2, position, LongBlock::getLong)); + } + + // @Evaluator(extraName = "Double") see end of file. + static void process(BooleanBlock.Builder builder, int position, DoubleBlock field1, DoubleBlock field2) { + appendTo(builder, containsAll(field1, field2, position, DoubleBlock::getDouble)); + } + + // @Evaluator(extraName = "BytesRef") see end of file. + static void process(BooleanBlock.Builder builder, int position, BytesRefBlock field1, BytesRefBlock field2) { + appendTo(builder, containsAll(field1, field2, position, (block, index) -> { + var ref = new BytesRef(); + // we pass in a reference, but sometimes we only get a return value, see ConstantBytesRefVector.getBytesRef + ref = block.getBytesRef(index, ref); + // pass empty ref as null + if (ref.length == 0) { + return null; + } + return ref; + })); + } + + static void appendTo(BooleanBlock.Builder builder, Boolean bool) { + if (bool == null) { + builder.appendNull(); + } else { + builder.beginPositionEntry().appendBoolean(bool).endPositionEntry(); + } + } + + /** + * A block is considered a subset if the superset contains values that test equal for all the values in the subset, independent of + * order. Duplicates are ignored in the sense that for each duplicate in the subset, we will search/match against the first/any value + * in the superset. + * + * @param superset block to check against + * @param subset block containing values that should be present in the other block. + * @return {@code true} if the given blocks are a superset and subset to each other, {@code false} if not. + */ + static Boolean containsAll( + BlockType superset, + BlockType subset, + final int position, + ValueExtractor valueExtractor + ) { + if (superset == subset) { + return true; + } + if (subset.areAllValuesNull()) { + return true; + } + + final var valueCount = subset.getValueCount(position); + final var startIndex = subset.getFirstValueIndex(position); + for (int valueIndex = startIndex; valueIndex < startIndex + valueCount; valueIndex++) { + var value = valueExtractor.extractValue(subset, valueIndex); + if (value == null) { // null entries are considered to always be an element in the superset. + continue; + } + if (hasValue(superset, position, value, valueExtractor) == false) { + return false; + } + } + return true; + } + + /** + * Check if the block has the value at any of it's positions + * @param superset Block to search + * @param value to search for + * @return true if the supplied long value is in the supplied Block + */ + static boolean hasValue( + BlockType superset, + final int position, + Type value, + ValueExtractor valueExtractor + ) { + final var supersetCount = superset.getValueCount(position); + final var startIndex = superset.getFirstValueIndex(position); + for (int supersetIndex = startIndex; supersetIndex < startIndex + supersetCount; supersetIndex++) { + var element = valueExtractor.extractValue(superset, supersetIndex); + if (element != null && element.equals(value)) { + return true; + } + } + return false; + } + + interface ValueExtractor { + Type extractValue(BlockType block, int position); + } + + /** + * Evaluator that always returns true for all values in the block (~column) + */ + public static final class ConstantBooleanTrueEvaluator implements ExpressionEvaluator.Factory { + @Override + public ExpressionEvaluator get(DriverContext driverContext) { + return new ExpressionEvaluator() { + @Override + public Block eval(Page page) { + return driverContext.blockFactory().newConstantBooleanBlockWith(true, page.getPositionCount()); + } + + @Override + public void close() {} + + @Override + public String toString() { + return "ConstantBooleanTrueEvaluator"; + } + + @Override + public long baseRamBytesUsed() { + return 0; + } + }; + } + + @Override + public String toString() { + return "ConstantBooleanTrueEvaluator"; + } + } + + /** + * Currently {@code EvaluatorImplementer} generates: + * if (allBlocksAreNulls) { + * result.appendNull(); + * continue position; + * } + * when all params are null, this violates our contract of always returning a boolean. + * It should probably also generate the warnings method conditionally - omitted here. + * TODO extend code generation to handle this case + */ + public static class MvContainsBooleanEvaluator implements EvalOperator.ExpressionEvaluator { + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(MvContainsBooleanEvaluator.class); + private final EvalOperator.ExpressionEvaluator field1; + private final EvalOperator.ExpressionEvaluator field2; + private final DriverContext driverContext; + + public MvContainsBooleanEvaluator( + EvalOperator.ExpressionEvaluator field1, + EvalOperator.ExpressionEvaluator field2, + DriverContext driverContext + ) { + this.field1 = field1; + this.field2 = field2; + this.driverContext = driverContext; + } + + @Override + public Block eval(Page page) { + try (BooleanBlock field1Block = (BooleanBlock) field1.eval(page)) { + try (BooleanBlock field2Block = (BooleanBlock) field2.eval(page)) { + return eval(page.getPositionCount(), field1Block, field2Block); + } + } + } + + public BooleanBlock eval(int positionCount, BooleanBlock field1Block, BooleanBlock field2Block) { + try (BooleanBlock.Builder result = driverContext.blockFactory().newBooleanBlockBuilder(positionCount)) { + for (int p = 0; p < positionCount; p++) { + MvContains.process(result, p, field1Block, field2Block); + } + return result.build(); + } + } + + @Override + public String toString() { + return "MvContainsBooleanEvaluator[" + "field1=" + field1 + ", field2=" + field2 + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(field1, field2); + } + + @Override + public long baseRamBytesUsed() { + long baseRamBytesUsed = BASE_RAM_BYTES_USED; + baseRamBytesUsed += field1.baseRamBytesUsed(); + baseRamBytesUsed += field2.baseRamBytesUsed(); + return baseRamBytesUsed; + } + + public static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + private final EvalOperator.ExpressionEvaluator.Factory field1; + private final EvalOperator.ExpressionEvaluator.Factory field2; + + public Factory( + Source source, + EvalOperator.ExpressionEvaluator.Factory field1, + EvalOperator.ExpressionEvaluator.Factory field2 + ) { + this.source = source; + this.field1 = field1; + this.field2 = field2; + } + + @Override + public MvContainsBooleanEvaluator get(DriverContext context) { + return new MvContainsBooleanEvaluator(field1.get(context), field2.get(context), context); + } + + @Override + public String toString() { + return "MvContainsBooleanEvaluator[" + "field1=" + field1 + ", field2=" + field2 + "]"; + } + } + } + + /** + * Currently {@code EvaluatorImplementer} generates: + * if (allBlocksAreNulls) { + * result.appendNull(); + * continue position; + * } + * when all params are null, this violates our contract of always returning a boolean. + * It should probably also generate the warnings method conditionally - omitted here. + * TODO extend code generation to handle this case + */ + public static class MvContainsBytesRefEvaluator implements EvalOperator.ExpressionEvaluator { + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(MvContainsBytesRefEvaluator.class); + private final EvalOperator.ExpressionEvaluator field1; + private final EvalOperator.ExpressionEvaluator field2; + private final DriverContext driverContext; + + public MvContainsBytesRefEvaluator( + EvalOperator.ExpressionEvaluator field1, + EvalOperator.ExpressionEvaluator field2, + DriverContext driverContext + ) { + this.field1 = field1; + this.field2 = field2; + this.driverContext = driverContext; + } + + @Override + public Block eval(Page page) { + try (BytesRefBlock field1Block = (BytesRefBlock) field1.eval(page)) { + try (BytesRefBlock field2Block = (BytesRefBlock) field2.eval(page)) { + return eval(page.getPositionCount(), field1Block, field2Block); + } + } + } + + public BooleanBlock eval(int positionCount, BytesRefBlock field1Block, BytesRefBlock field2Block) { + try (BooleanBlock.Builder result = driverContext.blockFactory().newBooleanBlockBuilder(positionCount)) { + for (int p = 0; p < positionCount; p++) { + MvContains.process(result, p, field1Block, field2Block); + } + return result.build(); + } + } + + @Override + public String toString() { + return "MvContainsBytesRefEvaluator[" + "field1=" + field1 + ", field2=" + field2 + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(field1, field2); + } + + @Override + public long baseRamBytesUsed() { + long baseRamBytesUsed = BASE_RAM_BYTES_USED; + baseRamBytesUsed += field1.baseRamBytesUsed(); + baseRamBytesUsed += field2.baseRamBytesUsed(); + return baseRamBytesUsed; + } + + public static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + private final EvalOperator.ExpressionEvaluator.Factory field1; + private final EvalOperator.ExpressionEvaluator.Factory field2; + + public Factory( + Source source, + EvalOperator.ExpressionEvaluator.Factory field1, + EvalOperator.ExpressionEvaluator.Factory field2 + ) { + this.source = source; + this.field1 = field1; + this.field2 = field2; + } + + @Override + public MvContainsBytesRefEvaluator get(DriverContext context) { + return new MvContainsBytesRefEvaluator(field1.get(context), field2.get(context), context); + } + + @Override + public String toString() { + return "MvContainsBytesRefEvaluator[" + "field1=" + field1 + ", field2=" + field2 + "]"; + } + } + } + + /** + * Currently {@code EvaluatorImplementer} generates: + * if (allBlocksAreNulls) { + * result.appendNull(); + * continue position; + * } + * when all params are null, this violates our contract of always returning a boolean. + * It should probably also generate the warnings method conditionally - omitted here. + * TODO extend code generation to handle this case + */ + public static class MvContainsDoubleEvaluator implements EvalOperator.ExpressionEvaluator { + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(MvContainsDoubleEvaluator.class); + private final EvalOperator.ExpressionEvaluator field1; + private final EvalOperator.ExpressionEvaluator field2; + private final DriverContext driverContext; + + public MvContainsDoubleEvaluator( + EvalOperator.ExpressionEvaluator field1, + EvalOperator.ExpressionEvaluator field2, + DriverContext driverContext + ) { + this.field1 = field1; + this.field2 = field2; + this.driverContext = driverContext; + } + + @Override + public Block eval(Page page) { + try (DoubleBlock field1Block = (DoubleBlock) field1.eval(page)) { + try (DoubleBlock field2Block = (DoubleBlock) field2.eval(page)) { + return eval(page.getPositionCount(), field1Block, field2Block); + } + } + } + + public BooleanBlock eval(int positionCount, DoubleBlock field1Block, DoubleBlock field2Block) { + try (BooleanBlock.Builder result = driverContext.blockFactory().newBooleanBlockBuilder(positionCount)) { + for (int p = 0; p < positionCount; p++) { + MvContains.process(result, p, field1Block, field2Block); + } + return result.build(); + } + } + + @Override + public String toString() { + return "MvContainsDoubleEvaluator[" + "field1=" + field1 + ", field2=" + field2 + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(field1, field2); + } + + @Override + public long baseRamBytesUsed() { + long baseRamBytesUsed = BASE_RAM_BYTES_USED; + baseRamBytesUsed += field1.baseRamBytesUsed(); + baseRamBytesUsed += field2.baseRamBytesUsed(); + return baseRamBytesUsed; + } + + public static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + private final EvalOperator.ExpressionEvaluator.Factory field1; + private final EvalOperator.ExpressionEvaluator.Factory field2; + + public Factory( + Source source, + EvalOperator.ExpressionEvaluator.Factory field1, + EvalOperator.ExpressionEvaluator.Factory field2 + ) { + this.source = source; + this.field1 = field1; + this.field2 = field2; + } + + @Override + public MvContainsDoubleEvaluator get(DriverContext context) { + return new MvContainsDoubleEvaluator(field1.get(context), field2.get(context), context); + } + + @Override + public String toString() { + return "MvContainsDoubleEvaluator[" + "field1=" + field1 + ", field2=" + field2 + "]"; + } + } + } + + /** + * Currently {@code EvaluatorImplementer} generates: + * if (allBlocksAreNulls) { + * result.appendNull(); + * continue position; + * } + * when all params are null, this violates our contract of always returning a boolean. + * It should probably also generate the warnings method conditionally - omitted here. + * TODO extend code generation to handle this case + */ + public static class MvContainsIntEvaluator implements EvalOperator.ExpressionEvaluator { + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(MvContainsIntEvaluator.class); + private final EvalOperator.ExpressionEvaluator field1; + private final EvalOperator.ExpressionEvaluator field2; + private final DriverContext driverContext; + + public MvContainsIntEvaluator( + EvalOperator.ExpressionEvaluator field1, + EvalOperator.ExpressionEvaluator field2, + DriverContext driverContext + ) { + this.field1 = field1; + this.field2 = field2; + this.driverContext = driverContext; + } + + @Override + public Block eval(Page page) { + try (IntBlock field1Block = (IntBlock) field1.eval(page)) { + try (IntBlock field2Block = (IntBlock) field2.eval(page)) { + return eval(page.getPositionCount(), field1Block, field2Block); + } + } + } + + public BooleanBlock eval(int positionCount, IntBlock field1Block, IntBlock field2Block) { + try (BooleanBlock.Builder result = driverContext.blockFactory().newBooleanBlockBuilder(positionCount)) { + for (int p = 0; p < positionCount; p++) { + MvContains.process(result, p, field1Block, field2Block); + } + return result.build(); + } + } + + @Override + public String toString() { + return "MvContainsIntEvaluator[" + "field1=" + field1 + ", field2=" + field2 + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(field1, field2); + } + + @Override + public long baseRamBytesUsed() { + long baseRamBytesUsed = BASE_RAM_BYTES_USED; + baseRamBytesUsed += field1.baseRamBytesUsed(); + baseRamBytesUsed += field2.baseRamBytesUsed(); + return baseRamBytesUsed; + } + + public static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + private final EvalOperator.ExpressionEvaluator.Factory field1; + private final EvalOperator.ExpressionEvaluator.Factory field2; + + public Factory( + Source source, + EvalOperator.ExpressionEvaluator.Factory field1, + EvalOperator.ExpressionEvaluator.Factory field2 + ) { + this.source = source; + this.field1 = field1; + this.field2 = field2; + } + + @Override + public MvContainsIntEvaluator get(DriverContext context) { + return new MvContainsIntEvaluator(field1.get(context), field2.get(context), context); + } + + @Override + public String toString() { + return "MvContainsIntEvaluator[" + "field1=" + field1 + ", field2=" + field2 + "]"; + } + } + } + + /** + * Currently {@code EvaluatorImplementer} generates: + * if (allBlocksAreNulls) { + * result.appendNull(); + * continue position; + * } + * when all params are null, this violates our contract of always returning a boolean. + * It should probably also generate the warnings method conditionally - omitted here. + * TODO extend code generation to handle this case + */ + public static class MvContainsLongEvaluator implements EvalOperator.ExpressionEvaluator { + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(MvContainsLongEvaluator.class); + private final EvalOperator.ExpressionEvaluator field1; + private final EvalOperator.ExpressionEvaluator field2; + private final DriverContext driverContext; + + public MvContainsLongEvaluator( + EvalOperator.ExpressionEvaluator field1, + EvalOperator.ExpressionEvaluator field2, + DriverContext driverContext + ) { + this.field1 = field1; + this.field2 = field2; + this.driverContext = driverContext; + } + + @Override + public Block eval(Page page) { + try (LongBlock field1Block = (LongBlock) field1.eval(page)) { + try (LongBlock field2Block = (LongBlock) field2.eval(page)) { + return eval(page.getPositionCount(), field1Block, field2Block); + } + } + } + + public BooleanBlock eval(int positionCount, LongBlock field1Block, LongBlock field2Block) { + try (BooleanBlock.Builder result = driverContext.blockFactory().newBooleanBlockBuilder(positionCount)) { + for (int p = 0; p < positionCount; p++) { + MvContains.process(result, p, field1Block, field2Block); + } + return result.build(); + } + } + + @Override + public String toString() { + return "MvContainsLongEvaluator[" + "field1=" + field1 + ", field2=" + field2 + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(field1, field2); + } + + @Override + public long baseRamBytesUsed() { + long baseRamBytesUsed = BASE_RAM_BYTES_USED; + baseRamBytesUsed += field1.baseRamBytesUsed(); + baseRamBytesUsed += field2.baseRamBytesUsed(); + return baseRamBytesUsed; + } + + public static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + private final EvalOperator.ExpressionEvaluator.Factory field1; + private final EvalOperator.ExpressionEvaluator.Factory field2; + + public Factory( + Source source, + EvalOperator.ExpressionEvaluator.Factory field1, + EvalOperator.ExpressionEvaluator.Factory field2 + ) { + this.source = source; + this.field1 = field1; + this.field2 = field2; + } + + @Override + public MvContainsLongEvaluator get(DriverContext context) { + return new MvContainsLongEvaluator(field1.get(context), field2.get(context), context); + } + + @Override + public String toString() { + return "MvContainsLongEvaluator[" + "field1=" + field1 + ", field2=" + field2 + "]"; + } + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvFunctionWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvFunctionWritables.java index 7f8fcd910ad6d..8dafc630e0e02 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvFunctionWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvFunctionWritables.java @@ -17,6 +17,7 @@ public static List getNamedWriteables() { MvAppend.ENTRY, MvAvg.ENTRY, MvConcat.ENTRY, + MvContains.ENTRY, MvCount.ENTRY, MvDedupe.ENTRY, MvFirst.ENTRY, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/nulls/IsNull.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/nulls/IsNull.java index 6f91c83940b34..d0baee75f817b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/nulls/IsNull.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/nulls/IsNull.java @@ -12,7 +12,7 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; import org.elasticsearch.core.Releasables; import org.elasticsearch.xpack.esql.capabilities.TranslationAware; import org.elasticsearch.xpack.esql.core.expression.Expression; @@ -103,7 +103,7 @@ public Object fold(FoldContext ctx) { } @Override - public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { + public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { return new IsNullEvaluatorFactory(toEvaluator.apply(field())); } @@ -138,9 +138,9 @@ public Query asQuery(LucenePushdownPredicates pushdownPredicates, TranslatorHand return new NotQuery(source(), new ExistsQuery(source(), handler.nameOf(field()))); } - record IsNullEvaluatorFactory(EvalOperator.ExpressionEvaluator.Factory field) implements EvalOperator.ExpressionEvaluator.Factory { + public record IsNullEvaluatorFactory(ExpressionEvaluator.Factory field) implements ExpressionEvaluator.Factory { @Override - public EvalOperator.ExpressionEvaluator get(DriverContext context) { + public ExpressionEvaluator get(DriverContext context) { return new IsNullEvaluator(context, field.get(context)); } @@ -150,9 +150,7 @@ public String toString() { } } - record IsNullEvaluator(DriverContext driverContext, EvalOperator.ExpressionEvaluator field) - implements - EvalOperator.ExpressionEvaluator { + record IsNullEvaluator(DriverContext driverContext, ExpressionEvaluator field) implements ExpressionEvaluator { private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(IsNullEvaluator.class); @Override diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvContainsErrorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvContainsErrorTests.java new file mode 100644 index 0000000000000..c64cb8eef8d6f --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvContainsErrorTests.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.ErrorsForCasesWithoutExamplesTestCase; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.hamcrest.Matcher; + +import java.util.List; +import java.util.Set; + +import static org.hamcrest.Matchers.equalTo; + +public class MvContainsErrorTests extends ErrorsForCasesWithoutExamplesTestCase { + @Override + protected List cases() { + return paramsToSuppliers(MvContainsTests.parameters()); + } + + @Override + protected Expression build(Source source, List args) { + return new MvContains(source, args.get(0), args.get(1)); + } + + @Override + protected Matcher expectedTypeErrorMatcher(List> validPerPosition, List signature) { + return equalTo( + "second argument of [" + + sourceForSignature(signature) + + "] must be [" + + signature.get(0).noText().typeName() + + "], found value [] type [" + + signature.get(1).typeName() + + "]" + ); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvContainsSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvContainsSerializationTests.java new file mode 100644 index 0000000000000..3f190b0250671 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvContainsSerializationTests.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; + +import java.io.IOException; + +public class MvContainsSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected MvContains createTestInstance() { + Source source = randomSource(); + Expression field1 = randomChild(); + Expression field2 = randomChild(); + return new MvContains(source, field1, field2); + } + + @Override + protected MvContains mutateInstance(MvContains instance) throws IOException { + Source source = randomSource(); + Expression field1 = randomChild(); + Expression field2 = randomChild(); + if (randomBoolean()) { + field1 = randomValueOtherThan(field1, AbstractExpressionSerializationTests::randomChild); + } else { + field2 = randomValueOtherThan(field2, AbstractExpressionSerializationTests::randomChild); + } + return new MvContains(source, field1, field2); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvContainsTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvContainsTests.java new file mode 100644 index 0000000000000..57da6d5c2fe9a --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvContainsTests.java @@ -0,0 +1,371 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.geo.GeometryTestUtils; +import org.elasticsearch.geo.ShapeTestUtils; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.hamcrest.Matcher; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import static org.elasticsearch.xpack.esql.EsqlTestUtils.randomLiteral; +import static org.elasticsearch.xpack.esql.core.util.SpatialCoordinateTypes.CARTESIAN; +import static org.elasticsearch.xpack.esql.core.util.SpatialCoordinateTypes.GEO; +import static org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier.TypedData.MULTI_ROW_NULL; +import static org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier.TypedData.NULL; +import static org.hamcrest.Matchers.equalTo; + +public class MvContainsTests extends AbstractScalarFunctionTestCase { + public MvContainsTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + List suppliers = new ArrayList<>(); + booleans(suppliers); + ints(suppliers); + longs(suppliers); + doubles(suppliers); + bytesRefs(suppliers); + + return parameterSuppliersFromTypedData( + anyNullIsNull( + suppliers, + (nullPosition, nullValueDataType, original) -> original.expectedType(), + (nullPosition, nullData, original) -> original + ) + ); + } + + @Override + protected Expression build(Source source, List args) { + return new MvContains(source, args.get(0), args.get(1)); + } + + private static void booleans(List suppliers) { + suppliers.add(new TestCaseSupplier(List.of(DataType.BOOLEAN, DataType.BOOLEAN), () -> { + List field1 = randomList(1, 10, ESTestCase::randomBoolean); + List field2 = randomList(1, 2, ESTestCase::randomBoolean); + var result = field1.containsAll(field2); + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field1, DataType.BOOLEAN, "field1"), + new TestCaseSupplier.TypedData(field2, DataType.BOOLEAN, "field2") + ), + "MvContainsBooleanEvaluator[field1=Attribute[channel=0], field2=Attribute[channel=1]]", + DataType.BOOLEAN, + equalTo(result) + ); + })); + } + + private static void ints(List suppliers) { + suppliers.add(new TestCaseSupplier(List.of(DataType.INTEGER, DataType.INTEGER), () -> { + List field1 = randomList(1, 10, ESTestCase::randomInt); + List field2 = randomList(1, 10, ESTestCase::randomInt); + var result = field1.containsAll(field2); + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field1, DataType.INTEGER, "field1"), + new TestCaseSupplier.TypedData(field2, DataType.INTEGER, "field2") + ), + "MvContainsIntEvaluator[field1=Attribute[channel=0], field2=Attribute[channel=1]]", + DataType.BOOLEAN, + equalTo(result) + ); + })); + } + + private static void longs(List suppliers) { + suppliers.add(new TestCaseSupplier(List.of(DataType.LONG, DataType.LONG), () -> { + List field1 = randomList(1, 10, ESTestCase::randomLong); + List field2 = randomList(1, 10, ESTestCase::randomLong); + var result = field1.containsAll(field2); + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field1, DataType.LONG, "field1"), + new TestCaseSupplier.TypedData(field2, DataType.LONG, "field2") + ), + "MvContainsLongEvaluator[field1=Attribute[channel=0], field2=Attribute[channel=1]]", + DataType.BOOLEAN, + equalTo(result) + ); + })); + suppliers.add(new TestCaseSupplier(List.of(DataType.UNSIGNED_LONG, DataType.UNSIGNED_LONG), () -> { + List field1 = randomList(1, 10, ESTestCase::randomLong); + List field2 = randomList(1, 10, ESTestCase::randomLong); + var result = field1.containsAll(field2); + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field1, DataType.UNSIGNED_LONG, "field1"), + new TestCaseSupplier.TypedData(field2, DataType.UNSIGNED_LONG, "field2") + ), + "MvContainsLongEvaluator[field1=Attribute[channel=0], field2=Attribute[channel=1]]", + DataType.BOOLEAN, + equalTo(result) + ); + })); + suppliers.add(new TestCaseSupplier(List.of(DataType.DATETIME, DataType.DATETIME), () -> { + List field1 = randomList(1, 10, ESTestCase::randomLong); + List field2 = randomList(1, 10, ESTestCase::randomLong); + var result = field1.containsAll(field2); + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field1, DataType.DATETIME, "field1"), + new TestCaseSupplier.TypedData(field2, DataType.DATETIME, "field2") + ), + "MvContainsLongEvaluator[field1=Attribute[channel=0], field2=Attribute[channel=1]]", + DataType.BOOLEAN, + equalTo(result) + ); + })); + suppliers.add(new TestCaseSupplier(List.of(DataType.DATE_NANOS, DataType.DATE_NANOS), () -> { + List field1 = randomList(1, 10, ESTestCase::randomNonNegativeLong); + List field2 = randomList(1, 10, ESTestCase::randomNonNegativeLong); + var result = field1.containsAll(field2); + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field1, DataType.DATE_NANOS, "field1"), + new TestCaseSupplier.TypedData(field2, DataType.DATE_NANOS, "field2") + ), + "MvContainsLongEvaluator[field1=Attribute[channel=0], field2=Attribute[channel=1]]", + DataType.BOOLEAN, + equalTo(result) + ); + })); + } + + private static void doubles(List suppliers) { + suppliers.add(new TestCaseSupplier(List.of(DataType.DOUBLE, DataType.DOUBLE), () -> { + List field1 = randomList(1, 10, ESTestCase::randomDouble); + List field2 = randomList(1, 10, ESTestCase::randomDouble); + var result = field1.containsAll(field2); + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field1, DataType.DOUBLE, "field1"), + new TestCaseSupplier.TypedData(field2, DataType.DOUBLE, "field2") + ), + "MvContainsDoubleEvaluator[field1=Attribute[channel=0], field2=Attribute[channel=1]]", + DataType.BOOLEAN, + equalTo(result) + ); + })); + } + + private static void bytesRefs(List suppliers) { + for (DataType lhs : new DataType[] { DataType.KEYWORD, DataType.TEXT }) { + for (DataType rhs : new DataType[] { DataType.KEYWORD, DataType.TEXT }) { + suppliers.add(new TestCaseSupplier(List.of(lhs, rhs), () -> { + List field1 = randomList(1, 10, () -> randomLiteral(lhs).value()); + List field2 = randomList(1, 10, () -> randomLiteral(rhs).value()); + var result = field1.containsAll(field2); + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field1, lhs, "field1"), + new TestCaseSupplier.TypedData(field2, rhs, "field2") + ), + "MvContainsBytesRefEvaluator[field1=Attribute[channel=0], field2=Attribute[channel=1]]", + DataType.BOOLEAN, + equalTo(result) + ); + })); + } + } + suppliers.add(new TestCaseSupplier(List.of(DataType.IP, DataType.IP), () -> { + List field1 = randomList(1, 10, () -> randomLiteral(DataType.IP).value()); + List field2 = randomList(1, 10, () -> randomLiteral(DataType.IP).value()); + var result = field1.containsAll(field2); + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field1, DataType.IP, "field"), + new TestCaseSupplier.TypedData(field2, DataType.IP, "field") + ), + "MvContainsBytesRefEvaluator[field1=Attribute[channel=0], field2=Attribute[channel=1]]", + DataType.BOOLEAN, + equalTo(result) + ); + })); + + suppliers.add(new TestCaseSupplier(List.of(DataType.VERSION, DataType.VERSION), () -> { + List field1 = randomList(1, 10, () -> randomLiteral(DataType.VERSION).value()); + List field2 = randomList(1, 10, () -> randomLiteral(DataType.VERSION).value()); + var result = field1.containsAll(field2); + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field1, DataType.VERSION, "field"), + new TestCaseSupplier.TypedData(field2, DataType.VERSION, "field") + ), + "MvContainsBytesRefEvaluator[field1=Attribute[channel=0], field2=Attribute[channel=1]]", + DataType.BOOLEAN, + equalTo(result) + ); + })); + + suppliers.add(new TestCaseSupplier(List.of(DataType.GEO_POINT, DataType.GEO_POINT), () -> { + List field1 = randomList(1, 10, () -> new BytesRef(GEO.asWkt(GeometryTestUtils.randomPoint()))); + List field2 = randomList(1, 10, () -> new BytesRef(GEO.asWkt(GeometryTestUtils.randomPoint()))); + var result = field1.containsAll(field2); + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field1, DataType.GEO_POINT, "field1"), + new TestCaseSupplier.TypedData(field2, DataType.GEO_POINT, "field2") + ), + "MvContainsBytesRefEvaluator[field1=Attribute[channel=0], field2=Attribute[channel=1]]", + DataType.BOOLEAN, + equalTo(result) + ); + })); + + suppliers.add(new TestCaseSupplier(List.of(DataType.CARTESIAN_POINT, DataType.CARTESIAN_POINT), () -> { + List field1 = randomList(1, 10, () -> new BytesRef(CARTESIAN.asWkt(ShapeTestUtils.randomPoint()))); + List field2 = randomList(1, 10, () -> new BytesRef(CARTESIAN.asWkt(ShapeTestUtils.randomPoint()))); + var result = field1.containsAll(field2); + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field1, DataType.CARTESIAN_POINT, "field1"), + new TestCaseSupplier.TypedData(field2, DataType.CARTESIAN_POINT, "field2") + ), + "MvContainsBytesRefEvaluator[field1=Attribute[channel=0], field2=Attribute[channel=1]]", + DataType.BOOLEAN, + equalTo(result) + ); + })); + + suppliers.add(new TestCaseSupplier(List.of(DataType.GEO_SHAPE, DataType.GEO_SHAPE), () -> { + var field1 = randomList(1, 3, () -> new BytesRef(GEO.asWkt(GeometryTestUtils.randomGeometry(randomBoolean(), 500)))); + var field2 = randomList(1, 3, () -> new BytesRef(GEO.asWkt(GeometryTestUtils.randomGeometry(randomBoolean(), 500)))); + var result = field1.containsAll(field2); + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field1, DataType.GEO_SHAPE, "field1"), + new TestCaseSupplier.TypedData(field2, DataType.GEO_SHAPE, "field2") + ), + "MvContainsBytesRefEvaluator[field1=Attribute[channel=0], field2=Attribute[channel=1]]", + DataType.BOOLEAN, + equalTo(result) + ); + })); + + suppliers.add(new TestCaseSupplier(List.of(DataType.CARTESIAN_SHAPE, DataType.CARTESIAN_SHAPE), () -> { + var field1 = randomList(1, 3, () -> new BytesRef(CARTESIAN.asWkt(ShapeTestUtils.randomGeometry(randomBoolean(), 500)))); + var field2 = randomList(1, 3, () -> new BytesRef(CARTESIAN.asWkt(ShapeTestUtils.randomGeometry(randomBoolean(), 500)))); + var result = field1.containsAll(field2); + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field1, DataType.CARTESIAN_SHAPE, "field1"), + new TestCaseSupplier.TypedData(field2, DataType.CARTESIAN_SHAPE, "field2") + ), + "MvContainsBytesRefEvaluator[field1=Attribute[channel=0], field2=Attribute[channel=1]]", + DataType.BOOLEAN, + equalTo(result) + ); + })); + } + + // Adjusted from static method anyNullIsNull in {@code AbstractFunctionTestCase#} + // - changed logic to expect a Boolean as an outcome and alternative evaluators (IsNullEvaluator, ConstantTrue) + // - constructor TestCase that's used has default access which we can't access, using public constructor variant as a replacement + // - Added prefix to generated tests for my sanity. + // - changed construction of new lists by copying them and updating the entries where necessary instead of regenerating + protected static List anyNullIsNull( + List testCaseSuppliers, + ExpectedType expectedType, + ExpectedEvaluatorToString evaluatorToString + ) { + List suppliers = new ArrayList<>(testCaseSuppliers); + + /* + * For each original test case, add as many copies as there were + * arguments, replacing one of the arguments with null and keeping + * the others. + * + * Also, if this was the first time we saw the signature we copy it + * *again*, replacing the argument with null, but annotating the + * argument’s type as `null` explicitly. + */ + Set> uniqueSignatures = new HashSet<>(); + for (TestCaseSupplier original : testCaseSuppliers) { + boolean firstTimeSeenSignature = uniqueSignatures.add(original.types()); + for (int typeIndex = 0; typeIndex < original.types().size(); typeIndex++) { + int nullPosition = typeIndex; + + suppliers.add(new TestCaseSupplier("G1: " + original.name() + " null in " + nullPosition, original.types(), () -> { + TestCaseSupplier.TestCase originalTestCase = original.get(); + List typeDataWithNull = new ArrayList<>(originalTestCase.getData()); + var data = typeDataWithNull.get(nullPosition); + typeDataWithNull.set(nullPosition, data.withData(data.isMultiRow() ? Collections.singletonList(null) : null)); + TestCaseSupplier.TypedData nulledData = originalTestCase.getData().get(nullPosition); + return new TestCaseSupplier.TestCase( + typeDataWithNull, + evaluatorToString.evaluatorToString(nullPosition, nulledData, originalTestCase.evaluatorToString()), + expectedType.expectedType(nullPosition, DataType.BOOLEAN, originalTestCase), + equalTo(nullPosition == 1) + ); + })); + + if (firstTimeSeenSignature) { + var typesWithNull = new ArrayList<>(original.types()); + typesWithNull.set(nullPosition, DataType.NULL); + boolean newSignature = uniqueSignatures.add(typesWithNull); + if (newSignature) { + suppliers.add( + new TestCaseSupplier( + "G2: " + toSpaceSeparatedString(typesWithNull) + " null in " + nullPosition, + typesWithNull, + () -> { + TestCaseSupplier.TestCase originalTestCase = original.get(); + var typeDataWithNull = new ArrayList<>(originalTestCase.getData()); + typeDataWithNull.set( + nullPosition, + typeDataWithNull.get(nullPosition).isMultiRow() ? MULTI_ROW_NULL : NULL + ); + return new TestCaseSupplier.TestCase( + typeDataWithNull, + nullPosition == 0 ? "IsNullEvaluator[field=Attribute[channel=1]]" : "ConstantTrue", + expectedType.expectedType(nullPosition, DataType.BOOLEAN, originalTestCase), + equalTo(nullPosition == 1) + ); + } + ) + ); + } + } + } + } + + return suppliers; + } + + private static String toSpaceSeparatedString(ArrayList typesWithNull) { + return typesWithNull.stream().map(Objects::toString).collect(Collectors.joining(" ")); + } + + // When all arguments are null: the 2nd arg (subset) will be `null` and the 1st is invariant (null,null) => true. + @Override + protected Matcher allNullsMatcher() { + return equalTo(true); + } +}