Skip to content

Commit 7406110

Browse files
ncordonjan-elastic
andauthored
[8.19] Propagates filter() to aggregation functions' surrogates (elastic#134461)
--------- Co-authored-by: Jan Kuipers <[email protected]> Co-authored-by: Jan Kuipers <[email protected]>
1 parent cb2f88d commit 7406110

File tree

11 files changed

+243
-12
lines changed

11 files changed

+243
-12
lines changed

docs/changelog/134461.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 134461
2+
summary: Propagates filter() to aggregation functions' surrogates
3+
area: Aggregations
4+
type: bug
5+
issues:
6+
- 134380

x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3121,3 +3121,182 @@ FROM employees
31213121
m:datetime | x:integer | d:boolean
31223122
1999-04-30T00:00:00.000Z | 2 | true
31233123
;
3124+
3125+
sumWithConditions
3126+
required_capability: stats_with_filtered_surrogate_fixed
3127+
required_capability: aggregate_metric_double_convert_to
3128+
3129+
FROM employees
3130+
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(1)
3131+
| STATS sum1 = SUM(1),
3132+
sum2 = SUM(1) WHERE emp_no == 10080,
3133+
sum3 = SUM(1) WHERE emp_no < 10080,
3134+
sum4 = SUM(1) WHERE emp_no >= 10080,
3135+
sum5 = SUM(agg_metric),
3136+
sum6 = SUM(agg_metric) WHERE emp_no == 10080
3137+
;
3138+
3139+
sum1:long | sum2:long | sum3:long | sum4:long | sum5:double | sum6:double
3140+
100 | 1 | 79 | 21 | 100.0 | 1.0
3141+
;
3142+
3143+
weightedAvgWithConditions
3144+
required_capability: stats_with_filtered_surrogate_fixed
3145+
3146+
ROW x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
3147+
| MV_EXPAND x
3148+
| STATS w_avg1 = WEIGHTED_AVG(x, 1) WHERE x == 5,
3149+
w_avg2 = WEIGHTED_AVG(x, x) WHERE x == 5,
3150+
w_avg3 = WEIGHTED_AVG(x, 2) WHERE x <= 5,
3151+
w_avg4 = WEIGHTED_AVG(x, x) WHERE x > 5,
3152+
w_avg5 = WEIGHTED_AVG([1,2,3], 1),
3153+
w_avg6 = WEIGHTED_AVG([1,2,3], 1) WHERE x == 5
3154+
;
3155+
3156+
w_avg1:double | w_avg2:double | w_avg3:double | w_avg4:double | w_avg5:double | w_avg6:double
3157+
5.0 | 5.0 | 3.0 | 8.25 | 2.0 | 2.0
3158+
;
3159+
3160+
maxWithConditions
3161+
required_capability: stats_with_filtered_surrogate_fixed
3162+
required_capability: aggregate_metric_double_convert_to
3163+
3164+
ROW x = [1, 2, 3, 4, 5]
3165+
| MV_EXPAND x
3166+
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
3167+
| STATS max1 = MAX(agg_metric) WHERE x <= 3,
3168+
max2 = MAX(agg_metric),
3169+
max3 = MAX(x),
3170+
max4 = MAX(x) WHERE x > 3
3171+
;
3172+
3173+
max1:double | max2:double | max3:integer | max4:integer
3174+
3.0 | 5.0 | 5 | 5
3175+
;
3176+
3177+
minWithConditions
3178+
required_capability: stats_with_filtered_surrogate_fixed
3179+
required_capability: aggregate_metric_double_convert_to
3180+
3181+
ROW x = [1, 2, 3, 4, 5]
3182+
| MV_EXPAND x
3183+
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
3184+
| STATS min1 = MIN(agg_metric) WHERE x <= 3,
3185+
min2 = MIN(agg_metric),
3186+
min3 = MIN(x),
3187+
min4 = MIN(x) WHERE x > 3
3188+
;
3189+
3190+
min1:double | min2:double | min3:integer | min4:integer
3191+
1.0 | 1.0 | 1 | 4
3192+
;
3193+
3194+
countWithConditions
3195+
required_capability: stats_with_filtered_surrogate_fixed
3196+
required_capability: aggregate_metric_double_convert_to
3197+
3198+
ROW x = [1, 2, 3, 4, 5]
3199+
| MV_EXPAND x
3200+
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
3201+
| STATS count1 = COUNT(x) WHERE x >= 3,
3202+
count2 = COUNT(x),
3203+
count3 = COUNT(agg_metric),
3204+
count4 = COUNT(agg_metric) WHERE x >=3,
3205+
count5 = COUNT(4) WHERE x >= 3,
3206+
count6 = COUNT(*) WHERE x >= 3,
3207+
count7 = COUNT([1,2,3]) WHERE x >= 3,
3208+
count8 = COUNT([1,2,3])
3209+
;
3210+
3211+
count1:long | count2:long | count3:long | count4:long | count5:long | count6:long | count7:long | count8:long
3212+
3 | 5 | 5 | 3 | 3 | 3 | 9 | 15
3213+
;
3214+
3215+
countDistinctWithConditions
3216+
required_capability: stats_with_filtered_surrogate_fixed
3217+
3218+
ROW x = [1, 2, 3, 4, 5]
3219+
| MV_EXPAND x
3220+
| STATS count1 = COUNT_DISTINCT(x) WHERE x <= 3,
3221+
count2 = COUNT_DISTINCT(x),
3222+
count3 = COUNT_DISTINCT(1) WHERE x <= 3,
3223+
count4 = COUNT_DISTINCT(1)
3224+
;
3225+
3226+
count1:long | count2:long | count3:long | count4:long
3227+
3 | 5 | 1 | 1
3228+
;
3229+
3230+
avgWithConditions
3231+
required_capability: stats_with_filtered_surrogate_fixed
3232+
required_capability: aggregate_metric_double_convert_to
3233+
3234+
ROW x = [1, 2, 3, 4, 5]
3235+
| MV_EXPAND x
3236+
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
3237+
| STATS avg1 = AVG(x) WHERE x <= 3,
3238+
avg2 = AVG(x),
3239+
avg3 = AVG(agg_metric) WHERE x <=3,
3240+
avg4 = AVG(agg_metric)
3241+
;
3242+
3243+
avg1:double | avg2:double | avg3:double | avg4:double
3244+
2.0 | 3.0 | 2.0 | 3.0
3245+
;
3246+
3247+
percentileWithConditions
3248+
required_capability: stats_with_filtered_surrogate_fixed
3249+
3250+
ROW x = [1, 2, 3, 4, 5]
3251+
| MV_EXPAND x
3252+
| STATS percentile1 = PERCENTILE(x, 50) WHERE x <= 3,
3253+
percentile2 = PERCENTILE(x, 50)
3254+
;
3255+
3256+
percentile1:double | percentile2:double
3257+
2.0 | 3.0
3258+
;
3259+
3260+
medianWithConditions
3261+
required_capability: stats_with_filtered_surrogate_fixed
3262+
3263+
ROW x = [1, 2, 3, 4, 5]
3264+
| MV_EXPAND x
3265+
| STATS median1 = MEDIAN(x) WHERE x <= 3,
3266+
median2 = MEDIAN(x),
3267+
median3 = MEDIAN([5,6,7,8,9]) WHERE x <= 3,
3268+
median4 = MEDIAN([5,6,7,8,9])
3269+
;
3270+
3271+
median1:double | median2:double | median3:double | median4:double
3272+
2.0 | 3.0 | 7.0 | 7.0
3273+
;
3274+
3275+
medianAbsoluteDeviationWithConditions
3276+
required_capability: stats_with_filtered_surrogate_fixed
3277+
3278+
ROW x = [1, 3, 4, 7, 11, 18]
3279+
| MV_EXPAND x
3280+
| STATS median_dev1 = MEDIAN_ABSOLUTE_DEVIATION(x) WHERE x <= 3,
3281+
median_dev2 = MEDIAN_ABSOLUTE_DEVIATION(x),
3282+
median_dev3 = MEDIAN_ABSOLUTE_DEVIATION([3, 11, 14, 25]) WHERE x <= 3,
3283+
median_dev4 = MEDIAN_ABSOLUTE_DEVIATION([3, 11, 14, 25])
3284+
;
3285+
3286+
median_dev1:double | median_dev2:double | median_dev3:double | median_dev4:double
3287+
1.0 | 3.5 | 5.5 | 5.5
3288+
;
3289+
3290+
topWithConditions
3291+
required_capability: stats_with_filtered_surrogate_fixed
3292+
3293+
FROM employees
3294+
| STATS min1 = TOP(emp_no, 1, "ASC") WHERE emp_no > 10010,
3295+
min2 = TOP(emp_no, 2, "ASC") WHERE emp_no > 10010,
3296+
max1 = TOP(emp_no, 1, "DESC") WHERE emp_no < 10080,
3297+
max2 = TOP(emp_no, 2, "DESC") WHERE emp_no < 10080
3298+
;
3299+
3300+
min1:integer | min2:integer | max1:integer | max2:integer
3301+
10011 | [10011, 10012] | 10079 | [10079, 10078]
3302+
;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1052,7 +1052,14 @@ public enum Cap {
10521052
/**
10531053
* Support correct counting of skipped shards.
10541054
*/
1055-
CORRECT_SKIPPED_SHARDS_COUNT;
1055+
CORRECT_SKIPPED_SHARDS_COUNT,
1056+
/**
1057+
* Bugfix for STATS {{expression}} WHERE {{condition}} when the
1058+
* expression is replaced by something else on planning
1059+
* e.g. STATS SUM(1) WHERE x==3 is replaced by
1060+
* STATS MV_SUM(const)*COUNT(*) WHERE x == 3.
1061+
*/
1062+
STATS_WITH_FILTERED_SURROGATE_FIXED;
10561063

10571064
private final boolean enabled;
10581065

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Count.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,11 @@ public Expression surrogate() {
145145
var s = source();
146146
var field = field();
147147
if (field.dataType() == DataType.AGGREGATE_METRIC_DOUBLE) {
148-
return new Sum(s, FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.COUNT));
148+
return new Sum(
149+
s,
150+
FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.COUNT),
151+
filter()
152+
);
149153
}
150154

151155
if (field.foldable()) {
@@ -162,7 +166,7 @@ public Expression surrogate() {
162166
return new Mul(
163167
s,
164168
new Coalesce(s, new MvCount(s, field), List.of(new Literal(s, 0, DataType.INTEGER))),
165-
new Count(s, Literal.keyword(s, StringUtils.WILDCARD))
169+
new Count(s, Literal.keyword(s, StringUtils.WILDCARD), filter())
166170
);
167171
}
168172

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,11 @@ public final AggregatorFunctionSupplier supplier() {
152152
@Override
153153
public Expression surrogate() {
154154
if (field().dataType() == DataType.AGGREGATE_METRIC_DOUBLE) {
155-
return new Max(source(), FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MAX));
155+
return new Max(
156+
source(),
157+
FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MAX),
158+
filter()
159+
);
156160
}
157161
return field().foldable() ? new MvMax(source(), field()) : null;
158162
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Median.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,6 @@ public Expression surrogate() {
110110

111111
return field.foldable()
112112
? new MvMedian(s, new ToDouble(s, field))
113-
: new Percentile(source(), field(), new Literal(source(), (int) QuantileStates.MEDIAN, DataType.INTEGER));
113+
: new Percentile(source(), field(), filter(), new Literal(source(), (int) QuantileStates.MEDIAN, DataType.INTEGER));
114114
}
115115
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,11 @@ public final AggregatorFunctionSupplier supplier() {
152152
@Override
153153
public Expression surrogate() {
154154
if (field().dataType() == DataType.AGGREGATE_METRIC_DOUBLE) {
155-
return new Min(source(), FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MIN));
155+
return new Min(
156+
source(),
157+
FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MIN),
158+
filter()
159+
);
156160
}
157161
return field().foldable() ? new MvMin(source(), field()) : null;
158162
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,14 @@ public Expression surrogate() {
138138
var s = source();
139139
var field = field();
140140
if (field.dataType() == AGGREGATE_METRIC_DOUBLE) {
141-
return new Sum(s, FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.SUM));
141+
return new Sum(
142+
s,
143+
FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.SUM),
144+
filter()
145+
);
142146
}
143147

144148
// SUM(const) is equivalent to MV_SUM(const)*COUNT(*).
145-
return field.foldable() ? new Mul(s, new MvSum(s, field), new Count(s, Literal.keyword(s, StringUtils.WILDCARD))) : null;
149+
return field.foldable() ? new Mul(s, new MvSum(s, field), new Count(s, Literal.keyword(s, StringUtils.WILDCARD), filter())) : null;
146150
}
147151
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,9 +217,9 @@ public Expression surrogate() {
217217

218218
if (limitValue() == 1) {
219219
if (orderValue()) {
220-
return new Min(s, field());
220+
return new Min(s, field(), filter());
221221
} else {
222-
return new Max(s, field());
222+
return new Max(s, field(), filter());
223223
}
224224
}
225225

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvg.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,9 @@ public Expression surrogate() {
159159
return new MvAvg(s, field);
160160
}
161161
if (weight.foldable()) {
162-
return new Div(s, new Sum(s, field), new Count(s, field), dataType());
162+
return new Div(s, new Sum(s, field, filter()), new Count(s, field, filter()), dataType());
163163
} else {
164-
return new Div(s, new Sum(s, new Mul(s, field, weight)), new Sum(s, weight), dataType());
164+
return new Div(s, new Sum(s, new Mul(s, field, weight), filter()), new Sum(s, weight, filter()), dataType());
165165
}
166166
}
167167

0 commit comments

Comments
 (0)