diff --git a/docs/changelog/134461.yaml b/docs/changelog/134461.yaml
new file mode 100644
index 0000000000000..0086e5e0a5050
--- /dev/null
+++ b/docs/changelog/134461.yaml
@@ -0,0 +1,6 @@
+pr: 134461
+summary: Propagates filter() to aggregation functions' surrogates
+area: Aggregations
+type: bug
+issues:
+ - 134380
diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec
index 7dcff7b2bdd98..c2a4db52b21d7 100644
--- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec
+++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec
@@ -3097,3 +3097,162 @@ ROW a = [1,2,3], b = 5
STD_DEV(a):double | STD_DEV(b):double
0.816496580927726 | 0.0
;
+
+sumWithConditions
+required_capability: stats_with_filtered_surrogate_fixed
+
+FROM employees
+| STATS sum1 = SUM(1),
+ sum2 = SUM(1) WHERE emp_no == 10080,
+ sum3 = SUM(1) WHERE emp_no < 10080,
+ sum4 = SUM(1) WHERE emp_no >= 10080
+;
+
+sum1:long | sum2:long | sum3:long | sum4:long
+100 | 1 | 79 | 21
+;
+
+weightedAvgWithConditions
+required_capability: stats_with_filtered_surrogate_fixed
+
+ROW x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+| MV_EXPAND x
+| STATS w_avg1 = WEIGHTED_AVG(x, 1) WHERE x == 5,
+ w_avg2 = WEIGHTED_AVG(x, x) WHERE x == 5,
+ w_avg3 = WEIGHTED_AVG(x, 2) WHERE x <= 5,
+ w_avg4 = WEIGHTED_AVG(x, x) WHERE x > 5,
+ w_avg5 = WEIGHTED_AVG([1,2,3], 1),
+ w_avg6 = WEIGHTED_AVG([1,2,3], 1) WHERE x == 5
+;
+
+w_avg1:double | w_avg2:double | w_avg3:double | w_avg4:double | w_avg5:double | w_avg6:double
+5.0 | 5.0 | 3.0 | 8.25 | 2.0 | 2.0
+;
+
+maxWithConditions
+required_capability: stats_with_filtered_surrogate_fixed
+
+ROW x = [1, 2, 3, 4, 5]
+| MV_EXPAND x
+| STATS max1 = MAX(x),
+ max2 = MAX(x) WHERE x > 3
+;
+
+max1:integer | max2:integer
+5 | 5
+;
+
+minWithConditions
+required_capability: stats_with_filtered_surrogate_fixed
+
+ROW x = [1, 2, 3, 4, 5]
+| MV_EXPAND x
+| STATS min1 = MIN(x),
+ min2 = MIN(x) WHERE x > 3
+;
+
+min1:integer | min2:integer
+1 | 4
+;
+
+countWithConditions
+required_capability: stats_with_filtered_surrogate_fixed
+
+ROW x = [1, 2, 3, 4, 5]
+| MV_EXPAND x
+| STATS count1 = COUNT(x) WHERE x >= 3,
+ count2 = COUNT(x),
+ count3 = COUNT(4) WHERE x >= 3,
+ count4 = COUNT(*) WHERE x >= 3,
+ count5 = COUNT([1,2,3]) WHERE x >= 3,
+ count6 = COUNT([1,2,3])
+;
+
+count1:long | count2:long | count3:long | count4:long | count5:long | count6:long
+3 | 5 | 3 | 3 | 9 | 15
+;
+
+countDistinctWithConditions
+required_capability: stats_with_filtered_surrogate_fixed
+
+ROW x = [1, 2, 3, 4, 5]
+| MV_EXPAND x
+| STATS count1 = COUNT_DISTINCT(x) WHERE x <= 3,
+ count2 = COUNT_DISTINCT(x),
+ count3 = COUNT_DISTINCT(1) WHERE x <= 3,
+ count4 = COUNT_DISTINCT(1)
+;
+
+count1:long | count2:long | count3:long | count4:long
+3 | 5 | 1 | 1
+;
+
+avgWithConditions
+required_capability: stats_with_filtered_surrogate_fixed
+
+ROW x = [1, 2, 3, 4, 5]
+| MV_EXPAND x
+| STATS avg1 = AVG(x) WHERE x <= 3,
+ avg2 = AVG(x)
+;
+
+avg1:double | avg2:double
+2.0 | 3.0
+;
+
+percentileWithConditions
+required_capability: stats_with_filtered_surrogate_fixed
+
+ROW x = [1, 2, 3, 4, 5]
+| MV_EXPAND x
+| STATS percentile1 = PERCENTILE(x, 50) WHERE x <= 3,
+ percentile2 = PERCENTILE(x, 50)
+;
+
+percentile1:double | percentile2:double
+2.0 | 3.0
+;
+
+medianWithConditions
+required_capability: stats_with_filtered_surrogate_fixed
+
+ROW x = [1, 2, 3, 4, 5]
+| MV_EXPAND x
+| STATS median1 = MEDIAN(x) WHERE x <= 3,
+ median2 = MEDIAN(x),
+ median3 = MEDIAN([5,6,7,8,9]) WHERE x <= 3,
+ median4 = MEDIAN([5,6,7,8,9])
+;
+
+median1:double | median2:double | median3:double | median4:double
+2.0 | 3.0 | 7.0 | 7.0
+;
+
+medianAbsoluteDeviationWithConditions
+required_capability: stats_with_filtered_surrogate_fixed
+
+ROW x = [1, 3, 4, 7, 11, 18]
+| MV_EXPAND x
+| STATS median_dev1 = MEDIAN_ABSOLUTE_DEVIATION(x) WHERE x <= 3,
+ median_dev2 = MEDIAN_ABSOLUTE_DEVIATION(x),
+ median_dev3 = MEDIAN_ABSOLUTE_DEVIATION([3, 11, 14, 25]) WHERE x <= 3,
+ median_dev4 = MEDIAN_ABSOLUTE_DEVIATION([3, 11, 14, 25])
+;
+
+median_dev1:double | median_dev2:double | median_dev3:double | median_dev4:double
+1.0 | 3.5 | 5.5 | 5.5
+;
+
+topWithConditions
+required_capability: stats_with_filtered_surrogate_fixed
+
+FROM employees
+| STATS min1 = TOP(emp_no, 1, "ASC") WHERE emp_no > 10010,
+ min2 = TOP(emp_no, 2, "ASC") WHERE emp_no > 10010,
+ max1 = TOP(emp_no, 1, "DESC") WHERE emp_no < 10080,
+ max2 = TOP(emp_no, 2, "DESC") WHERE emp_no < 10080
+;
+
+min1:integer | min2:integer | max1:integer | max2:integer
+10011 | [10011, 10012] | 10079 | [10079, 10078]
+;
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
index a70a221b1a8a3..c89cc4d2f8b55 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
@@ -766,7 +766,15 @@ public enum Cap {
* Support for the mv_expand target attribute should be retained in its original position.
* see ES|QL: inconsistent column order #129000
*/
- FIX_MV_EXPAND_INCONSISTENT_COLUMN_ORDER;
+ FIX_MV_EXPAND_INCONSISTENT_COLUMN_ORDER,
+
+ /**
+ * Bugfix for STATS {{expression}} WHERE {{condition}} when the
+ * expression is replaced by something else on planning
+ * e.g. STATS SUM(1) WHERE x==3 is replaced by
+ * STATS MV_SUM(const)*COUNT(*) WHERE x == 3.
+ */
+ STATS_WITH_FILTERED_SURROGATE_FIXED;
private final boolean enabled;
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Count.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Count.java
index 5ce43c7b3872d..42bdf0ce559b9 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Count.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Count.java
@@ -145,7 +145,11 @@ public Expression surrogate() {
var s = source();
var field = field();
if (field.dataType() == DataType.AGGREGATE_METRIC_DOUBLE) {
- return new Sum(s, FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.COUNT));
+ return new Sum(
+ s,
+ FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.COUNT),
+ filter()
+ );
}
if (field.foldable()) {
@@ -162,7 +166,7 @@ public Expression surrogate() {
return new Mul(
s,
new Coalesce(s, new MvCount(s, field), List.of(new Literal(s, 0, DataType.INTEGER))),
- new Count(s, new Literal(s, StringUtils.WILDCARD, DataType.KEYWORD))
+ new Count(s, new Literal(s, StringUtils.WILDCARD, DataType.KEYWORD), filter())
);
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java
index f0d0d8604f479..60279d9968711 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java
@@ -152,7 +152,11 @@ public final AggregatorFunctionSupplier supplier(List inputChannels) {
@Override
public Expression surrogate() {
if (field().dataType() == DataType.AGGREGATE_METRIC_DOUBLE) {
- return new Max(source(), FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MAX));
+ return new Max(
+ source(),
+ FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MAX),
+ filter()
+ );
}
return field().foldable() ? new MvMax(source(), field()) : null;
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Median.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Median.java
index c47fa612c1c49..4c2ff6392bfe4 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Median.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Median.java
@@ -110,6 +110,6 @@ public Expression surrogate() {
return field.foldable()
? new MvMedian(s, new ToDouble(s, field))
- : new Percentile(source(), field(), new Literal(source(), (int) QuantileStates.MEDIAN, DataType.INTEGER));
+ : new Percentile(source(), field(), filter(), new Literal(source(), (int) QuantileStates.MEDIAN, DataType.INTEGER));
}
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java
index 6436623ddaa63..71f22509b1d10 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java
@@ -152,7 +152,11 @@ public final AggregatorFunctionSupplier supplier(List inputChannels) {
@Override
public Expression surrogate() {
if (field().dataType() == DataType.AGGREGATE_METRIC_DOUBLE) {
- return new Min(source(), FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MIN));
+ return new Min(
+ source(),
+ FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MIN),
+ filter()
+ );
}
return field().foldable() ? new MvMin(source(), field()) : null;
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java
index 1c69edb9f0da9..0bdce345c4e94 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java
@@ -143,7 +143,7 @@ public Expression surrogate() {
// SUM(const) is equivalent to MV_SUM(const)*COUNT(*).
return field.foldable()
- ? new Mul(s, new MvSum(s, field), new Count(s, new Literal(s, StringUtils.WILDCARD, DataType.KEYWORD)))
+ ? new Mul(s, new MvSum(s, field), new Count(s, new Literal(s, StringUtils.WILDCARD, DataType.KEYWORD), filter()))
: null;
}
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java
index 9be8c94266ee8..49b7d55c3e3a7 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java
@@ -217,9 +217,9 @@ public Expression surrogate() {
if (limitValue() == 1) {
if (orderValue()) {
- return new Min(s, field());
+ return new Min(s, field(), filter());
} else {
- return new Max(s, field());
+ return new Max(s, field(), filter());
}
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvg.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvg.java
index bab65653ba576..a28fd5b146cd0 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvg.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvg.java
@@ -159,9 +159,9 @@ public Expression surrogate() {
return new MvAvg(s, field);
}
if (weight.foldable()) {
- return new Div(s, new Sum(s, field), new Count(s, field), dataType());
+ return new Div(s, new Sum(s, field, filter()), new Count(s, field, filter()), dataType());
} else {
- return new Div(s, new Sum(s, new Mul(s, field, weight)), new Sum(s, weight), dataType());
+ return new Div(s, new Sum(s, new Mul(s, field, weight), filter()), new Sum(s, weight, filter()), dataType());
}
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java
index 87ea6315d4f3b..d0d2b063c52eb 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java
@@ -150,6 +150,29 @@ public void testFold() {
}, this::evaluate);
}
+ public void testSurrogateHasFilter() {
+ Expression expression = randomFrom(
+ buildLiteralExpression(testCase),
+ buildDeepCopyOfFieldExpression(testCase),
+ buildFieldExpression(testCase)
+ );
+
+ assumeTrue("expression should have no type errors", expression.typeResolved().resolved());
+
+ if (expression instanceof AggregateFunction && expression instanceof SurrogateExpression) {
+ var filter = ((AggregateFunction) expression).filter();
+
+ var surrogate = ((SurrogateExpression) expression).surrogate();
+
+ if (surrogate != null) {
+ surrogate.forEachDown(AggregateFunction.class, child -> {
+ var surrogateFilter = child.filter();
+ assertEquals(filter, surrogateFilter);
+ });
+ }
+ }
+ }
+
private void aggregateSingleMode(Expression expression) {
Object result;
try (var aggregator = aggregator(expression, initialInputChannels(), AggregatorMode.SINGLE)) {