Skip to content
This repository was archived by the owner on Jan 9, 2020. It is now read-only.

Commit 21c4450

Browse files
DonnyZonegatorsmile
authored andcommitted
[SPARK-21980][SQL] References in grouping functions should be indexed with semanticEquals
## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-21980 This PR fixes the issue in ResolveGroupingAnalytics rule, which indexes the column references in grouping functions without considering case sensitive configurations. The problem can be reproduced by: `val df = spark.createDataFrame(Seq((1, 1), (2, 1), (2, 2))).toDF("a", "b") df.cube("a").agg(grouping("A")).show()` ## How was this patch tested? unit tests Author: donnyzone <[email protected]> Closes apache#19202 from DonnyZone/ResolveGroupingAnalytics.
1 parent b6ef1f5 commit 21c4450

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ class Analyzer(
314314
s"grouping columns (${groupByExprs.mkString(",")})")
315315
}
316316
case e @ Grouping(col: Expression) =>
317-
val idx = groupByExprs.indexOf(col)
317+
val idx = groupByExprs.indexWhere(_.semanticEquals(col))
318318
if (idx >= 0) {
319319
Alias(Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)),
320320
Literal(1)), ByteType), toPrettySQL(e))()

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,22 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
190190
)
191191
}
192192

193+
test("SPARK-21980: References in grouping functions should be indexed with semanticEquals") {
194+
checkAnswer(
195+
courseSales.cube("course", "year")
196+
.agg(grouping("CouRse"), grouping("year")),
197+
Row("Java", 2012, 0, 0) ::
198+
Row("Java", 2013, 0, 0) ::
199+
Row("Java", null, 0, 1) ::
200+
Row("dotNET", 2012, 0, 0) ::
201+
Row("dotNET", 2013, 0, 0) ::
202+
Row("dotNET", null, 0, 1) ::
203+
Row(null, 2012, 1, 0) ::
204+
Row(null, 2013, 1, 0) ::
205+
Row(null, null, 1, 1) :: Nil
206+
)
207+
}
208+
193209
test("rollup overlapping columns") {
194210
checkAnswer(
195211
testData2.rollup($"a" + $"b" as "foo", $"b" as "bar").agg(sum($"a" - $"b") as "foo"),

0 commit comments

Comments
 (0)