Skip to content

Commit d4a3244

Browse files
committed
Fixes filter part on aggregation functions' surrogates
1 parent b0a771b commit d4a3244

File tree

9 files changed

+175
-24
lines changed

9 files changed

+175
-24
lines changed

x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3145,3 +3145,162 @@ FROM employees
31453145
m:datetime | x:integer | d:boolean
31463146
1999-04-30T00:00:00.000Z | 2 | true
31473147
;
3148+
3149+
sumWithConditions
3150+
required_capability: stats_with_filtered_surrogate_fixed
3151+
3152+
FROM employees
3153+
| STATS sum1 = SUM(1),
3154+
sum2 = SUM(1) WHERE emp_no == 10080,
3155+
sum3 = SUM(1) WHERE emp_no < 10080,
3156+
sum4 = SUM(1) WHERE emp_no >= 10080
3157+
;
3158+
3159+
sum1:long | sum2:long | sum3:long | sum4:long
3160+
100 | 1 | 79 | 21
3161+
;
3162+
3163+
weightedAvgWithConditions
3164+
required_capability: stats_with_filtered_surrogate_fixed
3165+
3166+
ROW x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
3167+
| MV_EXPAND x
3168+
| STATS w_avg1 = WEIGHTED_AVG(x, 1) WHERE x == 5,
3169+
w_avg2 = WEIGHTED_AVG(x, x) WHERE x == 5,
3170+
w_avg3 = WEIGHTED_AVG(x, 2) WHERE x <= 5,
3171+
w_avg4 = WEIGHTED_AVG(x, x) WHERE x > 5
3172+
;
3173+
3174+
w_avg1:double | w_avg2:double | w_avg3:double | w_avg4:double
3175+
5.0 | 5.0 | 3.0 | 8.25
3176+
;
3177+
3178+
maxWithConditions
3179+
required_capability: stats_with_filtered_surrogate_fixed
3180+
required_capability: aggregate_metric_double_convert_to
3181+
3182+
ROW x = [1, 2, 3, 4, 5]
3183+
| MV_EXPAND x
3184+
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
3185+
| STATS max = MAX(agg_metric) WHERE x <= 3
3186+
;
3187+
3188+
max:double
3189+
3.0
3190+
;
3191+
3192+
minWithConditions
3193+
required_capability: stats_with_filtered_surrogate_fixed
3194+
required_capability: aggregate_metric_double_convert_to
3195+
3196+
ROW x = [1, 2, 3, 4, 5]
3197+
| MV_EXPAND x
3198+
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
3199+
| STATS min = MIN(agg_metric) WHERE x >= 3
3200+
;
3201+
3202+
min:double
3203+
3.0
3204+
;
3205+
3206+
countWithConditions
3207+
required_capability: stats_with_filtered_surrogate_fixed
3208+
required_capability: aggregate_metric_double_convert_to
3209+
3210+
ROW x = [1, 2, 3, 4, 5]
3211+
| MV_EXPAND x
3212+
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
3213+
| STATS count1 = COUNT(x) WHERE x >= 3,
3214+
count2 = COUNT(agg_metric) WHERE x >=3,
3215+
count3 = COUNT(4) WHERE x >= 3,
3216+
count4 = COUNT(*) WHERE x >= 3,
3217+
count5 = COUNT([1,2,3]) WHERE x >= 3
3218+
;
3219+
3220+
count1:long | count2:long | count3:long | count4:long | count5:long
3221+
3 | 3 | 3 | 3 | 9
3222+
;
3223+
3224+
countDistinctWithConditions
3225+
required_capability: stats_with_filtered_surrogate_fixed
3226+
3227+
ROW x = [1, 2, 3, 4, 5]
3228+
| MV_EXPAND x
3229+
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
3230+
| STATS count1 = COUNT_DISTINCT(x) WHERE x <= 3,
3231+
count2 = COUNT_DISTINCT(1) WHERE x <= 3
3232+
;
3233+
3234+
count1:long | count2:long
3235+
3 | 1
3236+
;
3237+
3238+
avgWithConditions
3239+
required_capability: stats_with_filtered_surrogate_fixed
3240+
required_capability: aggregate_metric_double_convert_to
3241+
3242+
ROW x = [1, 2, 3, 4, 5]
3243+
| MV_EXPAND x
3244+
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
3245+
| STATS avg1 = AVG(x) WHERE x <= 3,
3246+
avg2 = AVG(agg_metric) WHERE x <=3
3247+
;
3248+
3249+
avg1:double | avg2:double
3250+
2.0 | 2.0
3251+
;
3252+
3253+
percentileWithConditions
3254+
required_capability: stats_with_filtered_surrogate_fixed
3255+
3256+
ROW x = [1, 2, 3, 4, 5]
3257+
| MV_EXPAND x
3258+
| STATS percentile1 = PERCENTILE(x, 100) WHERE x <= 3,
3259+
percentile2 = PERCENTILE(x, 100)
3260+
;
3261+
3262+
percentile1:double | percentile2:double
3263+
3.0 | 5.0
3264+
;
3265+
3266+
medianWithConditions
3267+
required_capability: stats_with_filtered_surrogate_fixed
3268+
3269+
ROW x = [1, 2, 3, 4, 5]
3270+
| MV_EXPAND x
3271+
| STATS median1 = MEDIAN(x) WHERE x <= 3,
3272+
median2 = MEDIAN(x),
3273+
median3 = MEDIAN([5,6,7,8,9]) WHERE x <= 3
3274+
;
3275+
3276+
median1:double | median2:double | median3:double
3277+
2.0 | 3.0 | 7.0
3278+
;
3279+
3280+
medianAbsoluteDeviationWithConditions
3281+
required_capability: stats_with_filtered_surrogate_fixed
3282+
3283+
ROW x = [1, 2, 3, 4, 5]
3284+
| MV_EXPAND x
3285+
| STATS median_deviation1 = MEDIAN_ABSOLUTE_DEVIATION(x) WHERE x <= 3,
3286+
median_deviation2 = MEDIAN_ABSOLUTE_DEVIATION(x),
3287+
median_deviation3 = MEDIAN_ABSOLUTE_DEVIATION([5,6,7,8,9]) WHERE x <= 3
3288+
;
3289+
3290+
median_deviation1:double | median_deviation2:double | median_deviation3:double
3291+
1.0 | 1.0 | 1.0
3292+
;
3293+
3294+
topWithConditions
3295+
required_capability: stats_with_filtered_surrogate_fixed
3296+
3297+
FROM employees
3298+
| STATS min1 = TOP(emp_no, 1, "ASC") WHERE emp_no > 10010,
3299+
min2 = TOP(emp_no, 2, "ASC") WHERE emp_no > 10010,
3300+
max1 = TOP(emp_no, 1, "DESC") WHERE emp_no < 10080,
3301+
max2 = TOP(emp_no, 2, "DESC") WHERE emp_no < 10080
3302+
;
3303+
3304+
min1:integer | min2:integer | max1:integer | max2:integer
3305+
10011 | [10011, 10012] | 10079 | [10079, 10078]
3306+
;

x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_top.csv-spec

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -300,17 +300,3 @@ FROM books
300300
The Lord of the Rings Poster Collection: Six Paintings by Alan Lee (No. 1) | he Lo | [J. R. R. Tolkien, Alan Lee] | Alan Lee
301301
A Gentle Creature and Other Stories: White Nights, A Gentle Creature, and The Dream of a Ridiculous Man (The World's Classics) | Gent | [W. J. Leatherbarrow, Fyodor Dostoevsky, Alan Myers] | Alan Myers
302302
;
303-
304-
topWithConditions
305-
required_capability: stats_top_1_with_condition_fixed
306-
307-
FROM employees
308-
| STATS min1 = TOP(emp_no, 1, "ASC") WHERE emp_no > 10010,
309-
min2 = TOP(emp_no, 2, "ASC") WHERE emp_no > 10010,
310-
max1 = TOP(emp_no, 1, "DESC") WHERE emp_no < 10080,
311-
max2 = TOP(emp_no, 2, "DESC") WHERE emp_no < 10080
312-
;
313-
314-
min1:integer | min2:integer | max1:integer | max2:integer
315-
10011 | [10011, 10012] | 10079 | [10079, 10078]
316-
;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1465,9 +1465,11 @@ public enum Cap {
14651465
FN_PRESENT,
14661466

14671467
/**
1468-
* Bugfix for STATS TOP(field, 1, order) WHERE condition.
1468+
* Bugfix for STATS {{expression}} WHERE {{condition}} when the expression
1469+
* is replaced by something else on planning
1470+
* e.g. STATS SUM(1) WHERE x==3 is replaced by MV_SUM(const)*COUNT(* WHERE x == 3).
14691471
*/
1470-
STATS_TOP_1_WITH_CONDITION_FIXED;
1472+
STATS_WITH_FILTERED_SURROGATE_FIXED;
14711473

14721474
private final boolean enabled;
14731475

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Count.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ public Expression surrogate() {
152152
var s = source();
153153
var field = field();
154154
if (field.dataType() == DataType.AGGREGATE_METRIC_DOUBLE) {
155-
return new Sum(s, FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.COUNT));
155+
return new Sum(s, FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.COUNT), filter());
156156
}
157157

158158
if (field.foldable()) {
@@ -169,7 +169,7 @@ public Expression surrogate() {
169169
return new Mul(
170170
s,
171171
new Coalesce(s, new MvCount(s, field), List.of(new Literal(s, 0, DataType.INTEGER))),
172-
new Count(s, Literal.keyword(s, StringUtils.WILDCARD))
172+
new Count(s, Literal.keyword(s, StringUtils.WILDCARD), filter())
173173
);
174174
}
175175

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ public final AggregatorFunctionSupplier supplier() {
160160
@Override
161161
public Expression surrogate() {
162162
if (field().dataType() == DataType.AGGREGATE_METRIC_DOUBLE) {
163-
return new Max(source(), FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MAX));
163+
return new Max(source(), FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MAX), filter());
164164
}
165165
return field().foldable() ? new MvMax(source(), field()) : null;
166166
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Median.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,6 @@ public Expression surrogate() {
117117

118118
return field.foldable()
119119
? new MvMedian(s, new ToDouble(s, field))
120-
: new Percentile(source(), field(), new Literal(source(), (int) QuantileStates.MEDIAN, DataType.INTEGER));
120+
: new Percentile(source(), field(), filter(), new Literal(source(), (int) QuantileStates.MEDIAN, DataType.INTEGER));
121121
}
122122
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ public final AggregatorFunctionSupplier supplier() {
160160
@Override
161161
public Expression surrogate() {
162162
if (field().dataType() == DataType.AGGREGATE_METRIC_DOUBLE) {
163-
return new Min(source(), FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MIN));
163+
return new Min(source(), FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MIN), filter());
164164
}
165165
return field().foldable() ? new MvMin(source(), field()) : null;
166166
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ public Sum(Source source, @Param(name = "number", type = { "aggregate_metric_dou
6666
this(source, field, Literal.TRUE, SummationMode.COMPENSATED_LITERAL);
6767
}
6868

69+
public Sum(Source source, @Param(name = "number", type = { "aggregate_metric_double", "double", "integer", "long" }) Expression field, Expression filter) {
70+
this(source, field, filter, SummationMode.COMPENSATED_LITERAL);
71+
}
72+
6973
public Sum(Source source, Expression field, Expression filter, Expression summationMode) {
7074
super(source, field, filter, List.of(summationMode));
7175
this.summationMode = summationMode;
@@ -163,6 +167,6 @@ public Expression surrogate() {
163167
}
164168

165169
// SUM(const) is equivalent to MV_SUM(const)*COUNT(*).
166-
return field.foldable() ? new Mul(s, new MvSum(s, field), new Count(s, Literal.keyword(s, StringUtils.WILDCARD))) : null;
170+
return field.foldable() ? new Mul(s, new MvSum(s, field), new Count(s, Literal.keyword(s, StringUtils.WILDCARD), filter())) : null;
167171
}
168172
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvg.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,9 @@ public Expression surrogate() {
160160
return new MvAvg(s, field);
161161
}
162162
if (weight.foldable()) {
163-
return new Div(s, new Sum(s, field), new Count(s, field), dataType());
163+
return new Div(s, new Sum(s, field, filter()), new Count(s, field, filter()), dataType());
164164
} else {
165-
return new Div(s, new Sum(s, new Mul(s, field, weight)), new Sum(s, weight), dataType());
165+
return new Div(s, new Sum(s, new Mul(s, field, weight), filter()), new Sum(s, weight, filter()), dataType());
166166
}
167167
}
168168

0 commit comments

Comments
 (0)