diff --git a/docs/changelog/131341.yaml b/docs/changelog/131341.yaml new file mode 100644 index 0000000000000..d89efddf9e014 --- /dev/null +++ b/docs/changelog/131341.yaml @@ -0,0 +1,5 @@ +pr: 131341 +summary: Consider min/max from predicates when transform date_trunc/bucket to `round_to` +area: ES|QL +type: enhancement +issues: [] diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/bucket.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/bucket.csv-spec index 49b16baf30f58..06ac461cb6c62 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/bucket.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/bucket.csv-spec @@ -884,3 +884,108 @@ c:long | b:datetime | yr:datetime 9 | 1989-01-01T00:00:00.000Z | 1988-01-01T00:00:00.000Z 13 | 1990-01-01T00:00:00.000Z | 1989-01-01T00:00:00.000Z ; + +bucketYearInAggWithGTOutOfRange#[skip:-8.13.99, reason:BUCKET renamed in 8.14] +FROM employees +| WHERE hire_date >= "2000-01-01T00:00:00Z" +| STATS COUNT(*) by bucket = BUCKET(hire_date, 1 month) +| SORT bucket; + +COUNT(*):long | bucket:date +; + +bucketYearInAggWithLTOutOfRange#[skip:-8.13.99, reason:BUCKET renamed in 8.14] +FROM employees +| WHERE hire_date <= "1980-01-01T00:00:00Z" +| STATS COUNT(*) by bucket = BUCKET(hire_date, 1 year) +| SORT bucket; + +COUNT(*):long | bucket:date +; + +bucketYearInAggWithGTLTOutOfRange#[skip:-8.13.99, reason:BUCKET renamed in 8.14] +FROM employees +| WHERE hire_date <= "1980-01-01T00:00:00Z" and hire_date >= "1970-01-01" +| STATS COUNT(*) by bucket = BUCKET(hire_date, 1 week) +| SORT bucket; + +COUNT(*):long | bucket:date +; + +bucketYearInAggWithEQOutOfRange#[skip:-8.13.99, reason:BUCKET renamed in 8.14] +FROM employees +| WHERE hire_date == "1980-01-01T00:00:00Z" +| STATS COUNT(*) by bucket = BUCKET(hire_date, 1 hour) +| SORT bucket; + +COUNT(*):long | bucket:date +; + +bucketWithRename#[skip:-8.13.99, reason:BUCKET renamed in 8.14] +FROM employees +| RENAME hire_date as x, x as y +| WHERE y >= "1980-01-01T00:00:00Z" +| STATS COUNT(*) by bucket = BUCKET(y, 1 hour) +| SORT bucket +| LIMIT 5 +; + +COUNT(*):long | bucket:datetime +1 | 1985-02-18T00:00:00.000Z +1 | 1985-02-24T00:00:00.000Z +1 | 1985-05-13T00:00:00.000Z +1 | 1985-07-09T00:00:00.000Z +1 | 1985-09-17T00:00:00.000Z +; + +bucketWithEval#[skip:-8.13.99, reason:BUCKET renamed in 8.14] +FROM employees +| EVAL x = hire_date +| WHERE x >= "1980-01-01T00:00:00Z" and hire_date <= "1990-01-01T00:00:00Z" +| STATS COUNT(*) by bucket = BUCKET(x, 1 hour) +| SORT bucket +| LIMIT 5 +; + +COUNT(*):long | bucket:datetime +1 | 1985-02-18T00:00:00.000Z +1 | 1985-02-24T00:00:00.000Z +1 | 1985-05-13T00:00:00.000Z +1 | 1985-07-09T00:00:00.000Z +1 | 1985-09-17T00:00:00.000Z +; + +bucketWithEvalExpression#[skip:-8.13.99, reason:BUCKET renamed in 8.14] +FROM employees +| EVAL x = hire_date + 1 year +| WHERE x >= "1980-01-01T00:00:00Z" +| STATS COUNT(*) by bucket = BUCKET(x, 1 hour) +| SORT bucket +| LIMIT 5 +; + +COUNT(*):long | bucket:datetime +1 | 1986-02-18T00:00:00.000Z +1 | 1986-02-24T00:00:00.000Z +1 | 1986-05-13T00:00:00.000Z +1 | 1986-07-09T00:00:00.000Z +1 | 1986-09-17T00:00:00.000Z +; + +bucketWithRenameEvalExpression#[skip:-8.13.99, reason:BUCKET renamed in 8.14] +FROM employees +| EVAL x = hire_date + 1 year +| RENAME x as y +| WHERE y >= "1980-01-01T00:00:00Z" +| STATS COUNT(*) by bucket = BUCKET(y, 1 hour) +| SORT bucket +| LIMIT 5 +; + +COUNT(*):long | bucket:datetime +1 | 1986-02-18T00:00:00.000Z +1 | 1986-02-24T00:00:00.000Z +1 | 1986-05-13T00:00:00.000Z +1 | 1986-07-09T00:00:00.000Z +1 | 1986-09-17T00:00:00.000Z +; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/date.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/date.csv-spec index 4b9e1512844b4..788a5f9877dea 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/date.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/date.csv-spec @@ -1581,4 +1581,258 @@ x:date | y:date ; +evalDateTruncMonthIntervalWithGTEInRange +FROM employees +| SORT hire_date +| WHERE hire_date >= "1990-01-01" +| EVAL x = date_trunc(1 month, hire_date) +| KEEP emp_no, hire_date, x +| LIMIT 5; + +emp_no:integer | hire_date:date | x:date +10082 | 1990-01-03T00:00:00.000Z | 1990-01-01T00:00:00.000Z +10096 | 1990-01-14T00:00:00.000Z | 1990-01-01T00:00:00.000Z +10011 | 1990-01-22T00:00:00.000Z | 1990-01-01T00:00:00.000Z +10056 | 1990-02-01T00:00:00.000Z | 1990-02-01T00:00:00.000Z +10086 | 1990-02-16T00:00:00.000Z | 1990-02-01T00:00:00.000Z +; + +evalDateTruncHoursIntervalWithLTEInRange +FROM employees +| SORT hire_date desc +| WHERE hire_date <= "1990-01-01" +| EVAL x = date_trunc(240 hours, hire_date) +| KEEP emp_no, hire_date, x +| LIMIT 5; + +emp_no:integer | hire_date:date | x:date +10023 | 1989-12-17T00:00:00.000Z | 1989-12-17T00:00:00.000Z +10041 | 1989-11-12T00:00:00.000Z | 1989-11-07T00:00:00.000Z +10069 | 1989-11-05T00:00:00.000Z | 1989-10-28T00:00:00.000Z +10092 | 1989-09-22T00:00:00.000Z | 1989-09-18T00:00:00.000Z +10038 | 1989-09-20T00:00:00.000Z | 1989-09-18T00:00:00.000Z +; + +evalDateTruncWeeklyIntervalWithLTGTInRange +from employees +| SORT hire_date +| WHERE hire_date > "1986-01-01" and hire_date < "1988-01-01" +| EVAL x = date_trunc(1 week, hire_date) +| KEEP emp_no, hire_date, x +| LIMIT 5; + +emp_no:integer | hire_date:date | x:date +10053 | 1986-02-04T00:00:00.000Z | 1986-02-03T00:00:00.000Z +10066 | 1986-02-26T00:00:00.000Z | 1986-02-24T00:00:00.000Z +10090 | 1986-03-14T00:00:00.000Z | 1986-03-10T00:00:00.000Z +10079 | 1986-03-27T00:00:00.000Z | 1986-03-24T00:00:00.000Z +10001 | 1986-06-26T00:00:00.000Z | 1986-06-23T00:00:00.000Z +; + +evalDateTruncQuarterlyIntervalWithGTInRange +from employees +| SORT hire_date +| WHERE hire_date > "1980-01-01" +| EVAL x = date_trunc(3 month, hire_date) +| KEEP emp_no, hire_date, x +| LIMIT 5; + +emp_no:integer | hire_date:date | x:date +10009 | 1985-02-18T00:00:00.000Z | 1985-01-01T00:00:00.000Z +10048 | 1985-02-24T00:00:00.000Z | 1985-01-01T00:00:00.000Z +10098 | 1985-05-13T00:00:00.000Z | 1985-04-01T00:00:00.000Z +10076 | 1985-07-09T00:00:00.000Z | 1985-07-01T00:00:00.000Z +10061 | 1985-09-17T00:00:00.000Z | 1985-07-01T00:00:00.000Z +; + +dateTruncGroupingYearIntervalWithLTInRange +from employees +| WHERE hire_date < "2025-01-01" +| EVAL y = date_trunc(1 year, hire_date) +| stats c = count(emp_no) by y +| SORT y +| KEEP y, c +| LIMIT 5; + +y:date | c:long +1985-01-01T00:00:00.000Z | 11 +1986-01-01T00:00:00.000Z | 11 +1987-01-01T00:00:00.000Z | 15 +1988-01-01T00:00:00.000Z | 9 +1989-01-01T00:00:00.000Z | 13 +; + +dateTruncGroupingYearIntervalWithLTOutOfRange +from employees +| WHERE hire_date < "1980-01-01" +| EVAL y = date_trunc(1 year, hire_date) +| stats c = count(emp_no) by y +| SORT y +| KEEP y, c +| LIMIT 5; + +y:date | c:long +; + +dateTruncGroupingYearIntervalWithGTOutOfRange +from employees +| WHERE hire_date > "2000-01-01" +| EVAL y = date_trunc(1 year, hire_date) +| stats c = count(emp_no) by y +| SORT y +| KEEP y, c +| LIMIT 5; + +y:date | c:long +; +dateTruncGroupingMonthIntervalWithLTGTInRange +from employees +| WHERE hire_date > "1987-01-01" and hire_date < "1988-01-01" +| EVAL y = date_trunc(1 month, hire_date) +| stats c = count(emp_no) by y +| SORT y +| KEEP y, c +| LIMIT 5; + +y:date | c:long +1987-03-01T00:00:00.000Z | 5 +1987-04-01T00:00:00.000Z | 3 +1987-05-01T00:00:00.000Z | 1 +1987-07-01T00:00:00.000Z | 1 +1987-08-01T00:00:00.000Z | 2 +; + +dateTruncGroupingDayIntervalWithEQInRange +from employees +| WHERE hire_date == "1988-02-10" +| EVAL y = date_trunc(1 day, hire_date) +| stats c = count(emp_no) by y +| SORT y +| KEEP y, c +| LIMIT 5; + +y:date | c:long +1988-02-10T00:00:00.000Z | 1 +; + +dateTruncGroupingDayIntervalWithEQOutOfRange +from employees +| WHERE hire_date == "2025-01-01" +| EVAL y = date_trunc(1 day, hire_date) +| stats c = count(emp_no) by y +| SORT y +| KEEP y, c +| LIMIT 5; + +y:date | c:long +; + +dateTruncWithEval +from employees +| EVAL x = hire_date +| WHERE x > "1987-01-01" and hire_date < "1988-01-01" +| EVAL y = date_trunc(1 month, x) +| STATS c = count(emp_no) by y +| SORT y +| KEEP y, c +| LIMIT 5; + +y:date | c:long +1987-03-01T00:00:00.000Z | 5 +1987-04-01T00:00:00.000Z | 3 +1987-05-01T00:00:00.000Z | 1 +1987-07-01T00:00:00.000Z | 1 +1987-08-01T00:00:00.000Z | 2 +; + +dateTruncWithEvalExpression +from employees +| EVAL x = hire_date + 1 year +| WHERE x > "1987-01-01" and x < "1988-01-01" +| EVAL y = date_trunc(1 month, x) +| STATS c = count(emp_no) by y +| SORT y +| KEEP y, c +| LIMIT 5; + +y:date | c:long +1987-02-01T00:00:00.000Z | 2 +1987-03-01T00:00:00.000Z | 2 +1987-06-01T00:00:00.000Z | 1 +1987-07-01T00:00:00.000Z | 1 +1987-08-01T00:00:00.000Z | 2 +; + +dateTruncWithRename +FROM employees +| RENAME hire_date as x +| WHERE x > "1987-01-01" and x < "1988-01-01" +| EVAL y = date_trunc(1 month, x) +| STATS c = count(emp_no) by y +| SORT y +| KEEP y, c +| LIMIT 5; + +y:date | c:long +1987-03-01T00:00:00.000Z | 5 +1987-04-01T00:00:00.000Z | 3 +1987-05-01T00:00:00.000Z | 1 +1987-07-01T00:00:00.000Z | 1 +1987-08-01T00:00:00.000Z | 2 +; + +dateTruncWithRenameChain +FROM employees +| RENAME hire_date as a, a as x +| WHERE x > "1987-01-01" and x < "1988-01-01" +| EVAL y = date_trunc(1 month, x) +| STATS c = count(emp_no) by y +| SORT y +| KEEP y, c +| LIMIT 5; + +y:date | c:long +1987-03-01T00:00:00.000Z | 5 +1987-04-01T00:00:00.000Z | 3 +1987-05-01T00:00:00.000Z | 1 +1987-07-01T00:00:00.000Z | 1 +1987-08-01T00:00:00.000Z | 2 +; + +dateTruncWithRenameBack +FROM employees +| RENAME hire_date as x, x as hire_date +| WHERE hire_date > "1987-01-01" and hire_date < "1988-01-01" +| EVAL y = date_trunc(1 month, hire_date) +| STATS c = count(emp_no) by y +| SORT y +| KEEP y, c +| LIMIT 5; + +y:date | c:long +1987-03-01T00:00:00.000Z | 5 +1987-04-01T00:00:00.000Z | 3 +1987-05-01T00:00:00.000Z | 1 +1987-07-01T00:00:00.000Z | 1 +1987-08-01T00:00:00.000Z | 2 +; + +dateTruncWithEvalRename +FROM employees +| EVAL a = hire_date +| RENAME hire_date as b +| WHERE a > "1987-01-01" and a < "1988-01-01" +| EVAL y = date_trunc(1 month, b) +| STATS c = count(emp_no) by y +| SORT y +| KEEP y, c +| LIMIT 5; + +y:date | c:long +1987-03-01T00:00:00.000Z | 5 +1987-04-01T00:00:00.000Z | 3 +1987-05-01T00:00:00.000Z | 1 +1987-07-01T00:00:00.000Z | 1 +1987-08-01T00:00:00.000Z | 2 +; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/LocalSurrogateExpression.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/LocalSurrogateExpression.java index f0401ae1d4f05..99a6b0397f88e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/LocalSurrogateExpression.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/LocalSurrogateExpression.java @@ -8,15 +8,19 @@ package org.elasticsearch.xpack.esql.expression; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison; import org.elasticsearch.xpack.esql.stats.SearchStats; +import java.util.List; + /** * Interface signaling to the local logical plan optimizer that the declaring expression * has to be replaced by a different form. * Implement this on {@code Function}s when: * */ @@ -24,5 +28,10 @@ public interface LocalSurrogateExpression { /** * Returns the expression to be replaced by or {@code null} if this cannot be replaced. */ - Expression surrogate(SearchStats searchStats); + Expression surrogate(SearchStats searchStats, List binaryComparisons); + + /** + * Returns the field that can be used by {@code LocalSubstituteSurrogateExpressions} to check predicates against. + */ + Expression field(); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java index 01bc4dd2b4eec..c97e671de9c75 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java @@ -35,6 +35,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.math.Floor; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.stats.SearchStats; @@ -498,7 +499,7 @@ public String toString() { } @Override - public Expression surrogate(SearchStats searchStats) { + public Expression surrogate(SearchStats searchStats, List binaryComparisons) { // LocalSubstituteSurrogateExpressions should make sure this doesn't happen assert searchStats != null : "SearchStats cannot be null"; return maybeSubstituteWithRoundTo( @@ -506,6 +507,7 @@ public Expression surrogate(SearchStats searchStats) { field(), buckets(), searchStats, + binaryComparisons, (interval, minValue, maxValue) -> getDateRounding(FoldContext.small(), minValue, maxValue) ); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateTrunc.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateTrunc.java index 9b4d312e9df42..d894ab9220550 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateTrunc.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateTrunc.java @@ -17,6 +17,7 @@ import org.elasticsearch.compute.ann.Fixed; import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.core.Tuple; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; import org.elasticsearch.xpack.esql.core.expression.Expression; @@ -27,12 +28,19 @@ import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.type.MultiTypeEsField; +import org.elasticsearch.xpack.esql.core.util.Holder; import org.elasticsearch.xpack.esql.expression.LocalSurrogateExpression; import org.elasticsearch.xpack.esql.expression.function.Example; import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; import org.elasticsearch.xpack.esql.expression.function.scalar.math.RoundTo; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThan; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThanOrEqual; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.stats.SearchStats; @@ -125,7 +133,7 @@ Expression interval() { return interval; } - Expression field() { + public Expression field() { return timestampField; } @@ -287,7 +295,7 @@ public static ExpressionEvaluator.Factory evaluator( } @Override - public Expression surrogate(SearchStats searchStats) { + public Expression surrogate(SearchStats searchStats, List binaryComparisons) { // LocalSubstituteSurrogateExpressions should make sure this doesn't happen assert searchStats != null : "SearchStats cannot be null"; return maybeSubstituteWithRoundTo( @@ -295,6 +303,7 @@ public Expression surrogate(SearchStats searchStats) { field(), interval(), searchStats, + binaryComparisons, (interval, minValue, maxValue) -> createRounding(interval, DEFAULT_TZ, minValue, maxValue) ); } @@ -304,27 +313,42 @@ public static RoundTo maybeSubstituteWithRoundTo( Expression field, Expression foldableTimeExpression, SearchStats searchStats, + List binaryComparisons, TriFunction roundingFunction ) { if (field instanceof FieldAttribute fa && fa.field() instanceof MultiTypeEsField == false && isDateTime(fa.dataType())) { // Extract min/max from SearchStats DataType fieldType = fa.dataType(); FieldAttribute.FieldName fieldName = fa.fieldName(); - var min = searchStats.min(fieldName); - var max = searchStats.max(fieldName); + // Extract min/max from SearchStats + Object minFromSearchStats = searchStats.min(fieldName); + Object maxFromSearchStats = searchStats.max(fieldName); + Long min = toLong(minFromSearchStats); + Long max = toLong(maxFromSearchStats); + // Extract min/max from query + Tuple minMaxFromPredicates = minMaxFromPredicates(binaryComparisons); + Long minFromPredicates = minMaxFromPredicates.v1(); + Long maxFromPredicates = minMaxFromPredicates.v2(); + // Consolidate min/max from SearchStats and query + if (minFromPredicates != null) { + min = min != null ? Math.max(min, minFromPredicates) : minFromPredicates; + } + if (maxFromPredicates != null) { + max = max != null ? Math.min(max, maxFromPredicates) : maxFromPredicates; + } // If min/max is available create rounding with them - if (min instanceof Long minValue && max instanceof Long maxValue && foldableTimeExpression.foldable()) { + if (min != null && max != null && foldableTimeExpression.foldable() && min <= max) { Object foldedInterval = foldableTimeExpression.fold(FoldContext.small() /* TODO remove me */); - Rounding.Prepared rounding = roundingFunction.apply(foldedInterval, minValue, maxValue); + Rounding.Prepared rounding = roundingFunction.apply(foldedInterval, min, max); long[] roundingPoints = rounding.fixedRoundingPoints(); if (roundingPoints == null) { logger.trace( "Fixed rounding point is null for field {}, minValue {} in string format {} and maxValue {} in string format {}", fieldName, - minValue, - dateWithTypeToString(minValue, fieldType), - maxValue, - dateWithTypeToString(maxValue, fieldType) + min, + dateWithTypeToString(min, fieldType), + max, + dateWithTypeToString(max, fieldType) ); return null; } @@ -337,4 +361,35 @@ public static RoundTo maybeSubstituteWithRoundTo( } return null; } + + private static Tuple minMaxFromPredicates(List binaryComparisons) { + long[] min = new long[] { Long.MIN_VALUE }; + long[] max = new long[] { Long.MAX_VALUE }; + Holder foundMinValue = new Holder<>(false); + Holder foundMaxValue = new Holder<>(false); + for (EsqlBinaryComparison binaryComparison : binaryComparisons) { + if (binaryComparison.right() instanceof Literal l) { + long value = toLong(l.value()); + if (binaryComparison instanceof Equals) { + return new Tuple<>(value, value); + } + if (binaryComparison instanceof GreaterThan || binaryComparison instanceof GreaterThanOrEqual) { + if (value >= min[0]) { + min[0] = value; + foundMinValue.set(true); + } + } else if (binaryComparison instanceof LessThan || binaryComparison instanceof LessThanOrEqual) { + if (value <= max[0]) { + max[0] = value; + foundMaxValue.set(true); + } + } + } + } + return new Tuple<>(foundMinValue.get() ? min[0] : null, foundMaxValue.get() ? max[0] : null); + } + + private static Long toLong(Object value) { + return value instanceof Long l ? l : null; + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/LocalSubstituteSurrogateExpressions.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/LocalSubstituteSurrogateExpressions.java index ff25be0c85258..dbdf834854077 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/LocalSubstituteSurrogateExpressions.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/LocalSubstituteSurrogateExpressions.java @@ -10,32 +10,64 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.function.Function; import org.elasticsearch.xpack.esql.expression.LocalSurrogateExpression; +import org.elasticsearch.xpack.esql.expression.predicate.Predicates; +import org.elasticsearch.xpack.esql.expression.predicate.logical.And; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison; import org.elasticsearch.xpack.esql.optimizer.LocalLogicalOptimizerContext; import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.Filter; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.rule.ParameterizedRule; import org.elasticsearch.xpack.esql.stats.SearchStats; +import java.util.ArrayList; +import java.util.List; + public class LocalSubstituteSurrogateExpressions extends ParameterizedRule { @Override public LogicalPlan apply(LogicalPlan plan, LocalLogicalOptimizerContext context) { - return context.searchStats() != null - ? plan.transformUp(Eval.class, eval -> eval.transformExpressionsOnly(Function.class, f -> substitute(f, context.searchStats()))) - : plan; + return context.searchStats() != null ? plan.transformUp(Eval.class, eval -> substitute(eval, context.searchStats())) : plan; + } + + private LogicalPlan substitute(Eval eval, SearchStats searchStats) { + // check the filter in children plans + return eval.transformExpressionsOnly(Function.class, f -> substitute(f, eval, searchStats)); + } + + private List predicates(Eval eval, Expression field) { + List binaryComparisons = new ArrayList<>(); + eval.forEachUp(Filter.class, filter -> { + Expression condition = filter.condition(); + if (condition instanceof And and) { + Predicates.splitAnd(and).forEach(e -> addBinaryComparisonOnField(e, field, binaryComparisons)); + } else { + addBinaryComparisonOnField(condition, field, binaryComparisons); + } + }); + return binaryComparisons; + } + + private void addBinaryComparisonOnField(Expression expression, Expression field, List binaryComparisons) { + if (expression instanceof EsqlBinaryComparison esqlBinaryComparison + && esqlBinaryComparison.right().foldable() + && esqlBinaryComparison.left().semanticEquals(field)) { + binaryComparisons.add(esqlBinaryComparison); + } } /** - * Perform the actual substitution. + * Perform the actual substitution with {@code SearchStats} and predicates in the query. */ - private static Expression substitute(Expression e, SearchStats searchStats) { + private Expression substitute(Expression e, Eval eval, SearchStats searchStats) { if (e instanceof LocalSurrogateExpression s) { - Expression surrogate = s.surrogate(searchStats); + // extract relevant predicates from the query + List binaryComparisons = new ArrayList<>(predicates(eval, s.field())); + Expression surrogate = s.surrogate(searchStats, binaryComparisons); if (surrogate != null) { return surrogate; } } return e; } - } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/LocalSubstituteSurrogateExpressionTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/LocalSubstituteSurrogateExpressionTests.java index 74ec1b71cf824..9d9858420100b 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/LocalSubstituteSurrogateExpressionTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/LocalSubstituteSurrogateExpressionTests.java @@ -7,9 +7,14 @@ package org.elasticsearch.xpack.esql.optimizer.rules.logical.local; +import org.elasticsearch.common.logging.LoggerMessageFormat; +import org.elasticsearch.test.junit.annotations.TestLogging; import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.core.expression.Alias; +import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; +import org.elasticsearch.xpack.esql.expression.LocalSurrogateExpression; import org.elasticsearch.xpack.esql.expression.function.scalar.math.RoundTo; import org.elasticsearch.xpack.esql.optimizer.LocalLogicalPlanOptimizerTests; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; @@ -21,94 +26,249 @@ import org.elasticsearch.xpack.esql.plan.logical.TopN; import org.elasticsearch.xpack.esql.stats.SearchStats; +import java.util.HashMap; import java.util.List; import java.util.Map; import static org.elasticsearch.xpack.esql.EsqlTestUtils.as; import static org.elasticsearch.xpack.esql.core.type.DataType.DATETIME; +@TestLogging(value = "org.elasticsearch.xpack.esql:TRACE", reason = "debug") public class LocalSubstituteSurrogateExpressionTests extends LocalLogicalPlanOptimizerTests { - public void testSubstituteDateTruncInEvalWithRoundTo() { - var plan = plan(""" - from test - | sort hire_date - | eval x = date_trunc(1 day, hire_date) - | keep emp_no, hire_date, x - | limit 5 - """); + // Key is the predicate, + // Value is the number of items in the round_to function, if the number of item is 0, that means the min/max in predicates do not + // overlap with SearchStats, so the substitution does not happen. + private static final Map predicatesWithDateTruncBucket = new HashMap<>( + Map.ofEntries( + Map.entry("", 4), + Map.entry(" | where hire_date == \"2023-10-22\" ", 1), + Map.entry(" | where hire_date == \"2023-10-19\" ", 0), + Map.entry(" | where hire_date >= \"2023-10-20\" ", 4), + Map.entry(" | where hire_date >= \"2023-10-22\" ", 2), + Map.entry(" | where hire_date > \"2023-10-24\" ", 0), + Map.entry(" | where hire_date < \"2023-10-24\" ", 4), + Map.entry(" | where hire_date <= \"2023-10-22\" ", 3), + Map.entry(" | where hire_date <= \"2023-10-19\" ", 0), + Map.entry(" | where hire_date >= \"2023-10-20\" and hire_date <= \"2023-10-24\" ", 4), + Map.entry(" | where hire_date >= \"2023-10-21\" and hire_date <= \"2023-10-23\" ", 3), + Map.entry(" | where hire_date >= \"2023-10-24\" and hire_date <= \"2023-10-31\" ", 0) + ) + ); - // create a SearchStats with min and max millis - Map minValue = Map.of("hire_date", 1697804103360L); // 2023-10-20T12:15:03.360Z - Map maxValue = Map.of("hire_date", 1698069301543L); // 2023-10-23T13:55:01.543Z - SearchStats searchStats = new EsqlTestUtils.TestSearchStatsWithMinMax(minValue, maxValue); - - LogicalPlan localPlan = localPlan(plan, searchStats); - Project project = as(localPlan, Project.class); - TopN topN = as(project.child(), TopN.class); - Eval eval = as(topN.child(), Eval.class); - List fields = eval.fields(); - assertEquals(1, fields.size()); - Alias a = fields.get(0); - assertEquals("x", a.name()); - RoundTo roundTo = as(a.child(), RoundTo.class); - FieldAttribute fa = as(roundTo.field(), FieldAttribute.class); - assertEquals("hire_date", fa.name()); - assertEquals(DATETIME, fa.dataType()); - assertEquals(4, roundTo.points().size()); // 4 days - EsRelation relation = as(eval.child(), EsRelation.class); + private static final Map evalRenamePredicatesWithDateTruncBucket = new HashMap<>( + Map.ofEntries( + // ReplaceAliasingEvalWithProject replaces x with hire_date so that the DateTrunc can be transformed to RoundTo + Map.entry(" | eval x = hire_date ", 4), + // DateTrunc cannot be transformed to RoundTo if it references an expression + Map.entry(" | eval x = hire_date + 1 year ", -1), + // PushDownEval replaces the reference(x) in DateTrunc with the corresponding field hire_date + Map.entry(" | rename hire_date as x ", 4), + Map.entry(" | rename hire_date as a, a as x ", 4), + Map.entry(" | rename hire_date as x, x as hire_date ", 4), + Map.entry(" | eval a = hire_date | rename a as x ", 4), + Map.entry(" | eval x = hire_date | where x >= \"2023-10-22\" ", 2), + Map.entry(" | rename hire_date as x | where x >= \"2023-10-20\" ", 4), + Map.entry(" | rename hire_date as a, a as x | where x <= \"2023-10-22\" ", 3), + Map.entry(" | rename hire_date as x, x as hire_date | where hire_date >= \"2023-10-21\" and hire_date <= \"2023-10-23\" ", 3), + Map.entry(" | eval a = hire_date | rename a as x | where x <= \"2023-10-22\" ", 3) + ) + ); + + // The date range of SearchStats is from 2023-10-20 to 2023-10-23. + private static final SearchStats searchStats = searchStats(); + + public void testSubstituteDateTruncInEvalWithRoundTo() { + for (Map.Entry predicate : predicatesWithDateTruncBucket.entrySet()) { + String predicateString = predicate.getKey(); + int roundToPointsSize = predicate.getValue(); + String query = LoggerMessageFormat.format(null, """ + from test + | sort hire_date + | eval x = date_trunc(1 day, hire_date) + | keep emp_no, hire_date, x + {} + | limit 5 + """, predicateString); + LogicalPlan localPlan = localPlan(query, searchStats); + Project project = as(localPlan, Project.class); + TopN topN = as(project.child(), TopN.class); + Eval eval = as(topN.child(), Eval.class); + List fields = eval.fields(); + assertEquals(1, fields.size()); + Alias a = fields.get(0); + assertEquals("x", a.name()); + verifySubstitution(a, roundToPointsSize); + LogicalPlan subPlan = predicateString.isEmpty() ? eval : eval.child(); + EsRelation relation = as(subPlan.children().get(0), EsRelation.class); + } } public void testSubstituteDateTruncInAggWithRoundTo() { - var plan = plan(""" - from test - | stats count(*) by x = date_trunc(1 day, hire_date) - """); - - // create a SearchStats with min and max millis - Map minValue = Map.of("hire_date", 1697804103360L); // 2023-10-20T12:15:03.360Z - Map maxValue = Map.of("hire_date", 1698069301543L); // 2023-10-23T13:55:01.543Z - SearchStats searchStats = new EsqlTestUtils.TestSearchStatsWithMinMax(minValue, maxValue); - - LogicalPlan localPlan = localPlan(plan, searchStats); - Limit limit = as(localPlan, Limit.class); - Aggregate aggregate = as(limit.child(), Aggregate.class); - Eval eval = as(aggregate.child(), Eval.class); - List fields = eval.fields(); - assertEquals(1, fields.size()); - Alias a = fields.get(0); - assertEquals("x", a.name()); - RoundTo roundTo = as(a.child(), RoundTo.class); - FieldAttribute fa = as(roundTo.field(), FieldAttribute.class); - assertEquals("hire_date", fa.name()); - assertEquals(DATETIME, fa.dataType()); - assertEquals(4, roundTo.points().size()); // 4 days - EsRelation relation = as(eval.child(), EsRelation.class); + for (Map.Entry predicate : predicatesWithDateTruncBucket.entrySet()) { + String predicateString = predicate.getKey(); + int roundToPointsSize = predicate.getValue(); + String query = LoggerMessageFormat.format(null, """ + from test + {} + | stats count(*) by x = date_trunc(1 day, hire_date) + """, predicateString); + LogicalPlan localPlan = localPlan(query, searchStats); + Limit limit = as(localPlan, Limit.class); + Aggregate aggregate = as(limit.child(), Aggregate.class); + Eval eval = as(aggregate.child(), Eval.class); + List fields = eval.fields(); + assertEquals(1, fields.size()); + Alias a = fields.get(0); + assertEquals("x", a.name()); + verifySubstitution(a, roundToPointsSize); + LogicalPlan subPlan = predicateString.isEmpty() ? eval : eval.child(); + EsRelation relation = as(subPlan.children().get(0), EsRelation.class); + } } public void testSubstituteBucketInAggWithRoundTo() { - var plan = plan(""" - from test - | stats count(*) by x = bucket(hire_date, 1 day) - """); + for (Map.Entry predicate : predicatesWithDateTruncBucket.entrySet()) { + String predicateString = predicate.getKey(); + int roundToPointsSize = predicate.getValue(); + String query = LoggerMessageFormat.format(null, """ + from test + {} + | stats count(*) by x = bucket(hire_date, 1 day) + """, predicateString); + LogicalPlan localPlan = localPlan(query, searchStats); + Limit limit = as(localPlan, Limit.class); + Aggregate aggregate = as(limit.child(), Aggregate.class); + Eval eval = as(aggregate.child(), Eval.class); + List fields = eval.fields(); + assertEquals(1, fields.size()); + Alias a = fields.get(0); + assertEquals("x", a.name()); + verifySubstitution(a, roundToPointsSize); + LogicalPlan subPlan = predicateString.isEmpty() ? eval : eval.child(); + EsRelation relation = as(subPlan.children().get(0), EsRelation.class); + } + } + + public void testSubstituteDateTruncInEvalWithRoundToWithEvalRename() { + for (Map.Entry predicate : evalRenamePredicatesWithDateTruncBucket.entrySet()) { + String predicateString = predicate.getKey(); + int roundToPointsSize = predicate.getValue(); + boolean hasWhere = predicateString.contains("where"); + boolean renameBack = predicateString.contains("rename hire_date as x, x as hire_date"); + boolean dateTruncOnExpression = predicateString.contains("hire_date + 1 year"); + String fieldName = renameBack ? "hire_date" : "x"; + String query = LoggerMessageFormat.format(null, """ + from test + | sort hire_date + {} + | eval y = date_trunc(1 day, {}) + | keep emp_no, {}, y + | limit 5 + """, predicateString, fieldName, fieldName); + LogicalPlan localPlan = localPlan(query, searchStats); + Project project = as(localPlan, Project.class); + TopN topN = as(project.child(), TopN.class); + Eval eval = as(topN.child(), Eval.class); + List fields = eval.fields(); + assertEquals(dateTruncOnExpression ? 2 : 1, fields.size()); + Alias a = fields.get(dateTruncOnExpression ? 1 : 0); + assertEquals("y", a.name()); + verifySubstitution(a, roundToPointsSize); + LogicalPlan subPlan = hasWhere ? eval.child() : eval; + EsRelation relation = as(subPlan.children().get(0), EsRelation.class); + } + } + + public void testSubstituteBucketInAggWithRoundToWithEvalRename() { + for (Map.Entry predicate : evalRenamePredicatesWithDateTruncBucket.entrySet()) { + String predicateString = predicate.getKey(); + int roundToPointsSize = predicate.getValue(); + boolean hasWhere = predicateString.contains("where"); + boolean renameBack = predicateString.contains("rename hire_date as x, x as hire_date"); + boolean dateTruncOnExpression = predicateString.contains("hire_date + 1 year"); + String fieldName = renameBack ? "hire_date" : "x"; + String query = LoggerMessageFormat.format(null, """ + from test + {} + | stats count(*) by y = bucket({}, 1 day) + """, predicateString, fieldName); + LogicalPlan localPlan = localPlan(query, searchStats); + Limit limit = as(localPlan, Limit.class); + Aggregate aggregate = as(limit.child(), Aggregate.class); + Eval eval = as(aggregate.child(), Eval.class); + List fields = eval.fields(); + assertEquals(dateTruncOnExpression ? 2 : 1, fields.size()); + Alias a = fields.get(dateTruncOnExpression ? 1 : 0); + assertEquals("y", a.name()); + verifySubstitution(a, roundToPointsSize); + LogicalPlan subPlan = hasWhere ? eval.child() : eval; + EsRelation relation = as(subPlan.children().get(0), EsRelation.class); + } + } + + public void testSubstituteDateTruncInAggWithRoundToWithEvalRename() { + for (Map.Entry predicate : evalRenamePredicatesWithDateTruncBucket.entrySet()) { + String predicateString = predicate.getKey(); + int roundToPointsSize = predicate.getValue(); + boolean hasWhere = predicateString.contains("where"); + boolean renameBack = predicateString.contains("rename hire_date as x, x as hire_date"); + boolean dateTruncOnExpression = predicateString.contains("hire_date + 1 year"); + String fieldName = renameBack ? "hire_date" : "x"; + String query = LoggerMessageFormat.format(null, """ + from test + {} + | stats count(*) by y = date_trunc(1 day, {}) + """, predicateString, fieldName); + LogicalPlan localPlan = localPlan(query, searchStats); + Limit limit = as(localPlan, Limit.class); + Aggregate aggregate = as(limit.child(), Aggregate.class); + Eval eval = as(aggregate.child(), Eval.class); + List fields = eval.fields(); + assertEquals(dateTruncOnExpression ? 2 : 1, fields.size()); + Alias a = fields.get(dateTruncOnExpression ? 1 : 0); + assertEquals("y", a.name()); + verifySubstitution(a, roundToPointsSize); + LogicalPlan subPlan = hasWhere ? eval.child() : eval; + EsRelation relation = as(subPlan.children().get(0), EsRelation.class); + } + } + + private void verifySubstitution(Alias a, int roundToPointsSize) { + FieldAttribute fa = null; + Expression e = a.child(); + if (roundToPointsSize > 0) { + RoundTo roundTo = as(e, RoundTo.class); + fa = as(roundTo.field(), FieldAttribute.class); + assertEquals(roundToPointsSize, roundTo.points().size()); + } else if (roundToPointsSize == 0) { + if (e instanceof LocalSurrogateExpression lse) { + fa = as(lse.field(), FieldAttribute.class); + } else { + fail(e.getClass() + " is not supported"); + } + } else { + if (e instanceof LocalSurrogateExpression lse) { + assertTrue(lse.field() instanceof ReferenceAttribute); + } else { + fail(e.getClass() + " is not supported"); + } + } + if (roundToPointsSize >= 0) { + assertEquals("hire_date", fa.name()); + assertEquals(DATETIME, fa.dataType()); + } + } + + private LogicalPlan localPlan(String query, SearchStats searchStats) { + var plan = plan(query); + return localPlan(plan, searchStats); + } + + private static SearchStats searchStats() { // create a SearchStats with min and max millis Map minValue = Map.of("hire_date", 1697804103360L); // 2023-10-20T12:15:03.360Z Map maxValue = Map.of("hire_date", 1698069301543L); // 2023-10-23T13:55:01.543Z - SearchStats searchStats = new EsqlTestUtils.TestSearchStatsWithMinMax(minValue, maxValue); - - LogicalPlan localPlan = localPlan(plan, searchStats); - Limit limit = as(localPlan, Limit.class); - Aggregate aggregate = as(limit.child(), Aggregate.class); - Eval eval = as(aggregate.child(), Eval.class); - List fields = eval.fields(); - assertEquals(1, fields.size()); - Alias a = fields.get(0); - assertEquals("x", a.name()); - RoundTo roundTo = as(a.child(), RoundTo.class); - FieldAttribute fa = as(roundTo.field(), FieldAttribute.class); - assertEquals("hire_date", fa.name()); - assertEquals(DATETIME, fa.dataType()); - assertEquals(4, roundTo.points().size()); // 4 days - EsRelation relation = as(eval.child(), EsRelation.class); + return new EsqlTestUtils.TestSearchStatsWithMinMax(minValue, maxValue); } }