Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/134461.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 134461
summary: Propagates filter() to aggregation functions' surrogates
area: Aggregations
type: bug
issues:
- 134380
179 changes: 179 additions & 0 deletions x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec
Original file line number Diff line number Diff line change
Expand Up @@ -3121,3 +3121,182 @@ FROM employees
m:datetime | x:integer | d:boolean
1999-04-30T00:00:00.000Z | 2 | true
;

sumWithConditions
required_capability: stats_with_filtered_surrogate_fixed
required_capability: aggregate_metric_double_convert_to

FROM employees
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(1)
| STATS sum1 = SUM(1),
sum2 = SUM(1) WHERE emp_no == 10080,
sum3 = SUM(1) WHERE emp_no < 10080,
sum4 = SUM(1) WHERE emp_no >= 10080,
sum5 = SUM(agg_metric),
sum6 = SUM(agg_metric) WHERE emp_no == 10080
;

sum1:long | sum2:long | sum3:long | sum4:long | sum5:double | sum6:double
100 | 1 | 79 | 21 | 100.0 | 1.0
;

weightedAvgWithConditions
required_capability: stats_with_filtered_surrogate_fixed

ROW x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
| MV_EXPAND x
| STATS w_avg1 = WEIGHTED_AVG(x, 1) WHERE x == 5,
w_avg2 = WEIGHTED_AVG(x, x) WHERE x == 5,
w_avg3 = WEIGHTED_AVG(x, 2) WHERE x <= 5,
w_avg4 = WEIGHTED_AVG(x, x) WHERE x > 5,
w_avg5 = WEIGHTED_AVG([1,2,3], 1),
w_avg6 = WEIGHTED_AVG([1,2,3], 1) WHERE x == 5
;

w_avg1:double | w_avg2:double | w_avg3:double | w_avg4:double | w_avg5:double | w_avg6:double
5.0 | 5.0 | 3.0 | 8.25 | 2.0 | 2.0
;

maxWithConditions
required_capability: stats_with_filtered_surrogate_fixed
required_capability: aggregate_metric_double_convert_to

ROW x = [1, 2, 3, 4, 5]
| MV_EXPAND x
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
| STATS max1 = MAX(agg_metric) WHERE x <= 3,
max2 = MAX(agg_metric),
max3 = MAX(x),
max4 = MAX(x) WHERE x > 3
;

max1:double | max2:double | max3:integer | max4:integer
3.0 | 5.0 | 5 | 5
;

minWithConditions
required_capability: stats_with_filtered_surrogate_fixed
required_capability: aggregate_metric_double_convert_to

ROW x = [1, 2, 3, 4, 5]
| MV_EXPAND x
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
| STATS min1 = MIN(agg_metric) WHERE x <= 3,
min2 = MIN(agg_metric),
min3 = MIN(x),
min4 = MIN(x) WHERE x > 3
;

min1:double | min2:double | min3:integer | min4:integer
1.0 | 1.0 | 1 | 4
;

countWithConditions
required_capability: stats_with_filtered_surrogate_fixed
required_capability: aggregate_metric_double_convert_to

ROW x = [1, 2, 3, 4, 5]
| MV_EXPAND x
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
| STATS count1 = COUNT(x) WHERE x >= 3,
count2 = COUNT(x),
count3 = COUNT(agg_metric),
count4 = COUNT(agg_metric) WHERE x >=3,
count5 = COUNT(4) WHERE x >= 3,
count6 = COUNT(*) WHERE x >= 3,
count7 = COUNT([1,2,3]) WHERE x >= 3,
count8 = COUNT([1,2,3])
;

count1:long | count2:long | count3:long | count4:long | count5:long | count6:long | count7:long | count8:long
3 | 5 | 5 | 3 | 3 | 3 | 9 | 15
;

countDistinctWithConditions
required_capability: stats_with_filtered_surrogate_fixed

ROW x = [1, 2, 3, 4, 5]
| MV_EXPAND x
| STATS count1 = COUNT_DISTINCT(x) WHERE x <= 3,
count2 = COUNT_DISTINCT(x),
count3 = COUNT_DISTINCT(1) WHERE x <= 3,
count4 = COUNT_DISTINCT(1)
;

count1:long | count2:long | count3:long | count4:long
3 | 5 | 1 | 1
;

avgWithConditions
required_capability: stats_with_filtered_surrogate_fixed
required_capability: aggregate_metric_double_convert_to

ROW x = [1, 2, 3, 4, 5]
| MV_EXPAND x
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
| STATS avg1 = AVG(x) WHERE x <= 3,
avg2 = AVG(x),
avg3 = AVG(agg_metric) WHERE x <=3,
avg4 = AVG(agg_metric)
;

avg1:double | avg2:double | avg3:double | avg4:double
2.0 | 3.0 | 2.0 | 3.0
;

percentileWithConditions
required_capability: stats_with_filtered_surrogate_fixed

ROW x = [1, 2, 3, 4, 5]
| MV_EXPAND x
| STATS percentile1 = PERCENTILE(x, 50) WHERE x <= 3,
percentile2 = PERCENTILE(x, 50)
;

percentile1:double | percentile2:double
2.0 | 3.0
;

medianWithConditions
required_capability: stats_with_filtered_surrogate_fixed

ROW x = [1, 2, 3, 4, 5]
| MV_EXPAND x
| STATS median1 = MEDIAN(x) WHERE x <= 3,
median2 = MEDIAN(x),
median3 = MEDIAN([5,6,7,8,9]) WHERE x <= 3,
median4 = MEDIAN([5,6,7,8,9])
;

median1:double | median2:double | median3:double | median4:double
2.0 | 3.0 | 7.0 | 7.0
;

medianAbsoluteDeviationWithConditions
required_capability: stats_with_filtered_surrogate_fixed

ROW x = [1, 3, 4, 7, 11, 18]
| MV_EXPAND x
| STATS median_dev1 = MEDIAN_ABSOLUTE_DEVIATION(x) WHERE x <= 3,
median_dev2 = MEDIAN_ABSOLUTE_DEVIATION(x),
median_dev3 = MEDIAN_ABSOLUTE_DEVIATION([3, 11, 14, 25]) WHERE x <= 3,
median_dev4 = MEDIAN_ABSOLUTE_DEVIATION([3, 11, 14, 25])
;

median_dev1:double | median_dev2:double | median_dev3:double | median_dev4:double
1.0 | 3.5 | 5.5 | 5.5
;

topWithConditions
required_capability: stats_with_filtered_surrogate_fixed

FROM employees
| STATS min1 = TOP(emp_no, 1, "ASC") WHERE emp_no > 10010,
min2 = TOP(emp_no, 2, "ASC") WHERE emp_no > 10010,
max1 = TOP(emp_no, 1, "DESC") WHERE emp_no < 10080,
max2 = TOP(emp_no, 2, "DESC") WHERE emp_no < 10080
;

min1:integer | min2:integer | max1:integer | max2:integer
10011 | [10011, 10012] | 10079 | [10079, 10078]
;
Original file line number Diff line number Diff line change
Expand Up @@ -1052,7 +1052,14 @@ public enum Cap {
/**
* Support correct counting of skipped shards.
*/
CORRECT_SKIPPED_SHARDS_COUNT;
CORRECT_SKIPPED_SHARDS_COUNT,
/**
* Bugfix for STATS {{expression}} WHERE {{condition}} when the
* expression is replaced by something else on planning
* e.g. STATS SUM(1) WHERE x==3 is replaced by
* STATS MV_SUM(const)*COUNT(*) WHERE x == 3.
*/
STATS_WITH_FILTERED_SURROGATE_FIXED;

private final boolean enabled;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,11 @@ public Expression surrogate() {
var s = source();
var field = field();
if (field.dataType() == DataType.AGGREGATE_METRIC_DOUBLE) {
return new Sum(s, FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.COUNT));
return new Sum(
s,
FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.COUNT),
filter()
);
}

if (field.foldable()) {
Expand All @@ -162,7 +166,7 @@ public Expression surrogate() {
return new Mul(
s,
new Coalesce(s, new MvCount(s, field), List.of(new Literal(s, 0, DataType.INTEGER))),
new Count(s, Literal.keyword(s, StringUtils.WILDCARD))
new Count(s, Literal.keyword(s, StringUtils.WILDCARD), filter())
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,11 @@ public final AggregatorFunctionSupplier supplier() {
@Override
public Expression surrogate() {
if (field().dataType() == DataType.AGGREGATE_METRIC_DOUBLE) {
return new Max(source(), FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MAX));
return new Max(
source(),
FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MAX),
filter()
);
}
return field().foldable() ? new MvMax(source(), field()) : null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,6 @@ public Expression surrogate() {

return field.foldable()
? new MvMedian(s, new ToDouble(s, field))
: new Percentile(source(), field(), new Literal(source(), (int) QuantileStates.MEDIAN, DataType.INTEGER));
: new Percentile(source(), field(), filter(), new Literal(source(), (int) QuantileStates.MEDIAN, DataType.INTEGER));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,11 @@ public final AggregatorFunctionSupplier supplier() {
@Override
public Expression surrogate() {
if (field().dataType() == DataType.AGGREGATE_METRIC_DOUBLE) {
return new Min(source(), FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MIN));
return new Min(
source(),
FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MIN),
filter()
);
}
return field().foldable() ? new MvMin(source(), field()) : null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,14 @@ public Expression surrogate() {
var s = source();
var field = field();
if (field.dataType() == AGGREGATE_METRIC_DOUBLE) {
return new Sum(s, FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.SUM));
return new Sum(
s,
FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.SUM),
filter()
);
}

// SUM(const) is equivalent to MV_SUM(const)*COUNT(*).
return field.foldable() ? new Mul(s, new MvSum(s, field), new Count(s, Literal.keyword(s, StringUtils.WILDCARD))) : null;
return field.foldable() ? new Mul(s, new MvSum(s, field), new Count(s, Literal.keyword(s, StringUtils.WILDCARD), filter())) : null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,9 @@ public Expression surrogate() {

if (limitValue() == 1) {
if (orderValue()) {
return new Min(s, field());
return new Min(s, field(), filter());
} else {
return new Max(s, field());
return new Max(s, field(), filter());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ public Expression surrogate() {
return new MvAvg(s, field);
}
if (weight.foldable()) {
return new Div(s, new Sum(s, field), new Count(s, field), dataType());
return new Div(s, new Sum(s, field, filter()), new Count(s, field, filter()), dataType());
} else {
return new Div(s, new Sum(s, new Mul(s, field, weight)), new Sum(s, weight), dataType());
return new Div(s, new Sum(s, new Mul(s, field, weight), filter()), new Sum(s, weight, filter()), dataType());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,29 @@ public void testFold() {
}, this::evaluate);
}

public void testSurrogateHasFilter() {
Expression expression = randomFrom(
buildLiteralExpression(testCase),
buildDeepCopyOfFieldExpression(testCase),
buildFieldExpression(testCase)
);

assumeTrue("expression should have no type errors", expression.typeResolved().resolved());

if (expression instanceof AggregateFunction && expression instanceof SurrogateExpression) {
var filter = ((AggregateFunction) expression).filter();

var surrogate = ((SurrogateExpression) expression).surrogate();

if (surrogate != null) {
surrogate.forEachDown(AggregateFunction.class, child -> {
var surrogateFilter = child.filter();
assertEquals(filter, surrogateFilter);
});
}
}
}

private void aggregateSingleMode(Expression expression) {
Object result;
try (var aggregator = aggregator(expression, initialInputChannels(), AggregatorMode.SINGLE)) {
Expand Down