Skip to content

Commit 0a81875

Browse files
authored
Fixes esql class cast bug in STATS at planning level (#137511)
1 parent ab55d97 commit 0a81875

File tree

10 files changed

+812
-297
lines changed

10 files changed

+812
-297
lines changed

docs/changelog/137511.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 137511
2+
summary: Fixes esql class cast bug in STATS at planning level
3+
area: ES|QL
4+
type: bug
5+
issues: []

x-pack/plugin/esql/qa/testFixtures/src/main/resources/inlinestats.csv-spec

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4392,3 +4392,20 @@ row a = 1
43924392
c:long
43934393
1
43944394
;
4395+
4396+
fixClassCastBugWithSeveralCounts
4397+
required_capability: inline_stats
4398+
required_capability: fix_stats_classcast_exception
4399+
4400+
FROM sample_data, sample_data_str
4401+
| EVAL one_ip = client_ip::ip
4402+
| INLINE STATS count1=count(client_ip::ip), count2=count(one_ip)
4403+
| KEEP count1, count2
4404+
| LIMIT 3
4405+
;
4406+
4407+
count1:long |count2:long
4408+
14 |14
4409+
14 |14
4410+
14 |14
4411+
;

x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3463,3 +3463,72 @@ employees:long | job_positions:keyword
34633463
15 | Tech Lead
34643464
;
34653465

3466+
fixClassCastBugWithCountDistinct
3467+
required_capability: fix_stats_classcast_exception
3468+
3469+
from airports
3470+
| rename scalerank AS x
3471+
| stats a = count(x), b = count(x) + count(x), c = count_distinct(x)
3472+
;
3473+
3474+
a:long | b:long | c:long
3475+
891 | 1782 | 8
3476+
;
3477+
3478+
fixClassCastBugWithValuesFn
3479+
required_capability: fix_stats_classcast_exception
3480+
3481+
ROW x = [1,2,3]
3482+
| STATS a = MV_COUNT(VALUES(x)), b = VALUES(x), c = SUM(x)
3483+
;
3484+
3485+
a:integer | b:integer | c:long
3486+
3 | [1, 2, 3] | 6
3487+
;
3488+
3489+
fixClassCastBugWithSeveralCountDistincts
3490+
required_capability: fix_stats_classcast_exception
3491+
3492+
ROW x = 1
3493+
| STATS a = 2*COUNT_DISTINCT(x), b = COUNT_DISTINCT(x), c = MAX(x)
3494+
;
3495+
3496+
a:long | b:long | c:integer
3497+
2 | 1 | 1
3498+
;
3499+
3500+
fixClassCastBugWithMedianPlusCountDistinct
3501+
required_capability: fix_stats_classcast_exception
3502+
3503+
FROM sample_data_ts_long
3504+
| EVAL sym1 = 0, sym5 = 1
3505+
| STATS sym2 = median(sym5) + 0, sym3 = median(sym5), sym4 = count_distinct(sym1)
3506+
;
3507+
3508+
sym2:double |sym3:double | sym4:long
3509+
1.0 | 1.0 | 1
3510+
;
3511+
3512+
fixClassCastBugWithFoldableLiterals
3513+
required_capability: fix_stats_classcast_exception
3514+
3515+
from airports
3516+
| rename scalerank AS x
3517+
| stats a = count(x), b = count(x) + count(x), c = count_distinct(x, 10), d = count_distinct(x, 10 + 1 - 1)
3518+
;
3519+
3520+
a:long | b:long | c:long | d:long
3521+
891 | 1782 | 8 | 8
3522+
;
3523+
3524+
fixClassCastBugWithSurrogateExpressions
3525+
required_capability: fix_stats_classcast_exception
3526+
3527+
from airports
3528+
| rename scalerank AS x
3529+
| stats a = median(x), b = percentile(x, 50), c = count_distinct(x)
3530+
;
3531+
3532+
a:double | b:double | c:long
3533+
6.0 | 6.0 | 8
3534+
;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1608,6 +1608,13 @@ public enum Cap {
16081608
*/
16091609
PUSHING_DOWN_EVAL_WITH_SCORE,
16101610

1611+
/**
1612+
* Fix for ClassCastException in STATS
1613+
* https://github.com/elastic/elasticsearch/issues/133992
1614+
* https://github.com/elastic/elasticsearch/issues/136598
1615+
*/
1616+
FIX_STATS_CLASSCAST_EXCEPTION,
1617+
16111618
/**
16121619
* Fix attribute equality to respect the name id of the attribute.
16131620
*/

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.elasticsearch.xpack.esql.optimizer.rules.logical.CombineLimitTopN;
2020
import org.elasticsearch.xpack.esql.optimizer.rules.logical.CombineProjections;
2121
import org.elasticsearch.xpack.esql.optimizer.rules.logical.ConstantFolding;
22+
import org.elasticsearch.xpack.esql.optimizer.rules.logical.DeduplicateAggs;
2223
import org.elasticsearch.xpack.esql.optimizer.rules.logical.ExtractAggregateCommonFilter;
2324
import org.elasticsearch.xpack.esql.optimizer.rules.logical.FoldNull;
2425
import org.elasticsearch.xpack.esql.optimizer.rules.logical.HoistRemoteEnrichLimit;
@@ -178,6 +179,14 @@ protected static Batch<LogicalPlan> operators() {
178179
new SplitInWithFoldableValue(),
179180
new PropagateEvalFoldables(),
180181
new ConstantFolding(),
182+
/* Then deduplicate aggregations
183+
We need this after the constant folding
184+
because we could have expressions like
185+
count_distinct(_, 9 + 1)
186+
count_distinct(_, 10)
187+
which are semantically identical
188+
*/
189+
new DeduplicateAggs(),
181190
new PartiallyFoldCase(),
182191
// boolean
183192
new BooleanSimplification(),
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.optimizer.rules.logical;
9+
10+
/**
11+
* This rule handles duplicate aggregate functions to avoid duplicate compute
12+
* stats a = min(x), b = min(x), c = count(*), d = count() by g
13+
* becomes
14+
* stats a = min(x), c = count(*) by g | eval b = a, d = c | keep a, b, c, d, g
15+
*/
16+
public final class DeduplicateAggs extends ReplaceAggregateAggExpressionWithEval implements OptimizerRules.CoordinatorOnly {
17+
18+
public DeduplicateAggs() {
19+
super(false);
20+
}
21+
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateAggExpressionWithEval.java

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,17 @@
4343
* becomes
4444
* stats a = min(x), c = count(*) by g | eval b = a, d = c | keep a, b, c, d, g
4545
*/
46-
public final class ReplaceAggregateAggExpressionWithEval extends OptimizerRules.OptimizerRule<Aggregate> {
46+
public class ReplaceAggregateAggExpressionWithEval extends OptimizerRules.OptimizerRule<Aggregate> {
47+
private final boolean replaceNestedExpressions;
48+
49+
public ReplaceAggregateAggExpressionWithEval(boolean replaceNestedExpressions) {
50+
super(OptimizerRules.TransformDirection.UP);
51+
this.replaceNestedExpressions = replaceNestedExpressions;
52+
}
53+
4754
public ReplaceAggregateAggExpressionWithEval() {
4855
super(OptimizerRules.TransformDirection.UP);
56+
this.replaceNestedExpressions = true;
4957
}
5058

5159
@Override
@@ -88,7 +96,7 @@ protected LogicalPlan rule(Aggregate aggregate) {
8896
// common case - handle duplicates
8997
if (child instanceof AggregateFunction af) {
9098
// canonical representation, with resolved aliases
91-
AggregateFunction canonical = (AggregateFunction) af.canonical().transformUp(e -> aliases.resolve(e, e));
99+
AggregateFunction canonical = getCannonical(af, aliases);
92100

93101
Alias found = rootAggs.get(canonical);
94102
// aggregate is new
@@ -106,14 +114,15 @@ protected LogicalPlan rule(Aggregate aggregate) {
106114
}
107115
// nested expression over aggregate function or groups
108116
// replace them with reference and move the expression into a follow-up eval
109-
else {
117+
else if (replaceNestedExpressions) {
110118
changed.set(true);
111119
Expression aggExpression = child.transformUp(AggregateFunction.class, af -> {
112-
AggregateFunction canonical = (AggregateFunction) af.canonical();
120+
// canonical representation, with resolved aliases
121+
AggregateFunction canonical = getCannonical(af, aliases);
113122
Alias alias = rootAggs.get(canonical);
114123
if (alias == null) {
115-
// create synthetic alias ove the found agg function
116-
alias = new Alias(af.source(), syntheticName(canonical, child, counter[0]++), canonical, null, true);
124+
// create synthetic alias over the found agg function
125+
alias = new Alias(af.source(), syntheticName(canonical, child, counter[0]++), af.canonical(), null, true);
117126
// and remember it to remove duplicates
118127
rootAggs.put(canonical, alias);
119128
// add it to the list of aggregates and continue
@@ -132,6 +141,9 @@ protected LogicalPlan rule(Aggregate aggregate) {
132141
Alias alias = as.replaceChild(aggExpression);
133142
newEvals.add(alias);
134143
newProjections.add(alias.toAttribute());
144+
} else {
145+
newAggs.add(agg);
146+
newProjections.add(agg.toAttribute());
135147
}
136148
}
137149
// not an alias (e.g. grouping field)
@@ -155,6 +167,10 @@ protected LogicalPlan rule(Aggregate aggregate) {
155167
return plan;
156168
}
157169

170+
private static AggregateFunction getCannonical(AggregateFunction af, AttributeMap<Expression> aliases) {
171+
return (AggregateFunction) af.canonical().transformUp(e -> aliases.resolve(e, e));
172+
}
173+
158174
private static String syntheticName(Expression expression, Expression af, int counter) {
159175
return TemporaryNameUtils.temporaryName(expression, af, counter);
160176
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1656,6 +1656,21 @@ public void testLengthInWhereAndEval() {
16561656
/**
16571657
* Pushed LENGTH to the same field in a <strong>ton</strong> of unique and curious ways. All
16581658
* of these pushdowns should be fused to one.
1659+
*
1660+
* <pre>{@code
1661+
* Project[[l{r}#23]]
1662+
* \_Eval[[$$SUM$SUM(LENGTH(last>$0{r$}#37 / $$COUNT$$$AVG$SUM(LENGTH(last>$1$1{r$}#41 AS $$AVG$SUM(LENGTH(last>$1#38, $
1663+
* $SUM$SUM(LENGTH(last>$0{r$}#37 + $$AVG$SUM(LENGTH(last>$1{r$}#38 + $$SUM$SUM(LENGTH(last>$2{r$}#39 AS l#23]]
1664+
* \_Limit[1000[INTEGER],false,false]
1665+
* \_Aggregate[[],[SUM($$LENGTH(last_nam>$SUM$0{r$}#35,true[BOOLEAN],PT0S[TIME_DURATION],compensated[KEYWORD]) AS $$SUM$SUM(LE
1666+
* NGTH(last>$0#37,
1667+
* COUNT(a3{r}#11,true[BOOLEAN],PT0S[TIME_DURATION]) AS $$COUNT$$$AVG$SUM(LENGTH(last>$1$1#41,
1668+
* SUM($$LENGTH(first_na>$SUM$1{r$}#36,true[BOOLEAN],PT0S[TIME_DURATION],compensated[KEYWORD]) AS $$SUM$SUM(LENGTH(last>$2#39]]
1669+
* \_Eval[[$$last_name$LENGTH$920787299{f$}#42 AS a3#11, $$last_name$LENGTH$920787299{f$}#42 AS $$LENGTH(last_nam>$SUM$0
1670+
* #35, $$first_name$LENGTH$920787299{f$}#43 AS $$LENGTH(first_na>$SUM$1#36]]
1671+
* \_Filter[$$last_name$LENGTH$920787299{f$}#42 > 1[INTEGER]]
1672+
* \_EsRelation[test][_meta_field{f}#30, emp_no{f}#24, first_name{f}#25, ..]
1673+
* }</pre>
16591674
*/
16601675
public void testLengthPushdownZoo() {
16611676
assumeTrue("requires push", EsqlCapabilities.Cap.VECTOR_SIMILARITY_FUNCTIONS_PUSHDOWN.isEnabled());
@@ -1674,13 +1689,26 @@ public void testLengthPushdownZoo() {
16741689
// Eval - computes final aggregation result (SUM + AVG + SUM)
16751690
var eval1 = as(project.child(), Eval.class);
16761691
assertThat(eval1.fields(), hasSize(2));
1692+
// The avg is computed as the SUM(LENGTH(last_name)) / COUNT(LENGTH(last_name))
1693+
var avg = eval1.fields().get(0);
1694+
var avgDiv = as(avg.child(), Div.class);
1695+
// SUM(LENGTH(last_name))
1696+
var evalSumLastName = as(avgDiv.left(), ReferenceAttribute.class);
1697+
var evalCountLastName = as(avgDiv.right(), ReferenceAttribute.class);
1698+
var finalAgg = as(eval1.fields().get(1).child(), Add.class);
1699+
var leftFinalAgg = as(finalAgg.left(), Add.class);
1700+
assertThat(leftFinalAgg.left(), equalTo(evalSumLastName));
1701+
assertThat(as(leftFinalAgg.right(), ReferenceAttribute.class).id(), equalTo(avg.id()));
1702+
// SUM(LENGTH(first_name))
1703+
var evalSumFirstName = as(finalAgg.right(), ReferenceAttribute.class);
16771704

16781705
// Limit[1000[INTEGER],false,false]
16791706
var limit = as(eval1.child(), Limit.class);
16801707

1681-
// Aggregate with 4 aggregates: SUM for last_name, SUM and COUNT for AVG(a3), SUM for first_name
1708+
// Aggregate with 3 aggregates: SUM for last_name, COUNT for last_name
1709+
// (the AVG uses the sum and the count), SUM for first_name
16821710
var agg = as(limit.child(), Aggregate.class);
1683-
assertThat(agg.aggregates(), hasSize(4));
1711+
assertThat(agg.aggregates(), hasSize(3));
16841712

16851713
// Eval - pushdown fields: a3, LENGTH(last_name) for SUM, and LENGTH(first_name) for SUM
16861714
var evalPushdown = as(agg.child(), Eval.class);
@@ -1694,14 +1722,23 @@ public void testLengthPushdownZoo() {
16941722
Attribute firstNamePushDownAttr = assertLengthPushdown(firstNamePushdownAlias.child(), "first_name");
16951723

16961724
// Verify aggregates reference the pushed down fields
1697-
var sumForLastName = as(as(agg.aggregates().get(0), Alias.class).child(), Sum.class);
1725+
var sumForLastNameAlias = as(agg.aggregates().get(0), Alias.class);
1726+
var sumForLastName = as(sumForLastNameAlias.child(), Sum.class);
16981727
assertThat(as(sumForLastName.field(), ReferenceAttribute.class).id(), equalTo(lastNamePushdownAlias.id()));
1699-
var sumForAvg = as(as(agg.aggregates().get(1), Alias.class).child(), Sum.class);
1700-
assertThat(as(sumForAvg.field(), ReferenceAttribute.class).id(), equalTo(a3Alias.id()));
1701-
var countForAvg = as(as(agg.aggregates().get(2), Alias.class).child(), Count.class);
1728+
// Checks that the SUM(LENGTH(last_name)) in the final EVAL is the aggregate result here
1729+
assertThat(evalSumLastName.id(), equalTo(sumForLastNameAlias.id()));
1730+
1731+
var countForAvgAlias = as(agg.aggregates().get(1), Alias.class);
1732+
var countForAvg = as(countForAvgAlias.child(), Count.class);
17021733
assertThat(as(countForAvg.field(), ReferenceAttribute.class).id(), equalTo(a3Alias.id()));
1703-
var sumForFirstName = as(as(agg.aggregates().get(3), Alias.class).child(), Sum.class);
1734+
// Checks that the COUNT(LENGTH(last_name)) in the final EVAL is the aggregate result here
1735+
assertThat(evalCountLastName.id(), equalTo(countForAvgAlias.id()));
1736+
1737+
var sumForFirstNameAlias = as(agg.aggregates().get(2), Alias.class);
1738+
var sumForFirstName = as(sumForFirstNameAlias.child(), Sum.class);
17041739
assertThat(as(sumForFirstName.field(), ReferenceAttribute.class).id(), equalTo(firstNamePushdownAlias.id()));
1740+
// Checks that the SUM(LENGTH(first_name)) in the final EVAL is the aggregate result here
1741+
assertThat(evalSumFirstName.id(), equalTo(sumForFirstNameAlias.id()));
17051742

17061743
// Filter[LENGTH(last_name) > 1]
17071744
var filter = as(evalPushdown.child(), Filter.class);

0 commit comments

Comments
 (0)