Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/134461.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 134461
summary: Propagates filter() to aggregation functions' surrogates
area: Aggregations
type: bug
issues:
- 134380
159 changes: 159 additions & 0 deletions x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec
Original file line number Diff line number Diff line change
Expand Up @@ -3097,3 +3097,162 @@ ROW a = [1,2,3], b = 5
STD_DEV(a):double | STD_DEV(b):double
0.816496580927726 | 0.0
;

sumWithConditions
required_capability: stats_with_filtered_surrogate_fixed

FROM employees
| STATS sum1 = SUM(1),
sum2 = SUM(1) WHERE emp_no == 10080,
sum3 = SUM(1) WHERE emp_no < 10080,
sum4 = SUM(1) WHERE emp_no >= 10080
;

sum1:long | sum2:long | sum3:long | sum4:long
100 | 1 | 79 | 21
;

weightedAvgWithConditions
required_capability: stats_with_filtered_surrogate_fixed

ROW x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
| MV_EXPAND x
| STATS w_avg1 = WEIGHTED_AVG(x, 1) WHERE x == 5,
w_avg2 = WEIGHTED_AVG(x, x) WHERE x == 5,
w_avg3 = WEIGHTED_AVG(x, 2) WHERE x <= 5,
w_avg4 = WEIGHTED_AVG(x, x) WHERE x > 5,
w_avg5 = WEIGHTED_AVG([1,2,3], 1),
w_avg6 = WEIGHTED_AVG([1,2,3], 1) WHERE x == 5
;

w_avg1:double | w_avg2:double | w_avg3:double | w_avg4:double | w_avg5:double | w_avg6:double
5.0 | 5.0 | 3.0 | 8.25 | 2.0 | 2.0
;

maxWithConditions
required_capability: stats_with_filtered_surrogate_fixed

ROW x = [1, 2, 3, 4, 5]
| MV_EXPAND x
| STATS max1 = MAX(x),
max2 = MAX(x) WHERE x > 3
;

max1:integer | max2:integer
5 | 5
;

minWithConditions
required_capability: stats_with_filtered_surrogate_fixed

ROW x = [1, 2, 3, 4, 5]
| MV_EXPAND x
| STATS min1 = MIN(x),
min2 = MIN(x) WHERE x > 3
;

min1:integer | min2:integer
1 | 4
;

countWithConditions
required_capability: stats_with_filtered_surrogate_fixed

ROW x = [1, 2, 3, 4, 5]
| MV_EXPAND x
| STATS count1 = COUNT(x) WHERE x >= 3,
count2 = COUNT(x),
count3 = COUNT(4) WHERE x >= 3,
count4 = COUNT(*) WHERE x >= 3,
count5 = COUNT([1,2,3]) WHERE x >= 3,
count6 = COUNT([1,2,3])
;

count1:long | count2:long | count3:long | count4:long | count5:long | count6:long
3 | 5 | 3 | 3 | 9 | 15
;

countDistinctWithConditions
required_capability: stats_with_filtered_surrogate_fixed

ROW x = [1, 2, 3, 4, 5]
| MV_EXPAND x
| STATS count1 = COUNT_DISTINCT(x) WHERE x <= 3,
count2 = COUNT_DISTINCT(x),
count3 = COUNT_DISTINCT(1) WHERE x <= 3,
count4 = COUNT_DISTINCT(1)
;

count1:long | count2:long | count3:long | count4:long
3 | 5 | 1 | 1
;

avgWithConditions
required_capability: stats_with_filtered_surrogate_fixed

ROW x = [1, 2, 3, 4, 5]
| MV_EXPAND x
| STATS avg1 = AVG(x) WHERE x <= 3,
avg2 = AVG(x)
;

avg1:double | avg2:double
2.0 | 3.0
;

percentileWithConditions
required_capability: stats_with_filtered_surrogate_fixed

ROW x = [1, 2, 3, 4, 5]
| MV_EXPAND x
| STATS percentile1 = PERCENTILE(x, 50) WHERE x <= 3,
percentile2 = PERCENTILE(x, 50)
;

percentile1:double | percentile2:double
2.0 | 3.0
;

medianWithConditions
required_capability: stats_with_filtered_surrogate_fixed

ROW x = [1, 2, 3, 4, 5]
| MV_EXPAND x
| STATS median1 = MEDIAN(x) WHERE x <= 3,
median2 = MEDIAN(x),
median3 = MEDIAN([5,6,7,8,9]) WHERE x <= 3,
median4 = MEDIAN([5,6,7,8,9])
;

median1:double | median2:double | median3:double | median4:double
2.0 | 3.0 | 7.0 | 7.0
;

medianAbsoluteDeviationWithConditions
required_capability: stats_with_filtered_surrogate_fixed

ROW x = [1, 3, 4, 7, 11, 18]
| MV_EXPAND x
| STATS median_dev1 = MEDIAN_ABSOLUTE_DEVIATION(x) WHERE x <= 3,
median_dev2 = MEDIAN_ABSOLUTE_DEVIATION(x),
median_dev3 = MEDIAN_ABSOLUTE_DEVIATION([3, 11, 14, 25]) WHERE x <= 3,
median_dev4 = MEDIAN_ABSOLUTE_DEVIATION([3, 11, 14, 25])
;

median_dev1:double | median_dev2:double | median_dev3:double | median_dev4:double
1.0 | 3.5 | 5.5 | 5.5
;

topWithConditions
required_capability: stats_with_filtered_surrogate_fixed

FROM employees
| STATS min1 = TOP(emp_no, 1, "ASC") WHERE emp_no > 10010,
min2 = TOP(emp_no, 2, "ASC") WHERE emp_no > 10010,
max1 = TOP(emp_no, 1, "DESC") WHERE emp_no < 10080,
max2 = TOP(emp_no, 2, "DESC") WHERE emp_no < 10080
;

min1:integer | min2:integer | max1:integer | max2:integer
10011 | [10011, 10012] | 10079 | [10079, 10078]
;
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,15 @@ public enum Cap {
* Support for the mv_expand target attribute should be retained in its original position.
* see <a href="https://github.com/elastic/elasticsearch/issues/129000"> ES|QL: inconsistent column order #129000 </a>
*/
FIX_MV_EXPAND_INCONSISTENT_COLUMN_ORDER;
FIX_MV_EXPAND_INCONSISTENT_COLUMN_ORDER,

/**
* Bugfix for STATS {{expression}} WHERE {{condition}} when the
* expression is replaced by something else on planning
* e.g. STATS SUM(1) WHERE x==3 is replaced by
* STATS MV_SUM(const)*COUNT(*) WHERE x == 3.
*/
STATS_WITH_FILTERED_SURROGATE_FIXED;

private final boolean enabled;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,11 @@ public Expression surrogate() {
var s = source();
var field = field();
if (field.dataType() == DataType.AGGREGATE_METRIC_DOUBLE) {
return new Sum(s, FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.COUNT));
return new Sum(
s,
FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.COUNT),
filter()
);
}

if (field.foldable()) {
Expand All @@ -163,7 +167,7 @@ public Expression surrogate() {
return new Mul(
s,
new Coalesce(s, new MvCount(s, field), List.of(new Literal(s, 0, DataType.INTEGER))),
new Count(s, new Literal(s, StringUtils.WILDCARD, DataType.KEYWORD))
new Count(s, new Literal(s, StringUtils.WILDCARD, DataType.KEYWORD), filter())
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,11 @@ public final AggregatorFunctionSupplier supplier() {
@Override
public Expression surrogate() {
if (field().dataType() == DataType.AGGREGATE_METRIC_DOUBLE) {
return new Max(source(), FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MAX));
return new Max(
source(),
FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MAX),
filter()
);
}
return field().foldable() ? new MvMax(source(), field()) : null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,6 @@ public Expression surrogate() {

return field.foldable()
? new MvMedian(s, new ToDouble(s, field))
: new Percentile(source(), field(), new Literal(source(), (int) QuantileStates.MEDIAN, DataType.INTEGER));
: new Percentile(source(), field(), filter(), new Literal(source(), (int) QuantileStates.MEDIAN, DataType.INTEGER));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,11 @@ public final AggregatorFunctionSupplier supplier() {
@Override
public Expression surrogate() {
if (field().dataType() == DataType.AGGREGATE_METRIC_DOUBLE) {
return new Min(source(), FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MIN));
return new Min(
source(),
FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MIN),
filter()
);
}
return field().foldable() ? new MvMin(source(), field()) : null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,16 @@ public Expression surrogate() {
var s = source();
var field = field();
if (field.dataType() == AGGREGATE_METRIC_DOUBLE) {
return new Sum(s, FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.SUM));
return new Sum(
s,
FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.SUM),
filter()
);
}

// SUM(const) is equivalent to MV_SUM(const)*COUNT(*).
return field.foldable()
? new Mul(s, new MvSum(s, field), new Count(s, new Literal(s, StringUtils.WILDCARD, DataType.KEYWORD)))
? new Mul(s, new MvSum(s, field), new Count(s, new Literal(s, StringUtils.WILDCARD, DataType.KEYWORD), filter()))
: null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,9 @@ public Expression surrogate() {

if (limitValue() == 1) {
if (orderValue()) {
return new Min(s, field());
return new Min(s, field(), filter());
} else {
return new Max(s, field());
return new Max(s, field(), filter());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,9 @@ public Expression surrogate() {
return new MvAvg(s, field);
}
if (weight.foldable()) {
return new Div(s, new Sum(s, field), new Count(s, field), dataType());
return new Div(s, new Sum(s, field, filter()), new Count(s, field, filter()), dataType());
} else {
return new Div(s, new Sum(s, new Mul(s, field, weight)), new Sum(s, weight), dataType());
return new Div(s, new Sum(s, new Mul(s, field, weight), filter()), new Sum(s, weight, filter()), dataType());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,29 @@ public void testFold() {
}, this::evaluate);
}

public void testSurrogateHasFilter() {
Expression expression = randomFrom(
buildLiteralExpression(testCase),
buildDeepCopyOfFieldExpression(testCase),
buildFieldExpression(testCase)
);

assumeTrue("expression should have no type errors", expression.typeResolved().resolved());

if (expression instanceof AggregateFunction && expression instanceof SurrogateExpression) {
var filter = ((AggregateFunction) expression).filter();

var surrogate = ((SurrogateExpression) expression).surrogate();

if (surrogate != null) {
surrogate.forEachDown(AggregateFunction.class, child -> {
var surrogateFilter = child.filter();
assertEquals(filter, surrogateFilter);
});
}
}
}

private void aggregateSingleMode(Expression expression) {
Object result;
try (var aggregator = aggregator(expression, initialInputChannels(), AggregatorMode.SINGLE)) {
Expand Down