Skip to content

Commit 08675b1

Browse files
dtenedorcloud-fan
authored andcommitted
[SPARK-50630][SQL] Fix GROUP BY ordinal support for pipe SQL AGGREGATE operators
### What changes were proposed in this pull request? This PR fixes GROUP BY ordinal support for pipe SQL AGGREGATE operators. It adds a new `UnresolvedPipeAggregateOrdinal` expression to represent these ordinals. In this context, the ordinal refers to the one-based position of the column in the input relation. Note that this behavior is different from GROUP BY ordinals in regular SQL, wherein the ordinal refers to the one-based position of the column in the SELECT clause instead. For example: ``` select 3 as x, 4 as y, 5 as z |> aggregate sum(y) group by 2, 3 > 4, 5, 4 select 3 as x, 4 as y, 5 as z |> aggregate sum(y) group by 1, 2, 3 > 3, 4, 5, 4 ``` This PR also makes a small fix for `|> UNION` (and other set operations) to prefer future pipe operators to apply on the result of the entire union, rather than binding to the right leg of the union only (to allay reported confusion during testing). For example, `values (0, 1) s(x, y) |> union all values (2, 3) t(x, y) |> drop x` will succeed rather than report an error that the number of columns does not match. ### Why are the changes needed? The current implementation has a bug where the ordinals are sometimes mistakenly retained as literal integers. ### Does this PR introduce _any_ user-facing change? Yes, see above. ### How was this patch tested? This PR adds new golden file based test coverage. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #49248 from dtenedor/group-by-ordinals-pipe-aggregate. Authored-by: Daniel Tenedorio <daniel.tenedorio@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent ab95c4e commit 08675b1

File tree

7 files changed

+496
-46
lines changed

7 files changed

+496
-46
lines changed

sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1523,7 +1523,7 @@ operatorPipeRightSide
15231523
| unpivotClause pivotClause?
15241524
| sample
15251525
| joinRelation
1526-
| operator=(UNION | EXCEPT | SETMINUS | INTERSECT) setQuantifier? right=queryTerm
1526+
| operator=(UNION | EXCEPT | SETMINUS | INTERSECT) setQuantifier? right=queryPrimary
15271527
| queryOrganization
15281528
| AGGREGATE namedExpressionSeq? aggregationClause?
15291529
;

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1887,10 +1887,14 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
18871887

18881888
// Replace the index with the corresponding expression in aggregateExpressions. The index is
18891889
// a 1-base position of aggregateExpressions, which is output columns (select expression)
1890-
case Aggregate(groups, aggs, child, hint) if aggs.forall(_.resolved) &&
1890+
case Aggregate(groups, aggs, child, hint)
1891+
if aggs
1892+
.filter(!containUnresolvedPipeAggregateOrdinal(_))
1893+
.forall(_.resolved) &&
18911894
groups.exists(containUnresolvedOrdinal) =>
1892-
val newGroups = groups.map(resolveGroupByExpressionOrdinal(_, aggs))
1893-
Aggregate(newGroups, aggs, child, hint)
1895+
val newAggs = aggs.map(resolvePipeAggregateExpressionOrdinal(_, child.output))
1896+
val newGroups = groups.map(resolveGroupByExpressionOrdinal(_, newAggs))
1897+
Aggregate(newGroups, newAggs, child, hint)
18941898
}
18951899

18961900
private def containUnresolvedOrdinal(e: Expression): Boolean = e match {
@@ -1899,6 +1903,11 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
18991903
case _ => false
19001904
}
19011905

1906+
private def containUnresolvedPipeAggregateOrdinal(e: Expression): Boolean = e match {
1907+
case UnresolvedAlias(_: UnresolvedPipeAggregateOrdinal, _) => true
1908+
case _ => false
1909+
}
1910+
19021911
private def resolveGroupByExpressionOrdinal(
19031912
expr: Expression,
19041913
aggs: Seq[Expression]): Expression = expr match {
@@ -1934,6 +1943,17 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
19341943
}
19351944
}
19361945

1946+
private def resolvePipeAggregateExpressionOrdinal(
1947+
expr: NamedExpression,
1948+
inputs: Seq[Attribute]): NamedExpression = expr match {
1949+
case UnresolvedAlias(UnresolvedPipeAggregateOrdinal(index), _) =>
1950+
// In this case, the user applied the SQL pipe aggregate operator ("|> AGGREGATE") and used
1951+
// ordinals in its GROUP BY clause. This expression then refers to the i-th attribute of the
1952+
// child operator (one-based). Here we resolve the ordinal to the corresponding attribute.
1953+
inputs(index - 1)
1954+
case other =>
1955+
other
1956+
}
19371957

19381958
/**
19391959
* Checks whether a function identifier referenced by an [[UnresolvedFunction]] is defined in the

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,28 @@ case class UnresolvedOrdinal(ordinal: Int)
956956
final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_ORDINAL)
957957
}
958958

959+
/**
960+
* Represents an unresolved ordinal used in the GROUP BY clause of a SQL pipe aggregate operator
961+
* ("|> AGGREGATE").
962+
*
963+
* In this context, the ordinal refers to the one-based position of the column in the input
964+
* relation. Note that this behavior is different from GROUP BY ordinals in regular SQL, wherein the
965+
* ordinal refers to the one-based position of the column in the SELECT clause.
966+
*
967+
* For example:
968+
* {{{
969+
* values ('abc', 'def') tab(x, y)
970+
* |> aggregate sum(x) group by 2
971+
* }}}
972+
* @param ordinal ordinal starts from 1, instead of 0
973+
*/
974+
case class UnresolvedPipeAggregateOrdinal(ordinal: Int)
975+
extends LeafExpression with Unevaluable with NonSQLExpression {
976+
override def dataType: DataType = throw new UnresolvedException("dataType")
977+
override def nullable: Boolean = throw new UnresolvedException("nullable")
978+
override lazy val resolved = false
979+
}
980+
959981
/**
960982
* Represents unresolved having clause, the child for it can be Aggregate, GroupingSets, Rollup
961983
* and Cube. It is turned by the analyzer into a Filter.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6016,7 +6016,8 @@ class AstBuilder extends DataTypeAstBuilder
60166016
// analyzer behave as if we had added the corresponding SQL clause after a table subquery
60176017
// containing the input plan.
60186018
def withSubqueryAlias(): LogicalPlan = left match {
6019-
case _: SubqueryAlias | _: UnresolvedRelation | _: Join | _: Filter =>
6019+
case _: SubqueryAlias | _: UnresolvedRelation | _: Join | _: Filter |
6020+
_: GlobalLimit | _: LocalLimit | _: Offset | _: Sort =>
60206021
left
60216022
case _ =>
60226023
SubqueryAlias(SubqueryAlias.generateSubqueryName(), left)
@@ -6137,7 +6138,7 @@ class AstBuilder extends DataTypeAstBuilder
61376138
"The AGGREGATE clause requires a list of aggregate expressions " +
61386139
"or a list of grouping expressions, or both", ctx)
61396140
}
6140-
// Visit each aggregate expression, and add a PipeAggregate expression on top of it to generate
6141+
// Visit each aggregate expression, and add a [[PipeExpression]] on top of it to generate
61416142
// clear error messages if the expression does not contain at least one aggregate function.
61426143
val aggregateExpressions: Seq[NamedExpression] =
61436144
Option(ctx.namedExpressionSeq()).map { n: NamedExpressionSeqContext =>
@@ -6183,12 +6184,28 @@ class AstBuilder extends DataTypeAstBuilder
61836184
a.aggregateExpressions.foreach(visit)
61846185
// Prepend grouping keys to the list of aggregate functions, since operator pipe AGGREGATE
61856186
// clause returns the GROUP BY expressions followed by the list of aggregate functions.
6186-
val namedGroupingExpressions: Seq[NamedExpression] =
6187-
a.groupingExpressions.map {
6188-
case n: NamedExpression => n
6189-
case e: Expression => UnresolvedAlias(e, None)
6190-
}
6191-
a.copy(aggregateExpressions = namedGroupingExpressions ++ a.aggregateExpressions)
6187+
val newGroupingExpressions = ArrayBuffer.empty[Expression]
6188+
val newAggregateExpressions = ArrayBuffer.empty[NamedExpression]
6189+
a.groupingExpressions.foreach {
6190+
case n: NamedExpression =>
6191+
newGroupingExpressions += n
6192+
newAggregateExpressions += n
6193+
// If the grouping expression is an integer literal, create [[UnresolvedOrdinal]] and
6194+
// [[UnresolvedPipeAggregateOrdinal]] expressions to represent it in the final grouping
6195+
// and aggregate expressions, respectively. This will let the
6196+
// [[ResolveOrdinalInOrderByAndGroupBy]] rule detect the ordinal in the aggregate list
6197+
// and replace it with the corresponding attribute from the child operator.
6198+
case Literal(v: Int, IntegerType) if conf.groupByOrdinal =>
6199+
newGroupingExpressions += UnresolvedOrdinal(newAggregateExpressions.length + 1)
6200+
newAggregateExpressions += UnresolvedAlias(UnresolvedPipeAggregateOrdinal(v), None)
6201+
case e: Expression =>
6202+
newGroupingExpressions += e
6203+
newAggregateExpressions += UnresolvedAlias(e, None)
6204+
}
6205+
newAggregateExpressions.appendAll(a.aggregateExpressions)
6206+
a.copy(
6207+
groupingExpressions = newGroupingExpressions.toSeq,
6208+
aggregateExpressions = newAggregateExpressions.toSeq)
61926209
}
61936210
}.getOrElse {
61946211
// This is a table aggregation with no grouping expressions.

sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out

Lines changed: 187 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1539,6 +1539,78 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
15391539
}
15401540

15411541

1542+
-- !query
1543+
table t
1544+
|> select x, length(y) as z
1545+
|> limit 1000
1546+
|> where x + length(y) < 4
1547+
-- !query analysis
1548+
org.apache.spark.sql.catalyst.ExtendedAnalysisException
1549+
{
1550+
"errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION",
1551+
"sqlState" : "42703",
1552+
"messageParameters" : {
1553+
"objectName" : "`y`",
1554+
"proposal" : "`x`, `z`"
1555+
},
1556+
"queryContext" : [ {
1557+
"objectType" : "",
1558+
"objectName" : "",
1559+
"startIndex" : 71,
1560+
"stopIndex" : 71,
1561+
"fragment" : "y"
1562+
} ]
1563+
}
1564+
1565+
1566+
-- !query
1567+
table t
1568+
|> select x, length(y) as z
1569+
|> limit 1000 offset 1
1570+
|> where x + length(y) < 4
1571+
-- !query analysis
1572+
org.apache.spark.sql.catalyst.ExtendedAnalysisException
1573+
{
1574+
"errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION",
1575+
"sqlState" : "42703",
1576+
"messageParameters" : {
1577+
"objectName" : "`y`",
1578+
"proposal" : "`x`, `z`"
1579+
},
1580+
"queryContext" : [ {
1581+
"objectType" : "",
1582+
"objectName" : "",
1583+
"startIndex" : 80,
1584+
"stopIndex" : 80,
1585+
"fragment" : "y"
1586+
} ]
1587+
}
1588+
1589+
1590+
-- !query
1591+
table t
1592+
|> select x, length(y) as z
1593+
|> order by x, y
1594+
|> where x + length(y) < 4
1595+
-- !query analysis
1596+
org.apache.spark.sql.catalyst.ExtendedAnalysisException
1597+
{
1598+
"errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION",
1599+
"sqlState" : "42703",
1600+
"messageParameters" : {
1601+
"objectName" : "`y`",
1602+
"proposal" : "`x`, `z`"
1603+
},
1604+
"queryContext" : [ {
1605+
"objectType" : "",
1606+
"objectName" : "",
1607+
"startIndex" : 52,
1608+
"stopIndex" : 52,
1609+
"fragment" : "y"
1610+
} ]
1611+
}
1612+
1613+
15421614
-- !query
15431615
(select x, sum(length(y)) as sum_len from t group by x)
15441616
|> where sum(length(y)) = 3
@@ -2697,21 +2769,34 @@ Union false, false
26972769

26982770

26992771
-- !query
2700-
values (0, 1) tab(x, y)
2772+
values (2, 'xyz') tab(x, y)
27012773
|> union table t
27022774
|> where x = 0
27032775
-- !query analysis
2704-
Distinct
2705-
+- Union false, false
2706-
:- Project [x#x, cast(y#x as bigint) AS y#xL]
2707-
: +- SubqueryAlias tab
2708-
: +- LocalRelation [x#x, y#x]
2709-
+- Project [x#x, cast(y#x as bigint) AS y#xL]
2710-
+- Filter (x#x = 0)
2776+
Filter (x#x = 0)
2777+
+- SubqueryAlias __auto_generated_subquery_name
2778+
+- Distinct
2779+
+- Union false, false
2780+
:- SubqueryAlias tab
2781+
: +- LocalRelation [x#x, y#x]
27112782
+- SubqueryAlias spark_catalog.default.t
27122783
+- Relation spark_catalog.default.t[x#x,y#x] csv
27132784

27142785

2786+
-- !query
2787+
values (2, 'xyz') tab(x, y)
2788+
|> union table t
2789+
|> drop x
2790+
-- !query analysis
2791+
Project [y#x]
2792+
+- Distinct
2793+
+- Union false, false
2794+
:- SubqueryAlias tab
2795+
: +- LocalRelation [x#x, y#x]
2796+
+- SubqueryAlias spark_catalog.default.t
2797+
+- Relation spark_catalog.default.t[x#x,y#x] csv
2798+
2799+
27152800
-- !query
27162801
(select * from t)
27172802
|> union all (select * from t)
@@ -2878,10 +2963,9 @@ table t
28782963
-- !query analysis
28792964
GlobalLimit 1
28802965
+- LocalLimit 1
2881-
+- SubqueryAlias __auto_generated_subquery_name
2882-
+- Sort [x#x ASC NULLS FIRST], true
2883-
+- SubqueryAlias spark_catalog.default.t
2884-
+- Relation spark_catalog.default.t[x#x,y#x] csv
2966+
+- Sort [x#x ASC NULLS FIRST], true
2967+
+- SubqueryAlias spark_catalog.default.t
2968+
+- Relation spark_catalog.default.t[x#x,y#x] csv
28852969

28862970

28872971
-- !query
@@ -3109,11 +3193,101 @@ Aggregate [x#x, y#x], [x#x, y#x]
31093193
select 3 as x, 4 as y
31103194
|> aggregate group by 1, 2
31113195
-- !query analysis
3112-
Aggregate [1, 2], [1 AS 1#x, 2 AS 2#x]
3196+
Aggregate [x#x, y#x], [x#x, y#x]
31133197
+- Project [3 AS x#x, 4 AS y#x]
31143198
+- OneRowRelation
31153199

31163200

3201+
-- !query
3202+
values (3, 4) as tab(x, y)
3203+
|> aggregate sum(y) group by 1
3204+
-- !query analysis
3205+
Aggregate [x#x], [x#x, pipeexpression(sum(y#x), true, AGGREGATE) AS pipeexpression(sum(y))#xL]
3206+
+- SubqueryAlias tab
3207+
+- LocalRelation [x#x, y#x]
3208+
3209+
3210+
-- !query
3211+
values (3, 4), (5, 4) as tab(x, y)
3212+
|> aggregate sum(y) group by 1
3213+
-- !query analysis
3214+
Aggregate [x#x], [x#x, pipeexpression(sum(y#x), true, AGGREGATE) AS pipeexpression(sum(y))#xL]
3215+
+- SubqueryAlias tab
3216+
+- LocalRelation [x#x, y#x]
3217+
3218+
3219+
-- !query
3220+
select 3 as x, 4 as y
3221+
|> aggregate sum(y) group by 1, 1
3222+
-- !query analysis
3223+
Aggregate [x#x, x#x], [x#x, x#x, pipeexpression(sum(y#x), true, AGGREGATE) AS pipeexpression(sum(y))#xL]
3224+
+- Project [3 AS x#x, 4 AS y#x]
3225+
+- OneRowRelation
3226+
3227+
3228+
-- !query
3229+
select 1 as `1`, 2 as `2`
3230+
|> aggregate sum(`2`) group by `1`
3231+
-- !query analysis
3232+
Aggregate [1#x], [1#x, pipeexpression(sum(2#x), true, AGGREGATE) AS pipeexpression(sum(2))#xL]
3233+
+- Project [1 AS 1#x, 2 AS 2#x]
3234+
+- OneRowRelation
3235+
3236+
3237+
-- !query
3238+
select 3 as x, 4 as y
3239+
|> aggregate sum(y) group by 2
3240+
-- !query analysis
3241+
Aggregate [y#x], [y#x, pipeexpression(sum(y#x), true, AGGREGATE) AS pipeexpression(sum(y))#xL]
3242+
+- Project [3 AS x#x, 4 AS y#x]
3243+
+- OneRowRelation
3244+
3245+
3246+
-- !query
3247+
select 3 as x, 4 as y, 5 as z
3248+
|> aggregate sum(y) group by 2
3249+
-- !query analysis
3250+
Aggregate [y#x], [y#x, pipeexpression(sum(y#x), true, AGGREGATE) AS pipeexpression(sum(y))#xL]
3251+
+- Project [3 AS x#x, 4 AS y#x, 5 AS z#x]
3252+
+- OneRowRelation
3253+
3254+
3255+
-- !query
3256+
select 3 as x, 4 as y, 5 as z
3257+
|> aggregate sum(y) group by 3
3258+
-- !query analysis
3259+
Aggregate [z#x], [z#x, pipeexpression(sum(y#x), true, AGGREGATE) AS pipeexpression(sum(y))#xL]
3260+
+- Project [3 AS x#x, 4 AS y#x, 5 AS z#x]
3261+
+- OneRowRelation
3262+
3263+
3264+
-- !query
3265+
select 3 as x, 4 as y, 5 as z
3266+
|> aggregate sum(y) group by 2, 3
3267+
-- !query analysis
3268+
Aggregate [y#x, z#x], [y#x, z#x, pipeexpression(sum(y#x), true, AGGREGATE) AS pipeexpression(sum(y))#xL]
3269+
+- Project [3 AS x#x, 4 AS y#x, 5 AS z#x]
3270+
+- OneRowRelation
3271+
3272+
3273+
-- !query
3274+
select 3 as x, 4 as y, 5 as z
3275+
|> aggregate sum(y) group by 1, 2, 3
3276+
-- !query analysis
3277+
Aggregate [x#x, y#x, z#x], [x#x, y#x, z#x, pipeexpression(sum(y#x), true, AGGREGATE) AS pipeexpression(sum(y))#xL]
3278+
+- Project [3 AS x#x, 4 AS y#x, 5 AS z#x]
3279+
+- OneRowRelation
3280+
3281+
3282+
-- !query
3283+
select 3 as x, 4 as y, 5 as z
3284+
|> aggregate sum(y) group by x, 2, 3
3285+
-- !query analysis
3286+
Aggregate [x#x, y#x, z#x], [x#x, y#x, z#x, pipeexpression(sum(y#x), true, AGGREGATE) AS pipeexpression(sum(y))#xL]
3287+
+- Project [3 AS x#x, 4 AS y#x, 5 AS z#x]
3288+
+- OneRowRelation
3289+
3290+
31173291
-- !query
31183292
table t
31193293
|> aggregate sum(x)

0 commit comments

Comments
 (0)