Skip to content

Commit a41045d

Browse files
ncordonjan-elastic
andauthored
[9.1] Propagates filter() to aggregation functions' surrogates (#134461)
--------- Co-authored-by: Jan Kuipers <[email protected]> Co-authored-by: Jan Kuipers <[email protected]>
1 parent fdf7b35 commit a41045d

File tree

11 files changed

+244
-12
lines changed

11 files changed

+244
-12
lines changed

docs/changelog/134461.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 134461
2+
summary: Propagates filter() to aggregation functions' surrogates
3+
area: Aggregations
4+
type: bug
5+
issues:
6+
- 134380

x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3122,3 +3122,182 @@ FROM employees
31223122
m:datetime | x:integer | d:boolean
31233123
1999-04-30T00:00:00.000Z | 2 | true
31243124
;
3125+
3126+
sumWithConditions
3127+
required_capability: stats_with_filtered_surrogate_fixed
3128+
required_capability: aggregate_metric_double_convert_to
3129+
3130+
FROM employees
3131+
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(1)
3132+
| STATS sum1 = SUM(1),
3133+
sum2 = SUM(1) WHERE emp_no == 10080,
3134+
sum3 = SUM(1) WHERE emp_no < 10080,
3135+
sum4 = SUM(1) WHERE emp_no >= 10080,
3136+
sum5 = SUM(agg_metric),
3137+
sum6 = SUM(agg_metric) WHERE emp_no == 10080
3138+
;
3139+
3140+
sum1:long | sum2:long | sum3:long | sum4:long | sum5:double | sum6:double
3141+
100 | 1 | 79 | 21 | 100.0 | 1.0
3142+
;
3143+
3144+
weightedAvgWithConditions
3145+
required_capability: stats_with_filtered_surrogate_fixed
3146+
3147+
ROW x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
3148+
| MV_EXPAND x
3149+
| STATS w_avg1 = WEIGHTED_AVG(x, 1) WHERE x == 5,
3150+
w_avg2 = WEIGHTED_AVG(x, x) WHERE x == 5,
3151+
w_avg3 = WEIGHTED_AVG(x, 2) WHERE x <= 5,
3152+
w_avg4 = WEIGHTED_AVG(x, x) WHERE x > 5,
3153+
w_avg5 = WEIGHTED_AVG([1,2,3], 1),
3154+
w_avg6 = WEIGHTED_AVG([1,2,3], 1) WHERE x == 5
3155+
;
3156+
3157+
w_avg1:double | w_avg2:double | w_avg3:double | w_avg4:double | w_avg5:double | w_avg6:double
3158+
5.0 | 5.0 | 3.0 | 8.25 | 2.0 | 2.0
3159+
;
3160+
3161+
maxWithConditions
3162+
required_capability: stats_with_filtered_surrogate_fixed
3163+
required_capability: aggregate_metric_double_convert_to
3164+
3165+
ROW x = [1, 2, 3, 4, 5]
3166+
| MV_EXPAND x
3167+
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
3168+
| STATS max1 = MAX(agg_metric) WHERE x <= 3,
3169+
max2 = MAX(agg_metric),
3170+
max3 = MAX(x),
3171+
max4 = MAX(x) WHERE x > 3
3172+
;
3173+
3174+
max1:double | max2:double | max3:integer | max4:integer
3175+
3.0 | 5.0 | 5 | 5
3176+
;
3177+
3178+
minWithConditions
3179+
required_capability: stats_with_filtered_surrogate_fixed
3180+
required_capability: aggregate_metric_double_convert_to
3181+
3182+
ROW x = [1, 2, 3, 4, 5]
3183+
| MV_EXPAND x
3184+
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
3185+
| STATS min1 = MIN(agg_metric) WHERE x <= 3,
3186+
min2 = MIN(agg_metric),
3187+
min3 = MIN(x),
3188+
min4 = MIN(x) WHERE x > 3
3189+
;
3190+
3191+
min1:double | min2:double | min3:integer | min4:integer
3192+
1.0 | 1.0 | 1 | 4
3193+
;
3194+
3195+
countWithConditions
3196+
required_capability: stats_with_filtered_surrogate_fixed
3197+
required_capability: aggregate_metric_double_convert_to
3198+
3199+
ROW x = [1, 2, 3, 4, 5]
3200+
| MV_EXPAND x
3201+
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
3202+
| STATS count1 = COUNT(x) WHERE x >= 3,
3203+
count2 = COUNT(x),
3204+
count3 = COUNT(agg_metric),
3205+
count4 = COUNT(agg_metric) WHERE x >=3,
3206+
count5 = COUNT(4) WHERE x >= 3,
3207+
count6 = COUNT(*) WHERE x >= 3,
3208+
count7 = COUNT([1,2,3]) WHERE x >= 3,
3209+
count8 = COUNT([1,2,3])
3210+
;
3211+
3212+
count1:long | count2:long | count3:long | count4:long | count5:long | count6:long | count7:long | count8:long
3213+
3 | 5 | 5 | 3 | 3 | 3 | 9 | 15
3214+
;
3215+
3216+
countDistinctWithConditions
3217+
required_capability: stats_with_filtered_surrogate_fixed
3218+
3219+
ROW x = [1, 2, 3, 4, 5]
3220+
| MV_EXPAND x
3221+
| STATS count1 = COUNT_DISTINCT(x) WHERE x <= 3,
3222+
count2 = COUNT_DISTINCT(x),
3223+
count3 = COUNT_DISTINCT(1) WHERE x <= 3,
3224+
count4 = COUNT_DISTINCT(1)
3225+
;
3226+
3227+
count1:long | count2:long | count3:long | count4:long
3228+
3 | 5 | 1 | 1
3229+
;
3230+
3231+
avgWithConditions
3232+
required_capability: stats_with_filtered_surrogate_fixed
3233+
required_capability: aggregate_metric_double_convert_to
3234+
3235+
ROW x = [1, 2, 3, 4, 5]
3236+
| MV_EXPAND x
3237+
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
3238+
| STATS avg1 = AVG(x) WHERE x <= 3,
3239+
avg2 = AVG(x),
3240+
avg3 = AVG(agg_metric) WHERE x <=3,
3241+
avg4 = AVG(agg_metric)
3242+
;
3243+
3244+
avg1:double | avg2:double | avg3:double | avg4:double
3245+
2.0 | 3.0 | 2.0 | 3.0
3246+
;
3247+
3248+
percentileWithConditions
3249+
required_capability: stats_with_filtered_surrogate_fixed
3250+
3251+
ROW x = [1, 2, 3, 4, 5]
3252+
| MV_EXPAND x
3253+
| STATS percentile1 = PERCENTILE(x, 50) WHERE x <= 3,
3254+
percentile2 = PERCENTILE(x, 50)
3255+
;
3256+
3257+
percentile1:double | percentile2:double
3258+
2.0 | 3.0
3259+
;
3260+
3261+
medianWithConditions
3262+
required_capability: stats_with_filtered_surrogate_fixed
3263+
3264+
ROW x = [1, 2, 3, 4, 5]
3265+
| MV_EXPAND x
3266+
| STATS median1 = MEDIAN(x) WHERE x <= 3,
3267+
median2 = MEDIAN(x),
3268+
median3 = MEDIAN([5,6,7,8,9]) WHERE x <= 3,
3269+
median4 = MEDIAN([5,6,7,8,9])
3270+
;
3271+
3272+
median1:double | median2:double | median3:double | median4:double
3273+
2.0 | 3.0 | 7.0 | 7.0
3274+
;
3275+
3276+
medianAbsoluteDeviationWithConditions
3277+
required_capability: stats_with_filtered_surrogate_fixed
3278+
3279+
ROW x = [1, 3, 4, 7, 11, 18]
3280+
| MV_EXPAND x
3281+
| STATS median_dev1 = MEDIAN_ABSOLUTE_DEVIATION(x) WHERE x <= 3,
3282+
median_dev2 = MEDIAN_ABSOLUTE_DEVIATION(x),
3283+
median_dev3 = MEDIAN_ABSOLUTE_DEVIATION([3, 11, 14, 25]) WHERE x <= 3,
3284+
median_dev4 = MEDIAN_ABSOLUTE_DEVIATION([3, 11, 14, 25])
3285+
;
3286+
3287+
median_dev1:double | median_dev2:double | median_dev3:double | median_dev4:double
3288+
1.0 | 3.5 | 5.5 | 5.5
3289+
;
3290+
3291+
topWithConditions
3292+
required_capability: stats_with_filtered_surrogate_fixed
3293+
3294+
FROM employees
3295+
| STATS min1 = TOP(emp_no, 1, "ASC") WHERE emp_no > 10010,
3296+
min2 = TOP(emp_no, 2, "ASC") WHERE emp_no > 10010,
3297+
max1 = TOP(emp_no, 1, "DESC") WHERE emp_no < 10080,
3298+
max2 = TOP(emp_no, 2, "DESC") WHERE emp_no < 10080
3299+
;
3300+
3301+
min1:integer | min2:integer | max1:integer | max2:integer
3302+
10011 | [10011, 10012] | 10079 | [10079, 10078]
3303+
;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1267,7 +1267,15 @@ public enum Cap {
12671267
/**
12681268
* Support correct counting of skipped shards.
12691269
*/
1270-
CORRECT_SKIPPED_SHARDS_COUNT;
1270+
CORRECT_SKIPPED_SHARDS_COUNT,
1271+
1272+
/**
1273+
* Bugfix for STATS {{expression}} WHERE {{condition}} when the
1274+
* expression is replaced by something else on planning
1275+
* e.g. STATS SUM(1) WHERE x==3 is replaced by
1276+
* STATS MV_SUM(const)*COUNT(*) WHERE x == 3.
1277+
*/
1278+
STATS_WITH_FILTERED_SURROGATE_FIXED;
12711279

12721280
private final boolean enabled;
12731281

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Count.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,11 @@ public Expression surrogate() {
146146
var s = source();
147147
var field = field();
148148
if (field.dataType() == DataType.AGGREGATE_METRIC_DOUBLE) {
149-
return new Sum(s, FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.COUNT));
149+
return new Sum(
150+
s,
151+
FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.COUNT),
152+
filter()
153+
);
150154
}
151155

152156
if (field.foldable()) {
@@ -163,7 +167,7 @@ public Expression surrogate() {
163167
return new Mul(
164168
s,
165169
new Coalesce(s, new MvCount(s, field), List.of(new Literal(s, 0, DataType.INTEGER))),
166-
new Count(s, Literal.keyword(s, StringUtils.WILDCARD))
170+
new Count(s, Literal.keyword(s, StringUtils.WILDCARD), filter())
167171
);
168172
}
169173

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,11 @@ public final AggregatorFunctionSupplier supplier() {
153153
@Override
154154
public Expression surrogate() {
155155
if (field().dataType() == DataType.AGGREGATE_METRIC_DOUBLE) {
156-
return new Max(source(), FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MAX));
156+
return new Max(
157+
source(),
158+
FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MAX),
159+
filter()
160+
);
157161
}
158162
return field().foldable() ? new MvMax(source(), field()) : null;
159163
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Median.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,6 @@ public Expression surrogate() {
117117

118118
return field.foldable()
119119
? new MvMedian(s, new ToDouble(s, field))
120-
: new Percentile(source(), field(), new Literal(source(), (int) QuantileStates.MEDIAN, DataType.INTEGER));
120+
: new Percentile(source(), field(), filter(), new Literal(source(), (int) QuantileStates.MEDIAN, DataType.INTEGER));
121121
}
122122
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,11 @@ public final AggregatorFunctionSupplier supplier() {
153153
@Override
154154
public Expression surrogate() {
155155
if (field().dataType() == DataType.AGGREGATE_METRIC_DOUBLE) {
156-
return new Min(source(), FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MIN));
156+
return new Min(
157+
source(),
158+
FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MIN),
159+
filter()
160+
);
157161
}
158162
return field().foldable() ? new MvMin(source(), field()) : null;
159163
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,14 @@ public Expression surrogate() {
139139
var s = source();
140140
var field = field();
141141
if (field.dataType() == AGGREGATE_METRIC_DOUBLE) {
142-
return new Sum(s, FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.SUM));
142+
return new Sum(
143+
s,
144+
FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.SUM),
145+
filter()
146+
);
143147
}
144148

145149
// SUM(const) is equivalent to MV_SUM(const)*COUNT(*).
146-
return field.foldable() ? new Mul(s, new MvSum(s, field), new Count(s, Literal.keyword(s, StringUtils.WILDCARD))) : null;
150+
return field.foldable() ? new Mul(s, new MvSum(s, field), new Count(s, Literal.keyword(s, StringUtils.WILDCARD), filter())) : null;
147151
}
148152
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,9 +218,9 @@ public Expression surrogate() {
218218

219219
if (limitValue() == 1) {
220220
if (orderValue()) {
221-
return new Min(s, field());
221+
return new Min(s, field(), filter());
222222
} else {
223-
return new Max(s, field());
223+
return new Max(s, field(), filter());
224224
}
225225
}
226226

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvg.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,9 @@ public Expression surrogate() {
160160
return new MvAvg(s, field);
161161
}
162162
if (weight.foldable()) {
163-
return new Div(s, new Sum(s, field), new Count(s, field), dataType());
163+
return new Div(s, new Sum(s, field, filter()), new Count(s, field, filter()), dataType());
164164
} else {
165-
return new Div(s, new Sum(s, new Mul(s, field, weight)), new Sum(s, weight), dataType());
165+
return new Div(s, new Sum(s, new Mul(s, field, weight), filter()), new Sum(s, weight, filter()), dataType());
166166
}
167167
}
168168

0 commit comments

Comments
 (0)