diff --git a/docs/changelog/112401.yaml b/docs/changelog/112401.yaml new file mode 100644 index 0000000000000..65e9e76ac25f6 --- /dev/null +++ b/docs/changelog/112401.yaml @@ -0,0 +1,6 @@ +pr: 112401 +summary: "ESQL: Fix CASE when conditions are multivalued" +area: ES|QL +type: bug +issues: + - 112359 diff --git a/docs/reference/esql/functions/kibana/definition/case.json b/docs/reference/esql/functions/kibana/definition/case.json index 27705cd3897f9..ab10460f48b25 100644 --- a/docs/reference/esql/functions/kibana/definition/case.json +++ b/docs/reference/esql/functions/kibana/definition/case.json @@ -22,6 +22,30 @@ "variadic" : true, "returnType" : "boolean" }, + { + "params" : [ + { + "name" : "condition", + "type" : "boolean", + "optional" : false, + "description" : "A condition." + }, + { + "name" : "trueValue", + "type" : "boolean", + "optional" : false, + "description" : "The value that's returned when the corresponding condition is the first to evaluate to `true`. The default value is returned when no condition matches." + }, + { + "name" : "elseValue", + "type" : "boolean", + "optional" : true, + "description" : "The value that's returned when no condition evaluates to `true`." + } + ], + "variadic" : true, + "returnType" : "boolean" + }, { "params" : [ { @@ -40,6 +64,90 @@ "variadic" : true, "returnType" : "cartesian_point" }, + { + "params" : [ + { + "name" : "condition", + "type" : "boolean", + "optional" : false, + "description" : "A condition." + }, + { + "name" : "trueValue", + "type" : "cartesian_point", + "optional" : false, + "description" : "The value that's returned when the corresponding condition is the first to evaluate to `true`. The default value is returned when no condition matches." + }, + { + "name" : "elseValue", + "type" : "cartesian_point", + "optional" : true, + "description" : "The value that's returned when no condition evaluates to `true`." + } + ], + "variadic" : true, + "returnType" : "cartesian_point" + }, + { + "params" : [ + { + "name" : "condition", + "type" : "boolean", + "optional" : false, + "description" : "A condition." + }, + { + "name" : "trueValue", + "type" : "cartesian_shape", + "optional" : false, + "description" : "The value that's returned when the corresponding condition is the first to evaluate to `true`. The default value is returned when no condition matches." + } + ], + "variadic" : true, + "returnType" : "cartesian_shape" + }, + { + "params" : [ + { + "name" : "condition", + "type" : "boolean", + "optional" : false, + "description" : "A condition." + }, + { + "name" : "trueValue", + "type" : "cartesian_shape", + "optional" : false, + "description" : "The value that's returned when the corresponding condition is the first to evaluate to `true`. The default value is returned when no condition matches." + }, + { + "name" : "elseValue", + "type" : "cartesian_shape", + "optional" : true, + "description" : "The value that's returned when no condition evaluates to `true`." + } + ], + "variadic" : true, + "returnType" : "cartesian_shape" + }, + { + "params" : [ + { + "name" : "condition", + "type" : "boolean", + "optional" : false, + "description" : "A condition." + }, + { + "name" : "trueValue", + "type" : "date", + "optional" : false, + "description" : "The value that's returned when the corresponding condition is the first to evaluate to `true`. The default value is returned when no condition matches." + } + ], + "variadic" : true, + "returnType" : "date" + }, { "params" : [ { @@ -53,6 +161,12 @@ "type" : "date", "optional" : false, "description" : "The value that's returned when the corresponding condition is the first to evaluate to `true`. The default value is returned when no condition matches." + }, + { + "name" : "elseValue", + "type" : "date", + "optional" : true, + "description" : "The value that's returned when no condition evaluates to `true`." } ], "variadic" : true, @@ -76,6 +190,30 @@ "variadic" : true, "returnType" : "double" }, + { + "params" : [ + { + "name" : "condition", + "type" : "boolean", + "optional" : false, + "description" : "A condition." + }, + { + "name" : "trueValue", + "type" : "double", + "optional" : false, + "description" : "The value that's returned when the corresponding condition is the first to evaluate to `true`. The default value is returned when no condition matches." + }, + { + "name" : "elseValue", + "type" : "double", + "optional" : true, + "description" : "The value that's returned when no condition evaluates to `true`." + } + ], + "variadic" : true, + "returnType" : "double" + }, { "params" : [ { @@ -94,6 +232,90 @@ "variadic" : true, "returnType" : "geo_point" }, + { + "params" : [ + { + "name" : "condition", + "type" : "boolean", + "optional" : false, + "description" : "A condition." + }, + { + "name" : "trueValue", + "type" : "geo_point", + "optional" : false, + "description" : "The value that's returned when the corresponding condition is the first to evaluate to `true`. The default value is returned when no condition matches." + }, + { + "name" : "elseValue", + "type" : "geo_point", + "optional" : true, + "description" : "The value that's returned when no condition evaluates to `true`." + } + ], + "variadic" : true, + "returnType" : "geo_point" + }, + { + "params" : [ + { + "name" : "condition", + "type" : "boolean", + "optional" : false, + "description" : "A condition." + }, + { + "name" : "trueValue", + "type" : "geo_shape", + "optional" : false, + "description" : "The value that's returned when the corresponding condition is the first to evaluate to `true`. The default value is returned when no condition matches." + } + ], + "variadic" : true, + "returnType" : "geo_shape" + }, + { + "params" : [ + { + "name" : "condition", + "type" : "boolean", + "optional" : false, + "description" : "A condition." + }, + { + "name" : "trueValue", + "type" : "geo_shape", + "optional" : false, + "description" : "The value that's returned when the corresponding condition is the first to evaluate to `true`. The default value is returned when no condition matches." + }, + { + "name" : "elseValue", + "type" : "geo_shape", + "optional" : true, + "description" : "The value that's returned when no condition evaluates to `true`." + } + ], + "variadic" : true, + "returnType" : "geo_shape" + }, + { + "params" : [ + { + "name" : "condition", + "type" : "boolean", + "optional" : false, + "description" : "A condition." + }, + { + "name" : "trueValue", + "type" : "integer", + "optional" : false, + "description" : "The value that's returned when the corresponding condition is the first to evaluate to `true`. The default value is returned when no condition matches." + } + ], + "variadic" : true, + "returnType" : "integer" + }, { "params" : [ { @@ -107,6 +329,12 @@ "type" : "integer", "optional" : false, "description" : "The value that's returned when the corresponding condition is the first to evaluate to `true`. The default value is returned when no condition matches." + }, + { + "name" : "elseValue", + "type" : "integer", + "optional" : true, + "description" : "The value that's returned when no condition evaluates to `true`." } ], "variadic" : true, @@ -130,6 +358,30 @@ "variadic" : true, "returnType" : "ip" }, + { + "params" : [ + { + "name" : "condition", + "type" : "boolean", + "optional" : false, + "description" : "A condition." + }, + { + "name" : "trueValue", + "type" : "ip", + "optional" : false, + "description" : "The value that's returned when the corresponding condition is the first to evaluate to `true`. The default value is returned when no condition matches." + }, + { + "name" : "elseValue", + "type" : "ip", + "optional" : true, + "description" : "The value that's returned when no condition evaluates to `true`." + } + ], + "variadic" : true, + "returnType" : "ip" + }, { "params" : [ { @@ -143,12 +395,30 @@ "type" : "keyword", "optional" : false, "description" : "The value that's returned when the corresponding condition is the first to evaluate to `true`. The default value is returned when no condition matches." + } + ], + "variadic" : true, + "returnType" : "keyword" + }, + { + "params" : [ + { + "name" : "condition", + "type" : "boolean", + "optional" : false, + "description" : "A condition." }, { - "name" : "falseValue", + "name" : "trueValue", "type" : "keyword", - "optional" : true, + "optional" : false, "description" : "The value that's returned when the corresponding condition is the first to evaluate to `true`. The default value is returned when no condition matches." + }, + { + "name" : "elseValue", + "type" : "keyword", + "optional" : true, + "description" : "The value that's returned when no condition evaluates to `true`." } ], "variadic" : true, @@ -172,6 +442,30 @@ "variadic" : true, "returnType" : "long" }, + { + "params" : [ + { + "name" : "condition", + "type" : "boolean", + "optional" : false, + "description" : "A condition." + }, + { + "name" : "trueValue", + "type" : "long", + "optional" : false, + "description" : "The value that's returned when the corresponding condition is the first to evaluate to `true`. The default value is returned when no condition matches." + }, + { + "name" : "elseValue", + "type" : "long", + "optional" : true, + "description" : "The value that's returned when no condition evaluates to `true`." + } + ], + "variadic" : true, + "returnType" : "long" + }, { "params" : [ { @@ -190,6 +484,48 @@ "variadic" : true, "returnType" : "text" }, + { + "params" : [ + { + "name" : "condition", + "type" : "boolean", + "optional" : false, + "description" : "A condition." + }, + { + "name" : "trueValue", + "type" : "text", + "optional" : false, + "description" : "The value that's returned when the corresponding condition is the first to evaluate to `true`. The default value is returned when no condition matches." + }, + { + "name" : "elseValue", + "type" : "text", + "optional" : true, + "description" : "The value that's returned when no condition evaluates to `true`." + } + ], + "variadic" : true, + "returnType" : "text" + }, + { + "params" : [ + { + "name" : "condition", + "type" : "boolean", + "optional" : false, + "description" : "A condition." + }, + { + "name" : "trueValue", + "type" : "unsigned_long", + "optional" : false, + "description" : "The value that's returned when the corresponding condition is the first to evaluate to `true`. The default value is returned when no condition matches." + } + ], + "variadic" : true, + "returnType" : "unsigned_long" + }, { "params" : [ { @@ -203,6 +539,12 @@ "type" : "unsigned_long", "optional" : false, "description" : "The value that's returned when the corresponding condition is the first to evaluate to `true`. The default value is returned when no condition matches." + }, + { + "name" : "elseValue", + "type" : "unsigned_long", + "optional" : true, + "description" : "The value that's returned when no condition evaluates to `true`." } ], "variadic" : true, @@ -225,6 +567,30 @@ ], "variadic" : true, "returnType" : "version" + }, + { + "params" : [ + { + "name" : "condition", + "type" : "boolean", + "optional" : false, + "description" : "A condition." + }, + { + "name" : "trueValue", + "type" : "version", + "optional" : false, + "description" : "The value that's returned when the corresponding condition is the first to evaluate to `true`. The default value is returned when no condition matches." + }, + { + "name" : "elseValue", + "type" : "version", + "optional" : true, + "description" : "The value that's returned when no condition evaluates to `true`." + } + ], + "variadic" : true, + "returnType" : "version" } ], "examples" : [ diff --git a/docs/reference/esql/functions/parameters/case.asciidoc b/docs/reference/esql/functions/parameters/case.asciidoc index ee6f7e499b3b3..f12eade4d5780 100644 --- a/docs/reference/esql/functions/parameters/case.asciidoc +++ b/docs/reference/esql/functions/parameters/case.asciidoc @@ -7,3 +7,6 @@ A condition. `trueValue`:: The value that's returned when the corresponding condition is the first to evaluate to `true`. The default value is returned when no condition matches. + +`elseValue`:: +The value that's returned when no condition evaluates to `true`. diff --git a/docs/reference/esql/functions/signature/case.svg b/docs/reference/esql/functions/signature/case.svg index d6fd7da38aca6..0d51a0647627d 100644 --- a/docs/reference/esql/functions/signature/case.svg +++ b/docs/reference/esql/functions/signature/case.svg @@ -1 +1 @@ -CASE(condition,trueValue) \ No newline at end of file +CASE(condition,trueValueelseValue) \ No newline at end of file diff --git a/docs/reference/esql/functions/types/case.asciidoc b/docs/reference/esql/functions/types/case.asciidoc index f6c8cfe9361d1..e8aa3eaf5daae 100644 --- a/docs/reference/esql/functions/types/case.asciidoc +++ b/docs/reference/esql/functions/types/case.asciidoc @@ -4,16 +4,33 @@ [%header.monospaced.styled,format=dsv,separator=|] |=== -condition | trueValue | result -boolean | boolean | boolean -boolean | cartesian_point | cartesian_point -boolean | date | date -boolean | double | double -boolean | geo_point | geo_point -boolean | integer | integer -boolean | ip | ip -boolean | long | long -boolean | text | text -boolean | unsigned_long | unsigned_long -boolean | version | version +condition | trueValue | elseValue | result +boolean | boolean | boolean | boolean +boolean | boolean | | boolean +boolean | cartesian_point | cartesian_point | cartesian_point +boolean | cartesian_point | | cartesian_point +boolean | cartesian_shape | cartesian_shape | cartesian_shape +boolean | cartesian_shape | | cartesian_shape +boolean | date | date | date +boolean | date | | date +boolean | double | double | double +boolean | double | | double +boolean | geo_point | geo_point | geo_point +boolean | geo_point | | geo_point +boolean | geo_shape | geo_shape | geo_shape +boolean | geo_shape | | geo_shape +boolean | integer | integer | integer +boolean | integer | | integer +boolean | ip | ip | ip +boolean | ip | | ip +boolean | keyword | keyword | keyword +boolean | keyword | | keyword +boolean | long | long | long +boolean | long | | long +boolean | text | text | text +boolean | text | | text +boolean | unsigned_long | unsigned_long | unsigned_long +boolean | unsigned_long | | unsigned_long +boolean | version | version | version +boolean | version | | version |=== diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/conditional.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/conditional.csv-spec index d4b45ca37fc2d..996b2b5030d82 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/conditional.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/conditional.csv-spec @@ -94,6 +94,82 @@ M |10 M |10 ; +caseOnMv +required_capability: case_mv + +FROM employees +| WHERE emp_no == 10010 +| EVAL foo = CASE(still_hired, "still", is_rehired, "rehired", "not") +| KEEP still_hired, is_rehired, foo; +warning:Line 3:41: evaluation of [is_rehired] failed, treating result as false. Only first 20 failures recorded. +warning:Line 3:41: java.lang.IllegalArgumentException: CASE expects a single-valued boolean + +still_hired:boolean | is_rehired:boolean | foo:keyword + false | [false, false, true, true] | not +; + +caseOnConstantMvFalseTrue +required_capability: case_mv + +ROW foo = CASE([false, true], "a", "b"); +warning:Line 1:16: evaluation of [[false, true]] failed, treating result as false. Only first 20 failures recorded. +warning:Line 1:16: java.lang.IllegalArgumentException: CASE expects a single-valued boolean + +foo:keyword +b +; + +caseOnConstantMvTrueTrue +required_capability: case_mv + +ROW foo = CASE([true, true], "a", "b"); +warning:Line 1:16: evaluation of [[true, true]] failed, treating result as false. Only first 20 failures recorded. +warning:Line 1:16: java.lang.IllegalArgumentException: CASE expects a single-valued boolean + +foo:keyword +b +; + +caseOnMvSliceMv +required_capability: case_mv + +ROW foo = [true, false, false] | EVAL foo = CASE(MV_SLICE(foo, 0, 1), "a", "b"); +warning:Line 1:50: evaluation of [MV_SLICE(foo, 0, 1)] failed, treating result as false. Only first 20 failures recorded. +warning:Line 1:50: java.lang.IllegalArgumentException: CASE expects a single-valued boolean + +foo:keyword +b +; + +caseOnMvSliceSv +required_capability: case_mv + +ROW foo = [true, false, false] | EVAL foo = CASE(MV_SLICE(foo, 0), "a", "b"); + +foo:keyword +a +; + +caseOnConvertMvSliceMv +required_capability: case_mv + +ROW foo = ["true", "false", "false"] | EVAL foo = CASE(MV_SLICE(foo::BOOLEAN, 0, 1), "a", "b"); +warning:Line 1:56: evaluation of [MV_SLICE(foo::BOOLEAN, 0, 1)] failed, treating result as false. Only first 20 failures recorded. +warning:Line 1:56: java.lang.IllegalArgumentException: CASE expects a single-valued boolean + +foo:keyword +b +; + +caseOnConvertMvSliceSv +required_capability: case_mv + +ROW foo = ["true", "false", "false"] | EVAL foo = CASE(MV_SLICE(foo::BOOLEAN, 0), "a", "b"); + +foo:keyword +a +; + docsCaseSuccessRate // tag::docsCaseSuccessRate[] FROM sample_data diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec index cd3ecfc367ddd..bc90f7f616631 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec @@ -11,7 +11,7 @@ synopsis:keyword "double avg(number:double|integer|long)" "double|date bin(field:integer|long|double|date, buckets:integer|long|double|date_period|time_duration, ?from:integer|long|double|date|keyword|text, ?to:integer|long|double|date|keyword|text)" "double|date bucket(field:integer|long|double|date, buckets:integer|long|double|date_period|time_duration, ?from:integer|long|double|date|keyword|text, ?to:integer|long|double|date|keyword|text)" -"boolean|cartesian_point|date|double|geo_point|integer|ip|keyword|long|text|unsigned_long|version case(condition:boolean, trueValue...:boolean|cartesian_point|date|double|geo_point|integer|ip|keyword|long|text|unsigned_long|version)" +"boolean|cartesian_point|cartesian_shape|date|date_nanos|double|geo_point|geo_shape|integer|ip|keyword|long|text|unsigned_long|version case(condition:boolean, trueValue...:boolean|cartesian_point|cartesian_shape|date|date_nanos|double|geo_point|geo_shape|integer|ip|keyword|long|text|unsigned_long|version)" "double cbrt(number:double|integer|long|unsigned_long)" "double|integer|long|unsigned_long ceil(number:double|integer|long|unsigned_long)" "boolean cidr_match(ip:ip, blockX...:keyword|text)" @@ -137,7 +137,7 @@ atan2 |[y_coordinate, x_coordinate] |["double|integer|long|unsign avg |number |"double|integer|long" |[""] bin |[field, buckets, from, to] |["integer|long|double|date", "integer|long|double|date_period|time_duration", "integer|long|double|date|keyword|text", "integer|long|double|date|keyword|text"] |[Numeric or date expression from which to derive buckets., Target number of buckets\, or desired bucket size if `from` and `to` parameters are omitted., Start of the range. Can be a number\, a date or a date expressed as a string., End of the range. Can be a number\, a date or a date expressed as a string.] bucket |[field, buckets, from, to] |["integer|long|double|date", "integer|long|double|date_period|time_duration", "integer|long|double|date|keyword|text", "integer|long|double|date|keyword|text"] |[Numeric or date expression from which to derive buckets., Target number of buckets\, or desired bucket size if `from` and `to` parameters are omitted., Start of the range. Can be a number\, a date or a date expressed as a string., End of the range. Can be a number\, a date or a date expressed as a string.] -case |[condition, trueValue] |[boolean, "boolean|cartesian_point|date|double|geo_point|integer|ip|keyword|long|text|unsigned_long|version"] |[A condition., The value that's returned when the corresponding condition is the first to evaluate to `true`. The default value is returned when no condition matches.] +case |[condition, trueValue] |[boolean, "boolean|cartesian_point|cartesian_shape|date|date_nanos|double|geo_point|geo_shape|integer|ip|keyword|long|text|unsigned_long|version"] |[A condition., The value that's returned when the corresponding condition is the first to evaluate to `true`. The default value is returned when no condition matches.] cbrt |number |"double|integer|long|unsigned_long" |"Numeric expression. If `null`, the function returns `null`." ceil |number |"double|integer|long|unsigned_long" |Numeric expression. If `null`, the function returns `null`. cidr_match |[ip, blockX] |[ip, "keyword|text"] |[IP address of type `ip` (both IPv4 and IPv6 are supported)., CIDR block to test the IP against.] @@ -391,7 +391,7 @@ atan2 |double avg |double |false |false |true bin |"double|date" |[false, false, true, true] |false |false bucket |"double|date" |[false, false, true, true] |false |false -case |"boolean|cartesian_point|date|double|geo_point|integer|ip|keyword|long|text|unsigned_long|version" |[false, false] |true |false +case |"boolean|cartesian_point|cartesian_shape|date|date_nanos|double|geo_point|geo_shape|integer|ip|keyword|long|text|unsigned_long|version" |[false, false] |true |false cbrt |double |false |false |false ceil |"double|integer|long|unsigned_long" |false |false |false cidr_match |boolean |[false, false] |true |false diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 6e8d64edb6c86..858e2a3332bf8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -97,6 +97,11 @@ public enum Cap { */ AGG_TOP_IP_SUPPORT, + /** + * {@code CASE} properly handling multivalue conditions. + */ + CASE_MV, + /** * Optimization for ST_CENTROID changed some results in cartesian data. #108713 */ diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/Warnings.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/Warnings.java index 630cf62d0030a..87809ba536879 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/Warnings.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/Warnings.java @@ -32,30 +32,53 @@ public void registerException(Exception exception) { }; /** - * Create a new warnings object based on the given mode + * Create a new warnings object based on the given mode which warns that + * it treats the result as {@code null}. * @param warningsMode The warnings collection strategy to use - * @param source used to indicate where in the query the warning occured + * @param source used to indicate where in the query the warning occurred * @return A warnings collector object */ public static Warnings createWarnings(DriverContext.WarningsMode warningsMode, Source source) { - switch (warningsMode) { - case COLLECT -> { - return new Warnings(source); - } - case IGNORE -> { - return NOOP_WARNINGS; - } - } - throw new IllegalStateException("Unreachable"); + return createWarnings(warningsMode, source, "treating result as null"); + } + + /** + * Create a new warnings object based on the given mode which warns that + * it treats the result as {@code false}. + * @param warningsMode The warnings collection strategy to use + * @param source used to indicate where in the query the warning occurred + * @return A warnings collector object + */ + public static Warnings createWarningsTreatedAsFalse(DriverContext.WarningsMode warningsMode, Source source) { + return createWarnings(warningsMode, source, "treating result as false"); + } + + /** + * Create a new warnings object based on the given mode + * @param warningsMode The warnings collection strategy to use + * @param source used to indicate where in the query the warning occurred + * @param first warning message attached to the first result + * @return A warnings collector object + */ + public static Warnings createWarnings(DriverContext.WarningsMode warningsMode, Source source, String first) { + return switch (warningsMode) { + case COLLECT -> new Warnings(source, first); + case IGNORE -> NOOP_WARNINGS; + }; } public Warnings(Source source) { + this(source, "treating result as null"); + } + + public Warnings(Source source, String first) { location = format("Line {}:{}: ", source.source().getLineNumber(), source.source().getColumnNumber()); - first = format( + this.first = format( null, - "{}evaluation of [{}] failed, treating result as null. Only first {} failures recorded.", + "{}evaluation of [{}] failed, {}. Only first {} failures recorded.", location, source.text(), + first, MAX_ADDED_WARNINGS ); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/Case.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/Case.java index 3239afabf6a24..979f681a7fbd0 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/Case.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/Case.java @@ -11,6 +11,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.Page; @@ -29,6 +30,7 @@ import org.elasticsearch.xpack.esql.expression.function.Example; import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.expression.function.Warnings; import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.planner.PlannerUtils; @@ -46,7 +48,11 @@ public final class Case extends EsqlScalarFunction { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Case", Case::new); - record Condition(Expression condition, Expression value) {} + record Condition(Expression condition, Expression value) { + ConditionEvaluatorSupplier toEvaluator(Function toEvaluator) { + return new ConditionEvaluatorSupplier(condition.source(), toEvaluator.apply(condition), toEvaluator.apply(value)); + } + } private final List conditions; private final Expression elseValue; @@ -56,9 +62,12 @@ record Condition(Expression condition, Expression value) {} returnType = { "boolean", "cartesian_point", + "cartesian_shape", "date", + "date_nanos", "double", "geo_point", + "geo_shape", "integer", "ip", "keyword", @@ -94,9 +103,12 @@ public Case( type = { "boolean", "cartesian_point", + "cartesian_shape", "date", + "date_nanos", "double", "geo_point", + "geo_shape", "integer", "ip", "keyword", @@ -215,25 +227,26 @@ public boolean foldable() { if (condition.condition.foldable() == false) { return false; } - Boolean b = (Boolean) condition.condition.fold(); - if (b != null && b) { + if (Boolean.TRUE.equals(condition.condition.fold())) { + /* + * `fold` can make four things here: + * 1. `TRUE` + * 2. `FALSE` + * 3. null + * 4. A list with more than one `TRUE` or `FALSE` in it. + * + * In the first case, we're foldable if the condition is foldable. + * The multivalued field will make a warning, but eventually + * become null. And null will become false. So cases 2-4 are + * the same. In those cases we are foldable only if the *rest* + * of the condition is foldable. + */ return condition.value.foldable(); } } return elseValue.foldable(); } - @Override - public Object fold() { - for (Condition condition : conditions) { - Boolean b = (Boolean) condition.condition.fold(); - if (b != null && b) { - return condition.value.fold(); - } - } - return elseValue.fold(); - } - /** * Fold the arms of {@code CASE} statements. *
    @@ -261,8 +274,20 @@ public Expression partiallyFold() { continue; } modified = true; - Boolean b = (Boolean) condition.condition.fold(); - if (b != null && b) { + if (Boolean.TRUE.equals(condition.condition.fold())) { + /* + * `fold` can make four things here: + * 1. `TRUE` + * 2. `FALSE` + * 3. null + * 4. A list with more than one `TRUE` or `FALSE` in it. + * + * In the first case, we fold to the value of the condition. + * The multivalued field will make a warning, but eventually + * become null. And null will become false. So cases 2-4 are + * the same. In those cases we fold the entire condition + * away, returning just what ever's remaining in the CASE. + */ newChildren.add(condition.value); return finishPartialFold(newChildren); } @@ -277,24 +302,23 @@ public Expression partiallyFold() { } private Expression finishPartialFold(List newChildren) { - if (newChildren.size() == 1) { - return newChildren.get(0); - } - return replaceChildren(newChildren); + return switch (newChildren.size()) { + case 0 -> new Literal(source(), null, dataType()); + case 1 -> newChildren.get(0); + default -> replaceChildren(newChildren); + }; } @Override public ExpressionEvaluator.Factory toEvaluator(Function toEvaluator) { ElementType resultType = PlannerUtils.toElementType(dataType()); - List conditionsFactories = conditions.stream() - .map(c -> new ConditionEvaluatorSupplier(toEvaluator.apply(c.condition), toEvaluator.apply(c.value))) - .toList(); + List conditionsFactories = conditions.stream().map(c -> c.toEvaluator(toEvaluator)).toList(); ExpressionEvaluator.Factory elseValueFactory = toEvaluator.apply(elseValue); return new ExpressionEvaluator.Factory() { @Override public ExpressionEvaluator get(DriverContext context) { return new CaseEvaluator( - context, + context.blockFactory(), resultType, conditionsFactories.stream().map(x -> x.apply(context)).toList(), elseValueFactory.get(context) @@ -303,40 +327,58 @@ public ExpressionEvaluator get(DriverContext context) { @Override public String toString() { - return "CaseEvaluator[resultType=" - + resultType - + ", conditions=" - + conditionsFactories - + ", elseVal=" - + elseValueFactory - + ']'; + return "CaseEvaluator[conditions=" + conditionsFactories + ", elseVal=" + elseValueFactory + ']'; } }; } - record ConditionEvaluatorSupplier(ExpressionEvaluator.Factory condition, ExpressionEvaluator.Factory value) + record ConditionEvaluatorSupplier(Source conditionSource, ExpressionEvaluator.Factory condition, ExpressionEvaluator.Factory value) implements Function { @Override public ConditionEvaluator apply(DriverContext driverContext) { - return new ConditionEvaluator(condition.get(driverContext), value.get(driverContext)); + return new ConditionEvaluator( + /* + * We treat failures as null just like any other failure. + * It's just that we then *immediately* convert it to + * true or false using the tri-valued boolean logic stuff. + * And that makes it into false. This is, *exactly* what + * happens in PostgreSQL and MySQL and SQLite: + * > SELECT CASE WHEN null THEN 1 ELSE 2 END; + * 2 + * Rather than go into depth about this in the warning message, + * we just say "false". + */ + Warnings.createWarningsTreatedAsFalse(driverContext.warningsMode(), conditionSource), + condition.get(driverContext), + value.get(driverContext) + ); } @Override public String toString() { - return "ConditionEvaluator[" + "condition=" + condition + ", value=" + value + ']'; + return "ConditionEvaluator[condition=" + condition + ", value=" + value + ']'; } } - record ConditionEvaluator(EvalOperator.ExpressionEvaluator condition, EvalOperator.ExpressionEvaluator value) implements Releasable { + record ConditionEvaluator( + Warnings conditionWarnings, + EvalOperator.ExpressionEvaluator condition, + EvalOperator.ExpressionEvaluator value + ) implements Releasable { @Override public void close() { Releasables.closeExpectNoException(condition, value); } + + @Override + public String toString() { + return "ConditionEvaluator[condition=" + condition + ", value=" + value + ']'; + } } private record CaseEvaluator( - DriverContext driverContext, + BlockFactory blockFactory, ElementType resultType, List conditions, EvalOperator.ExpressionEvaluator elseVal @@ -353,10 +395,11 @@ public Block eval(Page page) { * a time - but it's not at all fast. */ int positionCount = page.getPositionCount(); - try (Block.Builder result = resultType.newBlockBuilder(positionCount, driverContext.blockFactory())) { + try (Block.Builder result = resultType.newBlockBuilder(positionCount, blockFactory)) { position: for (int p = 0; p < positionCount; p++) { int[] positions = new int[] { p }; Page limited = new Page( + 1, IntStream.range(0, page.getBlockCount()).mapToObj(b -> page.getBlock(b).filter(positions)).toArray(Block[]::new) ); try (Releasable ignored = limited::releaseBlocks) { @@ -365,6 +408,12 @@ public Block eval(Page page) { if (b.isNull(0)) { continue; } + if (b.getValueCount(0) > 1) { + condition.conditionWarnings.registerException( + new IllegalArgumentException("CASE expects a single-valued boolean") + ); + continue; + } if (false == b.getBoolean(b.getFirstValueIndex(0))) { continue; } @@ -390,7 +439,7 @@ public void close() { @Override public String toString() { - return "CaseEvaluator[resultType=" + resultType + ", conditions=" + conditions + ", elseVal=" + elseVal + ']'; + return "CaseEvaluator[conditions=" + conditions + ", elseVal=" + elseVal + ']'; } } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java index 4e26baddd013b..54db9afa291ad 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java @@ -105,8 +105,10 @@ protected static List withNoRowsExpectingNull(List anyNullIsNull( expectedType.expectedType(finalNullPosition, nulledData.type(), oc), nullValue(), null, + null, oc.getExpectedTypeError(), null, + null, null ); })); @@ -246,8 +248,10 @@ protected static List anyNullIsNull( expectedType.expectedType(finalNullPosition, DataType.NULL, oc), nullValue(), null, + null, oc.getExpectedTypeError(), null, + null, null ); })); @@ -642,9 +646,11 @@ protected static List randomizeBytesRefsOffset(List args = description.args(); @@ -707,7 +711,7 @@ public static void testFunctionInfo() { ); List> typesFromSignature = new ArrayList<>(); - Set returnFromSignature = new HashSet<>(); + Set returnFromSignature = new TreeSet<>(); for (int i = 0; i < args.size(); i++) { typesFromSignature.add(new HashSet<>()); } @@ -828,6 +832,28 @@ public static void renderDocs() throws IOException { FunctionDefinition definition = definition(name); if (definition != null) { EsqlFunctionRegistry.FunctionDescription description = EsqlFunctionRegistry.description(definition); + if (name.equals("case")) { + /* + * Hack the description, so we render a proper one for case. + */ + // TODO build the description properly *somehow* + EsqlFunctionRegistry.ArgSignature trueValue = description.args().get(1); + EsqlFunctionRegistry.ArgSignature falseValue = new EsqlFunctionRegistry.ArgSignature( + "elseValue", + trueValue.type(), + "The value that's returned when no condition evaluates to `true`.", + true, + EsqlFunctionRegistry.getTargetType(trueValue.type()) + ); + description = new EsqlFunctionRegistry.FunctionDescription( + description.name(), + List.of(description.args().get(0), trueValue, falseValue), + description.returnType(), + description.description(), + description.variadic(), + description.isAggregation() + ); + } renderTypes(description.argNames()); renderParametersList(description.argNames(), description.argDescriptions()); FunctionInfo info = EsqlFunctionRegistry.functionInfo(definition); @@ -836,22 +862,7 @@ public static void renderDocs() throws IOException { boolean hasAppendix = renderAppendix(info.appendix()); renderFullLayout(name, info.preview(), hasExamples, hasAppendix); renderKibanaInlineDocs(name, info); - List args = description.args(); - if (name.equals("case")) { - EsqlFunctionRegistry.ArgSignature falseValue = args.get(1); - args = List.of( - args.get(0), - falseValue, - new EsqlFunctionRegistry.ArgSignature( - "falseValue", - falseValue.type(), - falseValue.description(), - true, - EsqlFunctionRegistry.getTargetType(falseValue.type()) - ) - ); - } - renderKibanaFunctionDefinition(name, info, args, description.variadic()); + renderKibanaFunctionDefinition(name, info, description.args(), description.variadic()); return; } LogManager.getLogger(getTestClass()).info("Skipping rendering types because the function '" + name + "' isn't registered"); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractScalarFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractScalarFunctionTestCase.java index fed81d4260bcd..85db73901352b 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractScalarFunctionTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractScalarFunctionTestCase.java @@ -38,7 +38,6 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; -import static org.elasticsearch.compute.data.BlockUtils.toJavaObject; import static org.hamcrest.Matchers.either; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; @@ -120,6 +119,9 @@ public final void testEvaluate() { Object result; try (ExpressionEvaluator evaluator = evaluator(expression).get(driverContext())) { + if (testCase.getExpectedBuildEvaluatorWarnings() != null) { + assertWarnings(testCase.getExpectedBuildEvaluatorWarnings()); + } try (Block block = evaluator.eval(row(testCase.getDataValues()))) { assertThat(block.getPositionCount(), is(1)); result = toJavaObjectUnsignedLongAware(block, 0); @@ -177,6 +179,10 @@ public final void testEvaluateBlockWithNulls() { */ public final void testCrankyEvaluateBlockWithoutNulls() { assumeTrue("sometimes the cranky breaker silences warnings, just skip these cases", testCase.getExpectedWarnings() == null); + assumeTrue( + "sometimes the cranky breaker silences warnings, just skip these cases", + testCase.getExpectedBuildEvaluatorWarnings() == null + ); try { testEvaluateBlock(driverContext().blockFactory(), crankyContext(), false); } catch (CircuitBreakingException ex) { @@ -190,6 +196,10 @@ public final void testCrankyEvaluateBlockWithoutNulls() { */ public final void testCrankyEvaluateBlockWithNulls() { assumeTrue("sometimes the cranky breaker silences warnings, just skip these cases", testCase.getExpectedWarnings() == null); + assumeTrue( + "sometimes the cranky breaker silences warnings, just skip these cases", + testCase.getExpectedBuildEvaluatorWarnings() == null + ); try { testEvaluateBlock(driverContext().blockFactory(), crankyContext(), true); } catch (CircuitBreakingException ex) { @@ -242,10 +252,13 @@ private void testEvaluateBlock(BlockFactory inputBlockFactory, DriverContext con ExpressionEvaluator eval = evaluator(expression).get(context); Block block = eval.eval(new Page(positions, manyPositionsBlocks)) ) { + if (testCase.getExpectedBuildEvaluatorWarnings() != null) { + assertWarnings(testCase.getExpectedBuildEvaluatorWarnings()); + } assertThat(block.getPositionCount(), is(positions)); for (int p = 0; p < positions; p++) { if (nullPositions.contains(p)) { - assertThat(toJavaObject(block, p), allNullsMatcher()); + assertThat(toJavaObjectUnsignedLongAware(block, p), allNullsMatcher()); continue; } assertThat(toJavaObjectUnsignedLongAware(block, p), testCase.getMatcher()); @@ -275,6 +288,9 @@ public final void testEvaluateInManyThreads() throws ExecutionException, Interru int count = 10_000; int threads = 5; var evalSupplier = evaluator(expression); + if (testCase.getExpectedBuildEvaluatorWarnings() != null) { + assertWarnings(testCase.getExpectedBuildEvaluatorWarnings()); + } ExecutorService exec = Executors.newFixedThreadPool(threads); try { List> futures = new ArrayList<>(); @@ -310,6 +326,9 @@ public final void testEvaluatorToString() { assumeTrue("Can't build evaluator", testCase.canBuildEvaluator()); var factory = evaluator(expression); try (ExpressionEvaluator ev = factory.get(driverContext())) { + if (testCase.getExpectedBuildEvaluatorWarnings() != null) { + assertWarnings(testCase.getExpectedBuildEvaluatorWarnings()); + } assertThat(ev.toString(), testCase.evaluatorToString()); } } @@ -322,6 +341,9 @@ public final void testFactoryToString() { } assumeTrue("Can't build evaluator", testCase.canBuildEvaluator()); var factory = evaluator(buildFieldExpression(testCase)); + if (testCase.getExpectedBuildEvaluatorWarnings() != null) { + assertWarnings(testCase.getExpectedBuildEvaluatorWarnings()); + } assertThat(factory.toString(), testCase.evaluatorToString()); } @@ -342,6 +364,9 @@ public final void testFold() { result = NumericUtils.unsignedLongAsBigInteger((Long) result); } assertThat(result, testCase.getMatcher()); + if (testCase.getExpectedBuildEvaluatorWarnings() != null) { + assertWarnings(testCase.getExpectedBuildEvaluatorWarnings()); + } if (testCase.getExpectedWarnings() != null) { assertWarnings(testCase.getExpectedWarnings()); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/RailRoadDiagram.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/RailRoadDiagram.java index 4e00fa9f41fbd..df0737feadd8d 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/RailRoadDiagram.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/RailRoadDiagram.java @@ -49,18 +49,28 @@ static String functionSignature(FunctionDefinition definition) throws IOExceptio List expressions = new ArrayList<>(); expressions.add(new SpecialSequence(definition.name().toUpperCase(Locale.ROOT))); expressions.add(new Syntax("(")); - boolean first = true; - List args = EsqlFunctionRegistry.description(definition).argNames(); - for (String arg : args) { - if (arg.endsWith("...")) { - expressions.add(new Repetition(new Sequence(new Syntax(","), new Literal(arg.substring(0, arg.length() - 3))), 0, null)); - } else { - if (first) { - first = false; + + if (definition.name().equals("case")) { + // CASE is so weird let's just hack this together manually + Sequence seq = new Sequence(new Literal("condition"), new Syntax(","), new Literal("trueValue")); + expressions.add(new Repetition(seq, 1, null)); + expressions.add(new Repetition(new Literal("elseValue"), 0, 1)); + } else { + boolean first = true; + List args = EsqlFunctionRegistry.description(definition).argNames(); + for (String arg : args) { + if (arg.endsWith("...")) { + expressions.add( + new Repetition(new Sequence(new Syntax(","), new Literal(arg.substring(0, arg.length() - 3))), 0, null) + ); } else { - expressions.add(new Syntax(",")); + if (first) { + first = false; + } else { + expressions.add(new Syntax(",")); + } + expressions.add(new Literal(arg)); } - expressions.add(new Literal(arg)); } } expressions.add(new Syntax(")")); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java index a1caa784c9787..e44ea907518b4 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java @@ -67,6 +67,21 @@ public static String nameFromTypes(List types) { return types.stream().map(t -> "<" + t.typeName() + ">").collect(Collectors.joining(", ")); } + /** + * Build a name for the test case based on objects likely to describe it. + */ + public static String nameFrom(List paramDescriptors) { + return paramDescriptors.stream().map(p -> { + if (p == null) { + return "null"; + } + if (p instanceof DataType t) { + return "<" + t.typeName() + ">"; + } + return p.toString(); + }).collect(Collectors.joining(", ")); + } + public static List stringCases( BinaryOperator expected, BiFunction evaluatorToString, @@ -1305,7 +1320,7 @@ public static String castToDoubleEvaluator(String original, DataType current) { throw new UnsupportedOperationException(); } - public static class TestCase { + public static final class TestCase { /** * The {@link Source} this test case should be run with */ @@ -1333,22 +1348,34 @@ public static class TestCase { */ private final String[] expectedWarnings; + /** + * Warnings that are added by calling {@link AbstractFunctionTestCase#evaluator} + * or {@link Expression#fold()} on the expression built by this. + */ + private final String[] expectedBuildEvaluatorWarnings; + private final String expectedTypeError; private final boolean canBuildEvaluator; private final Class foldingExceptionClass; private final String foldingExceptionMessage; + /** + * Extra data embedded in the test case. Test subclasses can cast + * as needed and extra whatever helps them. + */ + private final Object extra; + public TestCase(List data, String evaluatorToString, DataType expectedType, Matcher matcher) { this(data, equalTo(evaluatorToString), expectedType, matcher); } public TestCase(List data, Matcher evaluatorToString, DataType expectedType, Matcher matcher) { - this(data, evaluatorToString, expectedType, matcher, null, null, null, null); + this(data, evaluatorToString, expectedType, matcher, null, null, null, null, null, null); } public static TestCase typeError(List data, String expectedTypeError) { - return new TestCase(data, null, null, null, null, expectedTypeError, null, null); + return new TestCase(data, null, null, null, null, null, expectedTypeError, null, null, null); } TestCase( @@ -1357,9 +1384,11 @@ public static TestCase typeError(List data, String expectedTypeError) DataType expectedType, Matcher matcher, String[] expectedWarnings, + String[] expectedBuildEvaluatorWarnings, String expectedTypeError, Class foldingExceptionClass, - String foldingExceptionMessage + String foldingExceptionMessage, + Object extra ) { this.source = Source.EMPTY; this.data = data; @@ -1369,10 +1398,12 @@ public static TestCase typeError(List data, String expectedTypeError) Matcher downcast = (Matcher) matcher; this.matcher = downcast; this.expectedWarnings = expectedWarnings; + this.expectedBuildEvaluatorWarnings = expectedBuildEvaluatorWarnings; this.expectedTypeError = expectedTypeError; this.canBuildEvaluator = data.stream().allMatch(d -> d.forceLiteral || DataType.isRepresentable(d.type)); this.foldingExceptionClass = foldingExceptionClass; this.foldingExceptionMessage = foldingExceptionMessage; + this.extra = extra; } public Source getSource() { @@ -1419,6 +1450,14 @@ public String[] getExpectedWarnings() { return expectedWarnings; } + /** + * Warnings that are added by calling {@link AbstractFunctionTestCase#evaluator} + * or {@link Expression#fold()} on the expression built by this. + */ + public String[] getExpectedBuildEvaluatorWarnings() { + return expectedBuildEvaluatorWarnings; + } + public Class foldingExceptionClass() { return foldingExceptionClass; } @@ -1431,28 +1470,88 @@ public String getExpectedTypeError() { return expectedTypeError; } + /** + * Extra data embedded in the test case. Test subclasses can cast + * as needed and extra whatever helps them. + */ + public Object extra() { + return extra; + } + + /** + * Build a new {@link TestCase} with new {@link #extra()}. + */ + public TestCase withExtra(Object extra) { + return new TestCase( + data, + evaluatorToString, + expectedType, + matcher, + expectedWarnings, + expectedBuildEvaluatorWarnings, + expectedTypeError, + foldingExceptionClass, + foldingExceptionMessage, + extra + ); + } + public TestCase withWarning(String warning) { - String[] newWarnings; - if (expectedWarnings != null) { - newWarnings = Arrays.copyOf(expectedWarnings, expectedWarnings.length + 1); - newWarnings[expectedWarnings.length] = warning; - } else { - newWarnings = new String[] { warning }; - } return new TestCase( data, evaluatorToString, expectedType, matcher, - newWarnings, + addWarning(expectedWarnings, warning), + expectedBuildEvaluatorWarnings, + expectedTypeError, + foldingExceptionClass, + foldingExceptionMessage, + extra + ); + } + + /** + * Warnings that are added by calling {@link AbstractFunctionTestCase#evaluator} + * or {@link Expression#fold()} on the expression built by this. + */ + public TestCase withBuildEvaluatorWarning(String warning) { + return new TestCase( + data, + evaluatorToString, + expectedType, + matcher, + expectedWarnings, + addWarning(expectedBuildEvaluatorWarnings, warning), expectedTypeError, foldingExceptionClass, - foldingExceptionMessage + foldingExceptionMessage, + extra ); } + private String[] addWarning(String[] warnings, String warning) { + if (warnings == null) { + return new String[] { warning }; + } + String[] newWarnings = Arrays.copyOf(warnings, warnings.length + 1); + newWarnings[warnings.length] = warning; + return newWarnings; + } + public TestCase withFoldingException(Class clazz, String message) { - return new TestCase(data, evaluatorToString, expectedType, matcher, expectedWarnings, expectedTypeError, clazz, message); + return new TestCase( + data, + evaluatorToString, + expectedType, + matcher, + expectedWarnings, + expectedBuildEvaluatorWarnings, + expectedTypeError, + clazz, + message, + extra + ); } public DataType expectedType() { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseExtraTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseExtraTests.java index f2c4625f5a3cb..de84086e3cb4e 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseExtraTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseExtraTests.java @@ -72,6 +72,19 @@ public void testPartialFoldDropsFirstFalse() { ); } + public void testPartialFoldMv() { + Case c = new Case( + Source.synthetic("case"), + new Literal(Source.EMPTY, List.of(true, true), DataType.BOOLEAN), + List.of(field("first", DataType.LONG), field("last_cond", DataType.BOOLEAN), field("last", DataType.LONG)) + ); + assertThat(c.foldable(), equalTo(false)); + assertThat( + c.partiallyFold(), + equalTo(new Case(Source.synthetic("case"), field("last_cond", DataType.BOOLEAN), List.of(field("last", DataType.LONG)))) + ); + } + public void testPartialFoldNoop() { Case c = new Case( Source.synthetic("case"), diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseTests.java index 97515db85e8c3..7b26ac8c099dc 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseTests.java @@ -10,22 +10,48 @@ import com.carrotsearch.randomizedtesting.annotations.Name; import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; -import org.apache.lucene.util.BytesRef; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.util.NumericUtils; import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.hamcrest.Matcher; -import java.math.BigInteger; +import java.util.ArrayList; +import java.util.Arrays; import java.util.List; +import java.util.Locale; +import java.util.function.Function; import java.util.function.Supplier; import static org.elasticsearch.xpack.esql.EsqlTestUtils.randomLiteral; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.Matchers.sameInstance; +import static org.hamcrest.Matchers.startsWith; public class CaseTests extends AbstractScalarFunctionTestCase { + private static final List TYPES = List.of( + DataType.KEYWORD, + DataType.TEXT, + DataType.BOOLEAN, + DataType.DATETIME, + DataType.DATE_NANOS, + DataType.DOUBLE, + DataType.INTEGER, + DataType.LONG, + DataType.UNSIGNED_LONG, + DataType.IP, + DataType.VERSION, + DataType.CARTESIAN_POINT, + DataType.GEO_POINT, + DataType.CARTESIAN_SHAPE, + DataType.GEO_SHAPE, + DataType.NULL + ); public CaseTests(@Name("TestCase") Supplier testCaseSupplier) { this.testCase = testCaseSupplier.get(); @@ -36,168 +62,755 @@ public CaseTests(@Name("TestCase") Supplier testCaseS */ @ParametersFactory public static Iterable parameters() { - // TODO this needs lots of stuff flipped to parameters - return parameterSuppliersFromTypedData( - List.of(new TestCaseSupplier("keyword", List.of(DataType.BOOLEAN, DataType.KEYWORD, DataType.KEYWORD), () -> { - List typedData = List.of( - new TestCaseSupplier.TypedData(true, DataType.BOOLEAN, "cond"), - new TestCaseSupplier.TypedData(new BytesRef("a"), DataType.KEYWORD, "a"), - new TestCaseSupplier.TypedData(new BytesRef("b"), DataType.KEYWORD, "b") - ); - return new TestCaseSupplier.TestCase( - typedData, - "CaseEvaluator[resultType=BYTES_REF, conditions=[ConditionEvaluator[condition=Attribute[channel=0], " - + "value=Attribute[channel=1]]], elseVal=Attribute[channel=2]]", - DataType.KEYWORD, - equalTo(new BytesRef("a")) - ); - }), new TestCaseSupplier("text", List.of(DataType.BOOLEAN, DataType.TEXT), () -> { - List typedData = List.of( - new TestCaseSupplier.TypedData(false, DataType.BOOLEAN, "cond"), - new TestCaseSupplier.TypedData(new BytesRef("a"), DataType.TEXT, "trueValue") - ); - return new TestCaseSupplier.TestCase( - typedData, - "CaseEvaluator[resultType=BYTES_REF, conditions=[ConditionEvaluator[condition=Attribute[channel=0], " - + "value=Attribute[channel=1]]], elseVal=LiteralsEvaluator[lit=null]]", - DataType.TEXT, - nullValue() - ); - }), new TestCaseSupplier("boolean", List.of(DataType.BOOLEAN, DataType.BOOLEAN), () -> { - List typedData = List.of( - new TestCaseSupplier.TypedData(false, DataType.BOOLEAN, "cond"), - new TestCaseSupplier.TypedData(false, DataType.BOOLEAN, "trueValue") - ); - return new TestCaseSupplier.TestCase( - typedData, - "CaseEvaluator[resultType=BOOLEAN, conditions=[ConditionEvaluator[condition=Attribute[channel=0], " - + "value=Attribute[channel=1]]], elseVal=LiteralsEvaluator[lit=null]]", - DataType.BOOLEAN, - nullValue() - ); - }), new TestCaseSupplier("date", List.of(DataType.BOOLEAN, DataType.DATETIME), () -> { - long value = randomNonNegativeLong(); - List typedData = List.of( - new TestCaseSupplier.TypedData(true, DataType.BOOLEAN, "cond"), - new TestCaseSupplier.TypedData(value, DataType.DATETIME, "trueValue") - ); - return new TestCaseSupplier.TestCase( - typedData, - "CaseEvaluator[resultType=LONG, conditions=[ConditionEvaluator[condition=Attribute[channel=0], " - + "value=Attribute[channel=1]]], elseVal=LiteralsEvaluator[lit=null]]", - DataType.DATETIME, - equalTo(value) - ); - }), new TestCaseSupplier("double", List.of(DataType.BOOLEAN, DataType.DOUBLE), () -> { - double value = randomDouble(); - List typedData = List.of( - new TestCaseSupplier.TypedData(true, DataType.BOOLEAN, "cond"), - new TestCaseSupplier.TypedData(value, DataType.DOUBLE, "trueValue") - ); - return new TestCaseSupplier.TestCase( - typedData, - "CaseEvaluator[resultType=DOUBLE, conditions=[ConditionEvaluator[condition=Attribute[channel=0], " - + "value=Attribute[channel=1]]], elseVal=LiteralsEvaluator[lit=null]]", - DataType.DOUBLE, - equalTo(value) - ); - }), new TestCaseSupplier("integer", List.of(DataType.BOOLEAN, DataType.INTEGER), () -> { - int value = randomInt(); + List suppliers = new ArrayList<>(); + for (DataType type : TYPES) { + twoAndThreeArgs(suppliers, true, true, type, List.of()); + twoAndThreeArgs(suppliers, false, false, type, List.of()); + twoAndThreeArgs(suppliers, null, false, type, List.of()); + twoAndThreeArgs( + suppliers, + randomMultivaluedCondition(), + false, + type, + List.of( + "Line -1:-1: evaluation of [cond] failed, treating result as false. Only first 20 failures recorded.", + "Line -1:-1: java.lang.IllegalArgumentException: CASE expects a single-valued boolean" + ) + ); + } + suppliers = errorsForCasesWithoutExamples( + suppliers, + (includeOrdinal, validPerPosition, types) -> typeErrorMessage(includeOrdinal, types) + ); + + for (DataType type : TYPES) { + fourAndFiveArgs(suppliers, true, randomSingleValuedCondition(), 0, type, List.of()); + fourAndFiveArgs(suppliers, false, true, 1, type, List.of()); + fourAndFiveArgs(suppliers, false, false, 2, type, List.of()); + fourAndFiveArgs(suppliers, null, true, 1, type, List.of()); + fourAndFiveArgs(suppliers, null, false, 2, type, List.of()); + fourAndFiveArgs( + suppliers, + randomMultivaluedCondition(), + true, + 1, + type, + List.of( + "Line -1:-1: evaluation of [cond1] failed, treating result as false. Only first 20 failures recorded.", + "Line -1:-1: java.lang.IllegalArgumentException: CASE expects a single-valued boolean" + ) + ); + fourAndFiveArgs( + suppliers, + false, + randomMultivaluedCondition(), + 2, + type, + List.of( + "Line -1:-1: evaluation of [cond2] failed, treating result as false. Only first 20 failures recorded.", + "Line -1:-1: java.lang.IllegalArgumentException: CASE expects a single-valued boolean" + ) + ); + } + return + + parameterSuppliersFromTypedData(suppliers); + } + + private static void twoAndThreeArgs( + List suppliers, + Object cond, + boolean lhsOrRhs, + DataType type, + List warnings + ) { + suppliers.add(new TestCaseSupplier(TestCaseSupplier.nameFrom(Arrays.asList(cond, type)), List.of(DataType.BOOLEAN, type), () -> { + Object lhs = randomLiteral(type).value(); + List typedData = List.of(cond(cond, "cond"), new TestCaseSupplier.TypedData(lhs, type, "lhs")); + return testCase(type, typedData, lhsOrRhs ? lhs : null, toStringMatcher(1, true), false, null, addWarnings(warnings)); + })); + suppliers.add( + new TestCaseSupplier(TestCaseSupplier.nameFrom(Arrays.asList(cond, type, type)), List.of(DataType.BOOLEAN, type, type), () -> { + Object lhs = randomLiteral(type).value(); + Object rhs = randomLiteral(type).value(); List typedData = List.of( - new TestCaseSupplier.TypedData(false, DataType.BOOLEAN, "cond"), - new TestCaseSupplier.TypedData(value, DataType.INTEGER, "trueValue") - ); - return new TestCaseSupplier.TestCase( - typedData, - "CaseEvaluator[resultType=INT, conditions=[ConditionEvaluator[condition=Attribute[channel=0], " - + "value=Attribute[channel=1]]], elseVal=LiteralsEvaluator[lit=null]]", - DataType.INTEGER, - nullValue() + cond(cond, "cond"), + new TestCaseSupplier.TypedData(lhs, type, "lhs"), + new TestCaseSupplier.TypedData(rhs, type, "rhs") ); - }), new TestCaseSupplier("long", List.of(DataType.BOOLEAN, DataType.LONG), () -> { - long value = randomLong(); - List typedData = List.of( - new TestCaseSupplier.TypedData(false, DataType.BOOLEAN, "cond"), - new TestCaseSupplier.TypedData(value, DataType.LONG, "trueValue") + return testCase(type, typedData, lhsOrRhs ? lhs : rhs, toStringMatcher(1, false), false, null, addWarnings(warnings)); + }) + ); + if (lhsOrRhs) { + suppliers.add( + new TestCaseSupplier( + "foldable " + TestCaseSupplier.nameFrom(Arrays.asList(cond, type, type)), + List.of(DataType.BOOLEAN, type, type), + () -> { + Object lhs = randomLiteral(type).value(); + Object rhs = randomLiteral(type).value(); + List typedData = List.of( + cond(cond, "cond").forceLiteral(), + new TestCaseSupplier.TypedData(lhs, type, "lhs").forceLiteral(), + new TestCaseSupplier.TypedData(rhs, type, "rhs") + ); + return testCase( + type, + typedData, + lhs, + startsWith("LiteralsEvaluator[lit="), + true, + null, + addBuildEvaluatorWarnings(warnings) + ); + } + ) + ); + suppliers.add( + new TestCaseSupplier( + "partial foldable " + TestCaseSupplier.nameFrom(Arrays.asList(cond, type)), + List.of(DataType.BOOLEAN, type), + () -> { + Object lhs = randomLiteral(type).value(); + List typedData = List.of( + cond(cond, "cond").forceLiteral(), + new TestCaseSupplier.TypedData(lhs, type, "lhs") + ); + return testCase( + type, + typedData, + lhs, + startsWith("CaseEvaluator[conditions=[ConditionEvaluator[condition=LiteralsEvaluator"), + false, + List.of(typedData.get(1)), + addBuildEvaluatorWarnings(warnings) + ); + } + ) + ); + } else { + suppliers.add( + new TestCaseSupplier( + "foldable " + TestCaseSupplier.nameFrom(Arrays.asList(cond, type)), + List.of(DataType.BOOLEAN, type), + () -> { + Object lhs = randomLiteral(type).value(); + List typedData = List.of( + cond(cond, "cond").forceLiteral(), + new TestCaseSupplier.TypedData(lhs, type, "lhs") + ); + return testCase( + type, + typedData, + null, + startsWith("LiteralsEvaluator[lit="), + true, + List.of(new TestCaseSupplier.TypedData(null, type, "null").forceLiteral()), + addBuildEvaluatorWarnings(warnings) + ); + } + ) + ); + } + + suppliers.add( + new TestCaseSupplier( + "partial foldable " + TestCaseSupplier.nameFrom(Arrays.asList(cond, type, type)), + List.of(DataType.BOOLEAN, type, type), + () -> { + Object lhs = randomLiteral(type).value(); + Object rhs = randomLiteral(type).value(); + List typedData = List.of( + cond(cond, "cond").forceLiteral(), + new TestCaseSupplier.TypedData(lhs, type, "lhs"), + new TestCaseSupplier.TypedData(rhs, type, "rhs") + ); + return testCase( + type, + typedData, + lhsOrRhs ? lhs : rhs, + startsWith("CaseEvaluator[conditions=[ConditionEvaluator[condition=LiteralsEvaluator"), + false, + List.of(typedData.get(lhsOrRhs ? 1 : 2)), + addWarnings(warnings) + ); + } + ) + ); + + // Fill in some cases with null conditions or null values + if (cond == null) { + suppliers.add( + new TestCaseSupplier(TestCaseSupplier.nameFrom(Arrays.asList(DataType.NULL, type)), List.of(DataType.NULL, type), () -> { + Object lhs = randomLiteral(type).value(); + List typedData = List.of( + new TestCaseSupplier.TypedData(null, DataType.NULL, "cond"), + new TestCaseSupplier.TypedData(lhs, type, "lhs") + ); + return testCase( + type, + typedData, + lhsOrRhs ? lhs : null, + startsWith("CaseEvaluator[conditions=[ConditionEvaluator[condition="), + false, + null, + addWarnings(warnings) + ); + }) + ); + suppliers.add( + new TestCaseSupplier( + TestCaseSupplier.nameFrom(Arrays.asList(DataType.NULL, type, type)), + List.of(DataType.NULL, type, type), + () -> { + Object lhs = randomLiteral(type).value(); + Object rhs = randomLiteral(type).value(); + List typedData = List.of( + new TestCaseSupplier.TypedData(null, DataType.NULL, "cond"), + new TestCaseSupplier.TypedData(lhs, type, "lhs"), + new TestCaseSupplier.TypedData(rhs, type, "rhs") + ); + return testCase( + type, + typedData, + lhsOrRhs ? lhs : rhs, + startsWith("CaseEvaluator[conditions=[ConditionEvaluator[condition="), + false, + null, + addWarnings(warnings) + ); + } + ) + ); + } + suppliers.add( + new TestCaseSupplier( + TestCaseSupplier.nameFrom(Arrays.asList(DataType.BOOLEAN, DataType.NULL, type)), + List.of(DataType.BOOLEAN, DataType.NULL, type), + () -> { + Object rhs = randomLiteral(type).value(); + List typedData = List.of( + cond(cond, "cond"), + new TestCaseSupplier.TypedData(null, DataType.NULL, "lhs"), + new TestCaseSupplier.TypedData(rhs, type, "rhs") + ); + return testCase( + type, + typedData, + lhsOrRhs ? null : rhs, + startsWith("CaseEvaluator[conditions=[ConditionEvaluator[condition="), + false, + null, + addWarnings(warnings) + ); + } + ) + ); + suppliers.add( + new TestCaseSupplier( + TestCaseSupplier.nameFrom(Arrays.asList(DataType.BOOLEAN, type, DataType.NULL)), + List.of(DataType.BOOLEAN, type, DataType.NULL), + () -> { + Object lhs = randomLiteral(type).value(); + List typedData = List.of( + cond(cond, "cond"), + new TestCaseSupplier.TypedData(lhs, type, "lhs"), + new TestCaseSupplier.TypedData(null, DataType.NULL, "rhs") + ); + return testCase( + type, + typedData, + lhsOrRhs ? lhs : null, + startsWith("CaseEvaluator[conditions=[ConditionEvaluator[condition="), + false, + null, + addWarnings(warnings) + ); + } + ) + ); + } + + private static void fourAndFiveArgs( + List suppliers, + Object cond1, + Object cond2, + int result, + DataType type, + List warnings + ) { + suppliers.add( + new TestCaseSupplier( + TestCaseSupplier.nameFrom(Arrays.asList(cond1, type, cond2, type)), + List.of(DataType.BOOLEAN, type, DataType.BOOLEAN, type), + () -> { + Object r1 = randomLiteral(type).value(); + Object r2 = randomLiteral(type).value(); + List typedData = List.of( + cond(cond1, "cond1"), + new TestCaseSupplier.TypedData(r1, type, "r1"), + cond(cond2, "cond2"), + new TestCaseSupplier.TypedData(r2, type, "r2") + ); + return testCase(type, typedData, switch (result) { + case 0 -> r1; + case 1 -> r2; + case 2 -> null; + default -> throw new AssertionError("unsupported result " + result); + }, toStringMatcher(2, true), false, null, addWarnings(warnings)); + } + ) + ); + suppliers.add( + new TestCaseSupplier( + TestCaseSupplier.nameFrom(Arrays.asList(cond1, type, cond2, type, type)), + List.of(DataType.BOOLEAN, type, DataType.BOOLEAN, type, type), + () -> { + Object r1 = randomLiteral(type).value(); + Object r2 = randomLiteral(type).value(); + Object r3 = randomLiteral(type).value(); + List typedData = List.of( + cond(cond1, "cond1"), + new TestCaseSupplier.TypedData(r1, type, "r1"), + cond(cond2, "cond2"), + new TestCaseSupplier.TypedData(r2, type, "r2"), + new TestCaseSupplier.TypedData(r3, type, "r3") + ); + return testCase(type, typedData, switch (result) { + case 0 -> r1; + case 1 -> r2; + case 2 -> r3; + default -> throw new AssertionError("unsupported result " + result); + }, toStringMatcher(2, false), false, null, addWarnings(warnings)); + } + ) + ); + // Add some foldable and partially foldable cases. This isn't every combination of fold-ability, but it's many. + switch (result) { + case 0 -> { + suppliers.add( + new TestCaseSupplier( + "foldable " + TestCaseSupplier.nameFrom(Arrays.asList(cond1, type, cond2, type, type)), + List.of(DataType.BOOLEAN, type, DataType.BOOLEAN, type, type), + () -> { + Object r1 = randomLiteral(type).value(); + Object r2 = randomLiteral(type).value(); + Object r3 = randomLiteral(type).value(); + List typedData = List.of( + cond(cond1, "cond1").forceLiteral(), + new TestCaseSupplier.TypedData(r1, type, "r1").forceLiteral(), + cond(cond2, "cond2"), + new TestCaseSupplier.TypedData(r2, type, "r2"), + new TestCaseSupplier.TypedData(r3, type, "r3") + ); + return testCase( + type, + typedData, + r1, + startsWith("LiteralsEvaluator[lit="), + true, + null, + addBuildEvaluatorWarnings(warnings) + ); + } + ) ); - return new TestCaseSupplier.TestCase( - typedData, - "CaseEvaluator[resultType=LONG, conditions=[ConditionEvaluator[condition=Attribute[channel=0], " - + "value=Attribute[channel=1]]], elseVal=LiteralsEvaluator[lit=null]]", - DataType.LONG, - nullValue() + suppliers.add( + new TestCaseSupplier( + "partial foldable " + TestCaseSupplier.nameFrom(Arrays.asList(cond1, type, cond2, type, type)), + List.of(DataType.BOOLEAN, type, DataType.BOOLEAN, type, type), + () -> { + Object r1 = randomLiteral(type).value(); + Object r2 = randomLiteral(type).value(); + Object r3 = randomLiteral(type).value(); + List typedData = List.of( + cond(cond1, "cond1").forceLiteral(), + new TestCaseSupplier.TypedData(r1, type, "r1"), + cond(cond2, "cond2"), + new TestCaseSupplier.TypedData(r2, type, "r2"), + new TestCaseSupplier.TypedData(r3, type, "r3") + ); + return testCase( + type, + typedData, + r1, + startsWith("CaseEvaluator[conditions=[ConditionEvaluator[condition=LiteralsEvaluator[lit="), + false, + List.of(typedData.get(1)), + addBuildEvaluatorWarnings(warnings) + ); + } + ) ); - }), new TestCaseSupplier("unsigned_long", List.of(DataType.BOOLEAN, DataType.UNSIGNED_LONG), () -> { - BigInteger value = randomUnsignedLongBetween(BigInteger.ZERO, UNSIGNED_LONG_MAX); - List typedData = List.of( - new TestCaseSupplier.TypedData(true, DataType.BOOLEAN, "cond"), - new TestCaseSupplier.TypedData(value, DataType.UNSIGNED_LONG, "trueValue") + } + case 1 -> { + suppliers.add( + new TestCaseSupplier( + "foldable " + TestCaseSupplier.nameFrom(Arrays.asList(cond1, type, cond2, type, type)), + List.of(DataType.BOOLEAN, type, DataType.BOOLEAN, type, type), + () -> { + Object r1 = randomLiteral(type).value(); + Object r2 = randomLiteral(type).value(); + Object r3 = randomLiteral(type).value(); + List typedData = List.of( + cond(cond1, "cond1").forceLiteral(), + new TestCaseSupplier.TypedData(r1, type, "r1").forceLiteral(), + cond(cond2, "cond2").forceLiteral(), + new TestCaseSupplier.TypedData(r2, type, "r2").forceLiteral(), + new TestCaseSupplier.TypedData(r3, type, "r3") + ); + return testCase( + type, + typedData, + r2, + startsWith("LiteralsEvaluator[lit="), + true, + null, + addBuildEvaluatorWarnings(warnings) + ); + } + ) ); - return new TestCaseSupplier.TestCase( - typedData, - "CaseEvaluator[resultType=LONG, conditions=[ConditionEvaluator[condition=Attribute[channel=0], " - + "value=Attribute[channel=1]]], elseVal=LiteralsEvaluator[lit=null]]", - DataType.UNSIGNED_LONG, - equalTo(value) + suppliers.add( + new TestCaseSupplier( + "partial foldable 1 " + TestCaseSupplier.nameFrom(Arrays.asList(cond1, type, cond2, type)), + List.of(DataType.BOOLEAN, type, DataType.BOOLEAN, type), + () -> { + Object r1 = randomLiteral(type).value(); + Object r2 = randomLiteral(type).value(); + Object r3 = randomLiteral(type).value(); + List typedData = List.of( + cond(cond1, "cond1").forceLiteral(), + new TestCaseSupplier.TypedData(r1, type, "r1").forceLiteral(), + cond(cond2, "cond2").forceLiteral(), + new TestCaseSupplier.TypedData(r2, type, "r2"), + new TestCaseSupplier.TypedData(r3, type, "r3") + ); + return testCase( + type, + typedData, + r2, + startsWith("CaseEvaluator[conditions=[ConditionEvaluator[condition=LiteralsEvaluator[lit="), + false, + List.of(typedData.get(3)), + addWarnings(warnings) + ); + } + ) ); - }), new TestCaseSupplier("ip", List.of(DataType.BOOLEAN, DataType.IP), () -> { - BytesRef value = (BytesRef) randomLiteral(DataType.IP).value(); - List typedData = List.of( - new TestCaseSupplier.TypedData(true, DataType.BOOLEAN, "cond"), - new TestCaseSupplier.TypedData(value, DataType.IP, "trueValue") + suppliers.add( + new TestCaseSupplier( + "partial foldable 2 " + TestCaseSupplier.nameFrom(Arrays.asList(cond1, type, cond2, type)), + List.of(DataType.BOOLEAN, type, DataType.BOOLEAN, type), + () -> { + Object r1 = randomLiteral(type).value(); + Object r2 = randomLiteral(type).value(); + List typedData = List.of( + cond(cond1, "cond1").forceLiteral(), + new TestCaseSupplier.TypedData(r1, type, "r1").forceLiteral(), + cond(cond2, "cond2").forceLiteral(), + new TestCaseSupplier.TypedData(r2, type, "r2") + ); + return testCase( + type, + typedData, + r2, + startsWith("CaseEvaluator[conditions=[ConditionEvaluator[condition=LiteralsEvaluator[lit="), + false, + List.of(typedData.get(3)), + addWarnings(warnings) + ); + } + ) ); - return new TestCaseSupplier.TestCase( - typedData, - "CaseEvaluator[resultType=BYTES_REF, conditions=[ConditionEvaluator[condition=Attribute[channel=0], " - + "value=Attribute[channel=1]]], elseVal=LiteralsEvaluator[lit=null]]", - DataType.IP, - equalTo(value) + suppliers.add( + new TestCaseSupplier( + "partial foldable 3 " + TestCaseSupplier.nameFrom(Arrays.asList(cond1, type, cond2, type)), + List.of(DataType.BOOLEAN, type, DataType.BOOLEAN, type), + () -> { + Object r1 = randomLiteral(type).value(); + Object r2 = randomLiteral(type).value(); + List typedData = List.of( + cond(cond1, "cond1").forceLiteral(), + new TestCaseSupplier.TypedData(r1, type, "r1").forceLiteral(), + cond(cond2, "cond2"), + new TestCaseSupplier.TypedData(r2, type, "r2") + ); + return testCase( + type, + typedData, + r2, + startsWith("CaseEvaluator[conditions=[ConditionEvaluator[condition=LiteralsEvaluator[lit="), + false, + typedData.subList(2, 4), + addWarnings(warnings) + ); + } + ) ); - }), new TestCaseSupplier("version", List.of(DataType.BOOLEAN, DataType.VERSION), () -> { - BytesRef value = (BytesRef) randomLiteral(DataType.VERSION).value(); - List typedData = List.of( - new TestCaseSupplier.TypedData(false, DataType.BOOLEAN, "cond"), - new TestCaseSupplier.TypedData(value, DataType.VERSION, "trueValue") + } + case 2 -> { + suppliers.add( + new TestCaseSupplier( + "foldable " + TestCaseSupplier.nameFrom(Arrays.asList(cond1, type, cond2, type, type)), + List.of(DataType.BOOLEAN, type, DataType.BOOLEAN, type, type), + () -> { + Object r1 = randomLiteral(type).value(); + Object r2 = randomLiteral(type).value(); + Object r3 = randomLiteral(type).value(); + List typedData = List.of( + cond(cond1, "cond1").forceLiteral(), + new TestCaseSupplier.TypedData(r1, type, "r1"), + cond(cond2, "cond2").forceLiteral(), + new TestCaseSupplier.TypedData(r2, type, "r2"), + new TestCaseSupplier.TypedData(r3, type, "r3").forceLiteral() + ); + return testCase( + type, + typedData, + r3, + startsWith("LiteralsEvaluator[lit="), + true, + null, + addBuildEvaluatorWarnings(warnings) + ); + } + ) ); - return new TestCaseSupplier.TestCase( - typedData, - "CaseEvaluator[resultType=BYTES_REF, conditions=[ConditionEvaluator[condition=Attribute[channel=0], " - + "value=Attribute[channel=1]]], elseVal=LiteralsEvaluator[lit=null]]", - DataType.VERSION, - nullValue() + suppliers.add( + new TestCaseSupplier( + "partial foldable 1 " + TestCaseSupplier.nameFrom(Arrays.asList(cond1, type, cond2, type, type)), + List.of(DataType.BOOLEAN, type, DataType.BOOLEAN, type, type), + () -> { + Object r1 = randomLiteral(type).value(); + Object r2 = randomLiteral(type).value(); + Object r3 = randomLiteral(type).value(); + List typedData = List.of( + cond(cond1, "cond1").forceLiteral(), + new TestCaseSupplier.TypedData(r1, type, "r1"), + cond(cond2, "cond2").forceLiteral(), + new TestCaseSupplier.TypedData(r2, type, "r2"), + new TestCaseSupplier.TypedData(r3, type, "r3") + ); + return testCase( + type, + typedData, + r3, + startsWith("CaseEvaluator[conditions=[ConditionEvaluator[condition=LiteralsEvaluator[lit="), + false, + List.of(typedData.get(4)), + addWarnings(warnings) + ); + } + ) ); - }), new TestCaseSupplier("cartesian_point", List.of(DataType.BOOLEAN, DataType.CARTESIAN_POINT), () -> { - BytesRef value = (BytesRef) randomLiteral(DataType.CARTESIAN_POINT).value(); - List typedData = List.of( - new TestCaseSupplier.TypedData(false, DataType.BOOLEAN, "cond"), - new TestCaseSupplier.TypedData(value, DataType.CARTESIAN_POINT, "trueValue") + suppliers.add( + new TestCaseSupplier( + "partial foldable 2 " + TestCaseSupplier.nameFrom(Arrays.asList(cond1, type, cond2, type, type)), + List.of(DataType.BOOLEAN, type, DataType.BOOLEAN, type, type), + () -> { + Object r1 = randomLiteral(type).value(); + Object r2 = randomLiteral(type).value(); + Object r3 = randomLiteral(type).value(); + List typedData = List.of( + cond(cond1, "cond1").forceLiteral(), + new TestCaseSupplier.TypedData(r1, type, "r1"), + cond(cond2, "cond2"), + new TestCaseSupplier.TypedData(r2, type, "r2"), + new TestCaseSupplier.TypedData(r3, type, "r3") + ); + return testCase( + type, + typedData, + r3, + startsWith("CaseEvaluator[conditions=[ConditionEvaluator[condition=LiteralsEvaluator[lit="), + false, + typedData.subList(2, 5), + addWarnings(warnings) + ); + } + ) ); - return new TestCaseSupplier.TestCase( - typedData, - "CaseEvaluator[resultType=BYTES_REF, conditions=[ConditionEvaluator[condition=Attribute[channel=0], " - + "value=Attribute[channel=1]]], elseVal=LiteralsEvaluator[lit=null]]", - DataType.CARTESIAN_POINT, - nullValue() - ); - }), new TestCaseSupplier("geo_point", List.of(DataType.BOOLEAN, DataType.GEO_POINT), () -> { - BytesRef value = (BytesRef) randomLiteral(DataType.GEO_POINT).value(); - List typedData = List.of( - new TestCaseSupplier.TypedData(true, DataType.BOOLEAN, "cond"), - new TestCaseSupplier.TypedData(value, DataType.GEO_POINT, "trueValue") - ); - return new TestCaseSupplier.TestCase( - typedData, - "CaseEvaluator[resultType=BYTES_REF, conditions=[ConditionEvaluator[condition=Attribute[channel=0], " - + "value=Attribute[channel=1]]], elseVal=LiteralsEvaluator[lit=null]]", - DataType.GEO_POINT, - equalTo(value) - ); - })) + } + default -> throw new IllegalArgumentException("unsupported " + result); + } + } + + private static Matcher toStringMatcher(int conditions, boolean trailingNull) { + StringBuilder result = new StringBuilder("CaseEvaluator[conditions=["); + int channel = 0; + for (int i = 0; i < conditions; i++) { + if (i != 0) { + result.append(", "); + } + result.append("ConditionEvaluator[condition=Attribute[channel=").append(channel++); + result.append("], value=Attribute[channel=").append(channel++).append("]]"); + } + if (trailingNull) { + result.append("], elseVal=LiteralsEvaluator[lit=null]]"); + } else { + result.append("], elseVal=Attribute[channel=").append(channel).append("]]"); + } + return equalTo(result.toString()); + } + + private static TestCaseSupplier.TypedData cond(Object cond, String name) { + return new TestCaseSupplier.TypedData(cond instanceof Supplier s ? s.get() : cond, DataType.BOOLEAN, name); + } + + private static TestCaseSupplier.TestCase testCase( + DataType type, + List typedData, + Object result, + Matcher evaluatorToString, + boolean foldable, + @Nullable List partialFold, + Function decorate + ) { + if (type == DataType.UNSIGNED_LONG && result != null) { + result = NumericUtils.unsignedLongAsBigInteger((Long) result); + } + return decorate.apply( + new TestCaseSupplier.TestCase(typedData, evaluatorToString, type, equalTo(result)).withExtra(new Extra(foldable, partialFold)) ); } @Override - protected Expression build(Source source, List args) { + protected Case build(Source source, List args) { return new Case(Source.EMPTY, args.get(0), args.subList(1, args.size())); } + + private static Supplier randomSingleValuedCondition() { + return new Supplier<>() { + @Override + public Boolean get() { + return randomBoolean(); + } + + @Override + public String toString() { + return "multivalue"; + } + }; + } + + private static Supplier> randomMultivaluedCondition() { + return new Supplier<>() { + @Override + public List get() { + return randomList(2, 100, ESTestCase::randomBoolean); + } + + @Override + public String toString() { + return "multivalue"; + } + }; + } + + public void testFancyFolding() { + if (testCase.getExpectedTypeError() != null) { + // Nothing to do + return; + } + Expression e = buildFieldExpression(testCase); + if (extra().foldable == false) { + assertThat(e.foldable(), equalTo(false)); + return; + } + assertThat(e.foldable(), equalTo(true)); + Object result = e.fold(); + if (testCase.getExpectedBuildEvaluatorWarnings() != null) { + assertWarnings(testCase.getExpectedBuildEvaluatorWarnings()); + } + if (testCase.expectedType() == DataType.UNSIGNED_LONG && result != null) { + result = NumericUtils.unsignedLongAsBigInteger((Long) result); + } + assertThat(result, testCase.getMatcher()); + if (testCase.getExpectedWarnings() != null) { + assertWarnings(testCase.getExpectedWarnings()); + } + } + + public void testPartialFold() { + if (testCase.getExpectedTypeError() != null || extra().foldable()) { + // Nothing to do + return; + } + Case c = (Case) buildFieldExpression(testCase); + if (extra().expectedPartialFold == null) { + assertThat(c.partiallyFold(), sameInstance(c)); + return; + } + if (extra().expectedPartialFold.size() == 1) { + assertThat(c.partiallyFold(), equalTo(extra().expectedPartialFold.get(0).asField())); + return; + } + Case expected = build( + Source.synthetic("expected"), + extra().expectedPartialFold.stream().map(TestCaseSupplier.TypedData::asField).toList() + ); + assertThat(c.partiallyFold(), equalTo(expected)); + } + + private static Function addWarnings(List warnings) { + return c -> { + TestCaseSupplier.TestCase r = c; + for (String warning : warnings) { + r = r.withWarning(warning); + } + return r; + }; + } + + private static Function addBuildEvaluatorWarnings(List warnings) { + return c -> { + TestCaseSupplier.TestCase r = c; + for (String warning : warnings) { + r = r.withBuildEvaluatorWarning(warning); + } + return r; + }; + } + + private record Extra(boolean foldable, List expectedPartialFold) {} + + private Extra extra() { + return (Extra) testCase.extra(); + } + + @Override + protected Matcher allNullsMatcher() { + if (extra().foldable) { + return testCase.getMatcher(); + } + return super.allNullsMatcher(); + } + + private static String typeErrorMessage(boolean includeOrdinal, List types) { + if (types.get(0) != DataType.BOOLEAN && types.get(0) != DataType.NULL) { + return typeErrorMessage(includeOrdinal, types, 0, "boolean"); + } + DataType mainType = types.get(1); + for (int i = 2; i < types.size(); i++) { + if (i % 2 == 0 && i != types.size() - 1) { + // condition + if (types.get(i) != DataType.BOOLEAN && types.get(i) != DataType.NULL) { + return typeErrorMessage(includeOrdinal, types, i, "boolean"); + } + } else { + // value + if (types.get(i) != mainType) { + return typeErrorMessage(includeOrdinal, types, i, mainType.typeName()); + } + } + } + throw new IllegalStateException("can't find bad arg for " + types); + } + + private static String typeErrorMessage(boolean includeOrdinal, List types, int badArgPosition, String expectedTypeString) { + String ordinal = includeOrdinal ? TypeResolutions.ParamOrdinal.fromIndex(badArgPosition).name().toLowerCase(Locale.ROOT) + " " : ""; + String name = types.get(badArgPosition).typeName(); + return ordinal + "argument of [] must be [" + expectedTypeString + "], found value [" + name + "] type [" + name + "]"; + } } diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/160_union_types.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/160_union_types.yml index 92b3f4d1b084d..359ac40bc3672 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/160_union_types.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/160_union_types.yml @@ -6,7 +6,7 @@ setup: parameters: [method, path, parameters, capabilities] capabilities: [union_types, union_types_remove_fields, casting_operator] reason: "Union types and casting operator introduced in 8.15.0" - test_runner_features: [capabilities, allowed_warnings_regex] + test_runner_features: [capabilities, allowed_warnings_regex, warnings_regex] - do: indices.create: @@ -830,3 +830,65 @@ load four indices with multiple conversion functions TO_LONG and TO_IP: - match: { values.21.2: "172.21.3.15" } - match: { values.21.3: 1756467 } - match: { values.21.4: "Connected to 10.1.0.1" } + +--- +CASE: + - requires: + capabilities: + - method: POST + path: /_query + parameters: [method, path, parameters, capabilities] + capabilities: [case_mv] + reason: "CASE support for multivalue conditions introduced in 8.16.0" + + - do: + indices.create: + index: b1 + body: + mappings: + properties: + f: + type: keyword + - do: + indices.create: + index: b2 + body: + mappings: + properties: + f: + type: boolean + - do: + bulk: + refresh: true + body: + - '{"index": {"_index": "b1"}}' + - '{"a": 1, "f": false}' + - '{"index": {"_index": "b1"}}' + - '{"a": 2, "f": [true, false]}' + - '{"index": {"_index": "b2"}}' + - '{"a": 3, "f": true}' + + - do: + warnings_regex: + - ".+evaluation of \\[f?\\] failed, treating result as false. Only first 20 failures recorded." + - ".+java.lang.IllegalArgumentException: CASE expects a single-valued boolean" + esql.query: + body: + query: 'FROM b* | EVAL c = CASE(f::BOOLEAN, "a", "b") | SORT a ASC | LIMIT 10' + + - match: { columns.0.name: "a" } + - match: { columns.0.type: "long" } + - match: { columns.1.name: "f" } + - match: { columns.1.type: "unsupported" } + - match: { columns.2.name: "c" } + - match: { columns.2.type: "keyword" } + - length: { values: 3 } + - match: { values.0.0: 1 } + - match: { values.0.1: null } + - match: { values.0.2: "b" } + - match: { values.1.0: 2 } + - match: { values.1.1: null } + - match: { values.1.2: "b" } + - match: { values.2.0: 3 } + - match: { values.2.1: null } + - match: { values.2.2: "a" }