Skip to content

Commit 6471c58

Browse files
committed
Propagates filter() to aggregation functions' surrogates (#134461)
--------- Co-authored-by: Jan Kuipers <[email protected]> Co-authored-by: Jan Kuipers <[email protected]> (cherry picked from commit 381fc8e) # Conflicts: # x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec # x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java # x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Count.java # x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java
1 parent c2116d6 commit 6471c58

File tree

11 files changed

+243
-10
lines changed

11 files changed

+243
-10
lines changed

docs/changelog/134461.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 134461
2+
summary: Propagates filter() to aggregation functions' surrogates
3+
area: Aggregations
4+
type: bug
5+
issues:
6+
- 134380

x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3097,3 +3097,183 @@ ROW a = [1,2,3], b = 5
30973097
STD_DEV(a):double | STD_DEV(b):double
30983098
0.816496580927726 | 0.0
30993099
;
3100+
3101+
sumWithConditions
3102+
required_capability: stats_with_filtered_surrogate_fixed
3103+
required_capability: aggregate_metric_double_convert_to
3104+
3105+
FROM employees
3106+
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(1)
3107+
| STATS sum1 = SUM(1),
3108+
sum2 = SUM(1) WHERE emp_no == 10080,
3109+
sum3 = SUM(1) WHERE emp_no < 10080,
3110+
sum4 = SUM(1) WHERE emp_no >= 10080,
3111+
sum5 = SUM(agg_metric),
3112+
sum6 = SUM(agg_metric) WHERE emp_no == 10080
3113+
;
3114+
3115+
sum1:long | sum2:long | sum3:long | sum4:long | sum5:double | sum6:double
3116+
100 | 1 | 79 | 21 | 100.0 | 1.0
3117+
;
3118+
3119+
weightedAvgWithConditions
3120+
required_capability: stats_with_filtered_surrogate_fixed
3121+
3122+
ROW x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
3123+
| MV_EXPAND x
3124+
| STATS w_avg1 = WEIGHTED_AVG(x, 1) WHERE x == 5,
3125+
w_avg2 = WEIGHTED_AVG(x, x) WHERE x == 5,
3126+
w_avg3 = WEIGHTED_AVG(x, 2) WHERE x <= 5,
3127+
w_avg4 = WEIGHTED_AVG(x, x) WHERE x > 5,
3128+
w_avg5 = WEIGHTED_AVG([1,2,3], 1),
3129+
w_avg6 = WEIGHTED_AVG([1,2,3], 1) WHERE x == 5
3130+
;
3131+
3132+
w_avg1:double | w_avg2:double | w_avg3:double | w_avg4:double | w_avg5:double | w_avg6:double
3133+
5.0 | 5.0 | 3.0 | 8.25 | 2.0 | 2.0
3134+
;
3135+
3136+
maxWithConditions
3137+
required_capability: stats_with_filtered_surrogate_fixed
3138+
required_capability: aggregate_metric_double_convert_to
3139+
3140+
ROW x = [1, 2, 3, 4, 5]
3141+
| MV_EXPAND x
3142+
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
3143+
| STATS max1 = MAX(agg_metric) WHERE x <= 3,
3144+
max2 = MAX(agg_metric),
3145+
max3 = MAX(x),
3146+
max4 = MAX(x) WHERE x > 3
3147+
;
3148+
3149+
max1:double | max2:double | max3:integer | max4:integer
3150+
3.0 | 5.0 | 5 | 5
3151+
;
3152+
3153+
minWithConditions
3154+
required_capability: stats_with_filtered_surrogate_fixed
3155+
required_capability: aggregate_metric_double_convert_to
3156+
3157+
ROW x = [1, 2, 3, 4, 5]
3158+
| MV_EXPAND x
3159+
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
3160+
| STATS min1 = MIN(agg_metric) WHERE x <= 3,
3161+
min2 = MIN(agg_metric),
3162+
min3 = MIN(x),
3163+
min4 = MIN(x) WHERE x > 3
3164+
;
3165+
3166+
min1:double | min2:double | min3:integer | min4:integer
3167+
1.0 | 1.0 | 1 | 4
3168+
;
3169+
3170+
countWithConditions
3171+
required_capability: stats_with_filtered_surrogate_fixed
3172+
required_capability: aggregate_metric_double_convert_to
3173+
3174+
ROW x = [1, 2, 3, 4, 5]
3175+
| MV_EXPAND x
3176+
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
3177+
| STATS count1 = COUNT(x) WHERE x >= 3,
3178+
count2 = COUNT(x),
3179+
count3 = COUNT(agg_metric),
3180+
count4 = COUNT(agg_metric) WHERE x >=3,
3181+
count5 = COUNT(4) WHERE x >= 3,
3182+
count6 = COUNT(*) WHERE x >= 3,
3183+
count7 = COUNT([1,2,3]) WHERE x >= 3,
3184+
count8 = COUNT([1,2,3])
3185+
;
3186+
3187+
count1:long | count2:long | count3:long | count4:long | count5:long | count6:long | count7:long | count8:long
3188+
3 | 5 | 5 | 3 | 3 | 3 | 9 | 15
3189+
;
3190+
3191+
countDistinctWithConditions
3192+
required_capability: stats_with_filtered_surrogate_fixed
3193+
3194+
ROW x = [1, 2, 3, 4, 5]
3195+
| MV_EXPAND x
3196+
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
3197+
| STATS count1 = COUNT_DISTINCT(x) WHERE x <= 3,
3198+
count2 = COUNT_DISTINCT(x),
3199+
count3 = COUNT_DISTINCT(1) WHERE x <= 3,
3200+
count4 = COUNT_DISTINCT(1)
3201+
;
3202+
3203+
count1:long | count2:long | count3:long | count4:long
3204+
3 | 5 | 1 | 1
3205+
;
3206+
3207+
avgWithConditions
3208+
required_capability: stats_with_filtered_surrogate_fixed
3209+
required_capability: aggregate_metric_double_convert_to
3210+
3211+
ROW x = [1, 2, 3, 4, 5]
3212+
| MV_EXPAND x
3213+
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
3214+
| STATS avg1 = AVG(x) WHERE x <= 3,
3215+
avg2 = AVG(x),
3216+
avg3 = AVG(agg_metric) WHERE x <=3,
3217+
avg4 = AVG(agg_metric)
3218+
;
3219+
3220+
avg1:double | avg2:double | avg3:double | avg4:double
3221+
2.0 | 3.0 | 2.0 | 3.0
3222+
;
3223+
3224+
percentileWithConditions
3225+
required_capability: stats_with_filtered_surrogate_fixed
3226+
3227+
ROW x = [1, 2, 3, 4, 5]
3228+
| MV_EXPAND x
3229+
| STATS percentile1 = PERCENTILE(x, 50) WHERE x <= 3,
3230+
percentile2 = PERCENTILE(x, 50)
3231+
;
3232+
3233+
percentile1:double | percentile2:double
3234+
2.0 | 3.0
3235+
;
3236+
3237+
medianWithConditions
3238+
required_capability: stats_with_filtered_surrogate_fixed
3239+
3240+
ROW x = [1, 2, 3, 4, 5]
3241+
| MV_EXPAND x
3242+
| STATS median1 = MEDIAN(x) WHERE x <= 3,
3243+
median2 = MEDIAN(x),
3244+
median3 = MEDIAN([5,6,7,8,9]) WHERE x <= 3,
3245+
median4 = MEDIAN([5,6,7,8,9])
3246+
;
3247+
3248+
median1:double | median2:double | median3:double | median4:double
3249+
2.0 | 3.0 | 7.0 | 7.0
3250+
;
3251+
3252+
medianAbsoluteDeviationWithConditions
3253+
required_capability: stats_with_filtered_surrogate_fixed
3254+
3255+
ROW x = [1, 3, 4, 7, 11, 18]
3256+
| MV_EXPAND x
3257+
| STATS median_dev1 = MEDIAN_ABSOLUTE_DEVIATION(x) WHERE x <= 3,
3258+
median_dev2 = MEDIAN_ABSOLUTE_DEVIATION(x),
3259+
median_dev3 = MEDIAN_ABSOLUTE_DEVIATION([3, 11, 14, 25]) WHERE x <= 3,
3260+
median_dev4 = MEDIAN_ABSOLUTE_DEVIATION([3, 11, 14, 25])
3261+
;
3262+
3263+
median_dev1:double | median_dev2:double | median_dev3:double | median_dev4:double
3264+
1.0 | 3.5 | 5.5 | 5.5
3265+
;
3266+
3267+
topWithConditions
3268+
required_capability: stats_with_filtered_surrogate_fixed
3269+
3270+
FROM employees
3271+
| STATS min1 = TOP(emp_no, 1, "ASC") WHERE emp_no > 10010,
3272+
min2 = TOP(emp_no, 2, "ASC") WHERE emp_no > 10010,
3273+
max1 = TOP(emp_no, 1, "DESC") WHERE emp_no < 10080,
3274+
max2 = TOP(emp_no, 2, "DESC") WHERE emp_no < 10080
3275+
;
3276+
3277+
min1:integer | min2:integer | max1:integer | max2:integer
3278+
10011 | [10011, 10012] | 10079 | [10079, 10078]
3279+
;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,15 @@ public enum Cap {
766766
* Support for the mv_expand target attribute should be retained in its original position.
767767
* see <a href="https://github.com/elastic/elasticsearch/issues/129000"> ES|QL: inconsistent column order #129000 </a>
768768
*/
769-
FIX_MV_EXPAND_INCONSISTENT_COLUMN_ORDER;
769+
FIX_MV_EXPAND_INCONSISTENT_COLUMN_ORDER,
770+
771+
/**
772+
* Bugfix for STATS {{expression}} WHERE {{condition}} when the
773+
* expression is replaced by something else on planning
774+
* e.g. STATS SUM(1) WHERE x==3 is replaced by
775+
* STATS MV_SUM(const)*COUNT(*) WHERE x == 3.
776+
*/
777+
STATS_WITH_FILTERED_SURROGATE_FIXED;
770778

771779
private final boolean enabled;
772780

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Count.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,11 @@ public Expression surrogate() {
145145
var s = source();
146146
var field = field();
147147
if (field.dataType() == DataType.AGGREGATE_METRIC_DOUBLE) {
148-
return new Sum(s, FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.COUNT));
148+
return new Sum(
149+
s,
150+
FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.COUNT),
151+
filter()
152+
);
149153
}
150154

151155
if (field.foldable()) {
@@ -162,7 +166,7 @@ public Expression surrogate() {
162166
return new Mul(
163167
s,
164168
new Coalesce(s, new MvCount(s, field), List.of(new Literal(s, 0, DataType.INTEGER))),
165-
new Count(s, new Literal(s, StringUtils.WILDCARD, DataType.KEYWORD))
169+
new Count(s, new Literal(s, StringUtils.WILDCARD, DataType.KEYWORD), filter())
166170
);
167171
}
168172

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,11 @@ public final AggregatorFunctionSupplier supplier(List<Integer> inputChannels) {
152152
@Override
153153
public Expression surrogate() {
154154
if (field().dataType() == DataType.AGGREGATE_METRIC_DOUBLE) {
155-
return new Max(source(), FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MAX));
155+
return new Max(
156+
source(),
157+
FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MAX),
158+
filter()
159+
);
156160
}
157161
return field().foldable() ? new MvMax(source(), field()) : null;
158162
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Median.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,6 @@ public Expression surrogate() {
110110

111111
return field.foldable()
112112
? new MvMedian(s, new ToDouble(s, field))
113-
: new Percentile(source(), field(), new Literal(source(), (int) QuantileStates.MEDIAN, DataType.INTEGER));
113+
: new Percentile(source(), field(), filter(), new Literal(source(), (int) QuantileStates.MEDIAN, DataType.INTEGER));
114114
}
115115
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,11 @@ public final AggregatorFunctionSupplier supplier(List<Integer> inputChannels) {
152152
@Override
153153
public Expression surrogate() {
154154
if (field().dataType() == DataType.AGGREGATE_METRIC_DOUBLE) {
155-
return new Min(source(), FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MIN));
155+
return new Min(
156+
source(),
157+
FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MIN),
158+
filter()
159+
);
156160
}
157161
return field().foldable() ? new MvMin(source(), field()) : null;
158162
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,12 @@ public Expression surrogate() {
142142
}
143143

144144
// SUM(const) is equivalent to MV_SUM(const)*COUNT(*).
145+
<<<<<<< HEAD
145146
return field.foldable()
146147
? new Mul(s, new MvSum(s, field), new Count(s, new Literal(s, StringUtils.WILDCARD, DataType.KEYWORD)))
147148
: null;
149+
=======
150+
return field.foldable() ? new Mul(s, new MvSum(s, field), new Count(s, Literal.keyword(s, StringUtils.WILDCARD), filter())) : null;
151+
>>>>>>> 381fc8e7e730 (Propagates filter() to aggregation functions' surrogates (#134461))
148152
}
149153
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,9 +217,9 @@ public Expression surrogate() {
217217

218218
if (limitValue() == 1) {
219219
if (orderValue()) {
220-
return new Min(s, field());
220+
return new Min(s, field(), filter());
221221
} else {
222-
return new Max(s, field());
222+
return new Max(s, field(), filter());
223223
}
224224
}
225225

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvg.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,9 @@ public Expression surrogate() {
159159
return new MvAvg(s, field);
160160
}
161161
if (weight.foldable()) {
162-
return new Div(s, new Sum(s, field), new Count(s, field), dataType());
162+
return new Div(s, new Sum(s, field, filter()), new Count(s, field, filter()), dataType());
163163
} else {
164-
return new Div(s, new Sum(s, new Mul(s, field, weight)), new Sum(s, weight), dataType());
164+
return new Div(s, new Sum(s, new Mul(s, field, weight), filter()), new Sum(s, weight, filter()), dataType());
165165
}
166166
}
167167

0 commit comments

Comments
 (0)