Skip to content

Commit 88dbdf8

Browse files
support bucket SubstituteSurrogateExpressionsWithSearchStats
1 parent f78fa18 commit 88dbdf8

File tree

4 files changed

+95
-14
lines changed

4 files changed

+95
-14
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,29 +18,36 @@
1818
import org.elasticsearch.xpack.esql.capabilities.PostOptimizationVerificationAware;
1919
import org.elasticsearch.xpack.esql.common.Failures;
2020
import org.elasticsearch.xpack.esql.core.expression.Expression;
21+
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
2122
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
2223
import org.elasticsearch.xpack.esql.core.expression.Foldables;
2324
import org.elasticsearch.xpack.esql.core.expression.Literal;
2425
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
2526
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
2627
import org.elasticsearch.xpack.esql.core.tree.Source;
2728
import org.elasticsearch.xpack.esql.core.type.DataType;
29+
import org.elasticsearch.xpack.esql.core.type.MultiTypeEsField;
30+
import org.elasticsearch.xpack.esql.expression.SurrogateExpression;
2831
import org.elasticsearch.xpack.esql.expression.function.Example;
2932
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
3033
import org.elasticsearch.xpack.esql.expression.function.FunctionType;
3134
import org.elasticsearch.xpack.esql.expression.function.Param;
3235
import org.elasticsearch.xpack.esql.expression.function.TwoOptionalArguments;
3336
import org.elasticsearch.xpack.esql.expression.function.scalar.date.DateTrunc;
3437
import org.elasticsearch.xpack.esql.expression.function.scalar.math.Floor;
38+
import org.elasticsearch.xpack.esql.expression.function.scalar.math.RoundTo;
3539
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div;
3640
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul;
3741
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
42+
import org.elasticsearch.xpack.esql.stats.SearchStats;
3843

3944
import java.io.IOException;
4045
import java.time.ZoneId;
4146
import java.time.ZoneOffset;
4247
import java.util.ArrayList;
48+
import java.util.Arrays;
4349
import java.util.List;
50+
import java.util.stream.Collectors;
4451

4552
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
4653
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
@@ -49,6 +56,7 @@
4956
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.THIRD;
5057
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNumeric;
5158
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
59+
import static org.elasticsearch.xpack.esql.core.type.DataType.isDateTime;
5260
import static org.elasticsearch.xpack.esql.expression.Validations.isFoldable;
5361
import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.dateTimeToLong;
5462

@@ -61,7 +69,8 @@
6169
public class Bucket extends GroupingFunction.EvaluatableGroupingFunction
6270
implements
6371
PostOptimizationVerificationAware,
64-
TwoOptionalArguments {
72+
TwoOptionalArguments,
73+
SurrogateExpression {
6574
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Bucket", Bucket::new);
6675

6776
// TODO maybe we should just cover the whole of representable dates here - like ten years, 100 years, 1000 years, all the way up.
@@ -301,15 +310,22 @@ public Rounding.Prepared getDateRoundingOrNull(FoldContext foldCtx) {
301310
}
302311

303312
private Rounding.Prepared getDateRounding(FoldContext foldContext) {
313+
return getDateRounding(foldContext, null, null);
314+
}
315+
316+
private Rounding.Prepared getDateRounding(FoldContext foldContext, Long min, Long max) {
304317
assert field.dataType() == DataType.DATETIME || field.dataType() == DataType.DATE_NANOS : "expected date type; got " + field;
305318
if (buckets.dataType().isWholeNumber()) {
306319
int b = ((Number) buckets.fold(foldContext)).intValue();
307320
long f = foldToLong(foldContext, from);
308321
long t = foldToLong(foldContext, to);
322+
if (min != null && max != null) {
323+
return new DateRoundingPicker(b, f, t).pickRounding().prepare(min, max);
324+
}
309325
return new DateRoundingPicker(b, f, t).pickRounding().prepareForUnknown();
310326
} else {
311327
assert DataType.isTemporalAmount(buckets.dataType()) : "Unexpected span data type [" + buckets.dataType() + "]";
312-
return DateTrunc.createRounding(buckets.fold(foldContext), DEFAULT_TZ);
328+
return DateTrunc.createRounding(buckets.fold(foldContext), DEFAULT_TZ, min, max);
313329
}
314330
}
315331

@@ -488,4 +504,40 @@ public Expression to() {
488504
public String toString() {
489505
return "Bucket{" + "field=" + field + ", buckets=" + buckets + ", from=" + from + ", to=" + to + '}';
490506
}
507+
508+
@Override
509+
public Expression surrogate() {
510+
return null;
511+
}
512+
513+
@Override
514+
public Expression surrogate(SearchStats searchStats) {
515+
if (field() instanceof FieldAttribute fa && fa.field() instanceof MultiTypeEsField == false && isDateTime(fa.dataType())) {
516+
// Extract min/max from SearchStats
517+
DataType fieldType = fa.dataType();
518+
String fieldName = fa.fieldName();
519+
var min = searchStats.min(fieldName);
520+
var max = searchStats.max(fieldName);
521+
// If min/max is available create rounding with them
522+
if (min != null && max != null && buckets().foldable()) {
523+
// System.out.println("field: " + fieldName + ", min string: " + dateWithTypeToString((Long) min, fieldType));
524+
// System.out.println("field: " + fieldName + ", max string: " + dateWithTypeToString((Long) max, fieldType));
525+
Rounding.Prepared rounding = getDateRounding(FoldContext.small(), (Long) min, (Long) max);
526+
// createRounding(foldedInterval, DEFAULT_TZ, (Long) min, (Long) max);
527+
long[] roundingPoints = rounding.fixedRoundingPoints();
528+
// TODO do we support date_nanos? It seems like prepare(long minUtcMillis, long maxUtcMillis) takes millis only
529+
// the min/max long values for date and date_nanos are correct, however the roundingPoints for date_nanos is null
530+
// System.out.println("roundingPoints = " + Arrays.toString(roundingPoints));
531+
if (roundingPoints == null) {
532+
return null; // TODO log this case
533+
}
534+
// Convert to round_to function with the roundings
535+
List<Expression> points = Arrays.stream(roundingPoints)
536+
.mapToObj(l -> new Literal(Source.EMPTY, l, fieldType))
537+
.collect(Collectors.toList());
538+
return new RoundTo(source(), field(), points);
539+
}
540+
}
541+
return null;
542+
}
491543
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateTrunc.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
5050
import static org.elasticsearch.xpack.esql.core.type.DataType.DATETIME;
5151
import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_NANOS;
52+
import static org.elasticsearch.xpack.esql.core.type.DataType.isDateTime;
5253

5354
public class DateTrunc extends EsqlScalarFunction implements SurrogateExpression {
5455
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
@@ -277,21 +278,21 @@ public Expression surrogate() { // there is no substitute without SearchStats
277278

278279
@Override
279280
public Expression surrogate(SearchStats searchStats) {
280-
if (field() instanceof FieldAttribute fa && fa.field() instanceof MultiTypeEsField == false) {
281+
if (field() instanceof FieldAttribute fa && fa.field() instanceof MultiTypeEsField == false && isDateTime(fa.dataType())) {
281282
// Extract min/max from SearchStats
282283
DataType fieldType = fa.dataType();
283284
String fieldName = fa.fieldName();
284285
var min = searchStats.min(fieldName);
285286
var max = searchStats.max(fieldName);
286287
// If min/max is available create rounding with them
287288
if (min != null && max != null && interval().foldable()) {
289+
// System.out.println("field: "+ fieldName + ", min string: " + dateWithTypeToString((Long) min, fieldType));
290+
// System.out.println("field: "+ fieldName + ", max string: " + dateWithTypeToString((Long) max, fieldType));
288291
Object foldedInterval = interval().fold(FoldContext.small() /* TODO remove me */);
289292
Rounding.Prepared rounding = createRounding(foldedInterval, DEFAULT_TZ, (Long) min, (Long) max);
290293
long[] roundingPoints = rounding.fixedRoundingPoints();
291294
// TODO do we support date_nanos? It seems like prepare(long minUtcMillis, long maxUtcMillis) takes millis only
292295
// the min/max long values for date and date_nanos are correct, however the roundingPoints for date_nanos is null
293-
// System.out.println("min string: " + dateWithTypeToString((Long) min, fieldType));
294-
// System.out.println("max string: " + dateWithTypeToString((Long) max, fieldType));
295296
// System.out.println("field name = " + fieldName + ", min = " + min + ", max = " + max + ", roundingPoints = " +
296297
// Arrays.toString(roundingPoints));
297298
if (roundingPoints == null) {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/SubstituteSurrogateExpressionsWithSearchStats.java

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,29 +8,31 @@
88
package org.elasticsearch.xpack.esql.optimizer.rules.logical.local;
99

1010
import org.elasticsearch.xpack.esql.core.expression.Expression;
11+
import org.elasticsearch.xpack.esql.core.expression.function.Function;
1112
import org.elasticsearch.xpack.esql.expression.SurrogateExpression;
12-
import org.elasticsearch.xpack.esql.expression.function.scalar.date.DateTrunc;
1313
import org.elasticsearch.xpack.esql.optimizer.LocalLogicalOptimizerContext;
14-
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;
14+
import org.elasticsearch.xpack.esql.plan.logical.Eval;
1515
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
16+
import org.elasticsearch.xpack.esql.rule.ParameterizedRule;
1617
import org.elasticsearch.xpack.esql.stats.SearchStats;
1718

18-
public class SubstituteSurrogateExpressionsWithSearchStats extends OptimizerRules.ParameterizedOptimizerRule<
19+
public class SubstituteSurrogateExpressionsWithSearchStats extends ParameterizedRule<
20+
LogicalPlan,
1921
LogicalPlan,
2022
LocalLogicalOptimizerContext> {
21-
public SubstituteSurrogateExpressionsWithSearchStats() {
22-
super(OptimizerRules.TransformDirection.UP);
23-
}
2423

2524
@Override
26-
protected LogicalPlan rule(LogicalPlan plan, LocalLogicalOptimizerContext context) {
27-
return plan.transformExpressionsUp(DateTrunc.class, e -> rule(e, context.searchStats()));
25+
public LogicalPlan apply(LogicalPlan plan, LocalLogicalOptimizerContext context) {
26+
return plan.transformUp(
27+
Eval.class,
28+
eval -> eval.transformExpressionsOnly(Function.class, f -> substituteDateTruncBucketWithRoundTo(f, context.searchStats()))
29+
);
2830
}
2931

3032
/**
3133
* Perform the actual substitution.
3234
*/
33-
public static Expression rule(Expression e, SearchStats searchStats) {
35+
private static Expression substituteDateTruncBucketWithRoundTo(Expression e, SearchStats searchStats) {
3436
if (e instanceof SurrogateExpression s && searchStats != null) {
3537
Expression surrogate = s.surrogate(searchStats);
3638
if (surrogate != null) {

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,32 @@ public void testSubstituteDateTruncInAggWithRoundTo() {
696696
EsRelation relation = as(eval.child(), EsRelation.class);
697697
}
698698

699+
public void testSubstituteBucketInAggWithRoundTo() {
700+
var plan = plan("""
701+
from test
702+
| stats count(*) by x = bucket(hire_date, 1 day)
703+
""");
704+
// create a SearchStats with min and max millis
705+
Map<String, Object> minValue = Map.of("hire_date", 1697804103360L); // 2023-10-20T12:15:03.360Z
706+
Map<String, Object> maxValue = Map.of("hire_date", 1698069301543L); // 2023-10-23T13:55:01.543Z
707+
SearchStats searchStats = new EsqlTestUtils.TestSearchStatsWithMinMax(minValue, maxValue);
708+
709+
LogicalPlan localPlan = localPlan(plan, searchStats);
710+
Limit limit = as(localPlan, Limit.class);
711+
Aggregate aggregate = as(limit.child(), Aggregate.class);
712+
Eval eval = as(aggregate.child(), Eval.class);
713+
List<Alias> fields = eval.fields();
714+
assertEquals(1, fields.size());
715+
Alias a = fields.get(0);
716+
assertEquals("x", a.name());
717+
RoundTo roundTo = as(a.child(), RoundTo.class);
718+
FieldAttribute fa = as(roundTo.field(), FieldAttribute.class);
719+
assertEquals("hire_date", fa.name());
720+
assertEquals(DATETIME, fa.dataType());
721+
assertEquals(4, roundTo.points().size()); // 4 days
722+
EsRelation relation = as(eval.child(), EsRelation.class);
723+
}
724+
699725
private IsNotNull isNotNull(Expression field) {
700726
return new IsNotNull(EMPTY, field);
701727
}

0 commit comments

Comments
 (0)