Skip to content
This repository was archived by the owner on Jan 9, 2020. It is now read-only.

Commit f999312

Browse files
nsycahvanhovell
authored andcommitted
[SPARK-18814][SQL] CheckAnalysis rejects TPCDS query 32
## What changes were proposed in this pull request? Move the checking of GROUP BY column in correlated scalar subquery from CheckAnalysis to Analysis to fix a regression caused by SPARK-18504. This problem can be reproduced with a simple script now. Seq((1,1)).toDF("pk","pv").createOrReplaceTempView("p") Seq((1,1)).toDF("ck","cv").createOrReplaceTempView("c") sql("select * from p,c where p.pk=c.ck and c.cv = (select avg(c1.cv) from c c1 where c1.ck = p.pk)").show The requirements are: 1. We need to reference the same table twice in both the parent and the subquery. Here is the table c. 2. We need to have a correlated predicate but to a different table. Here is from c (as c1) in the subquery to p in the parent. 3. We will then "deduplicate" c1.ck in the subquery to `ck#<n1>#<n2>` at `Project` above `Aggregate` of `avg`. Then when we compare `ck#<n1>#<n2>` and the original group by column `ck#<n1>` by their canonicalized form, which is #<n2> != #<n1>. That's how we trigger the exception added in SPARK-18504. ## How was this patch tested? SubquerySuite and a simplified version of TPCDS-Q32 Author: Nattavut Sutyanyong <[email protected]> Closes apache#16246 from nsyca/18814. (cherry picked from commit cccd643) Signed-off-by: Herman van Hovell <[email protected]>
1 parent 8ef0059 commit f999312

File tree

4 files changed

+90
-9
lines changed

4 files changed

+90
-9
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,10 @@ trait CheckAnalysis extends PredicateHelper {
124124
s"Scalar subquery must return only one column, but got ${query.output.size}")
125125

126126
case s @ ScalarSubquery(query, conditions, _) if conditions.nonEmpty =>
127+
128+
// Collect the columns from the subquery for further checking.
129+
var subqueryColumns = conditions.flatMap(_.references).filter(query.output.contains)
130+
127131
def checkAggregate(agg: Aggregate): Unit = {
128132
// Make sure correlated scalar subqueries contain one row for every outer row by
129133
// enforcing that they are aggregates which contain exactly one aggregate expressions.
@@ -136,24 +140,35 @@ trait CheckAnalysis extends PredicateHelper {
136140
failAnalysis("The output of a correlated scalar subquery must be aggregated")
137141
}
138142

139-
// SPARK-18504: block cases where GROUP BY columns
140-
// are not part of the correlated columns
141-
val groupByCols = ExpressionSet.apply(agg.groupingExpressions.flatMap(_.references))
142-
val predicateCols = ExpressionSet.apply(conditions.flatMap(_.references))
143-
val invalidCols = groupByCols.diff(predicateCols)
143+
// SPARK-18504/SPARK-18814: Block cases where GROUP BY columns
144+
// are not part of the correlated columns.
145+
val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references))
146+
val correlatedCols = AttributeSet(subqueryColumns)
147+
val invalidCols = groupByCols -- correlatedCols
144148
// GROUP BY columns must be a subset of columns in the predicates
145149
if (invalidCols.nonEmpty) {
146150
failAnalysis(
147-
"a GROUP BY clause in a scalar correlated subquery " +
151+
"A GROUP BY clause in a scalar correlated subquery " +
148152
"cannot contain non-correlated columns: " +
149153
invalidCols.mkString(","))
150154
}
151155
}
152156

153-
// Skip projects and subquery aliases added by the Analyzer and the SQLBuilder.
157+
// Skip subquery aliases added by the Analyzer and the SQLBuilder.
158+
// For projects, do the necessary mapping and skip to its child.
154159
def cleanQuery(p: LogicalPlan): LogicalPlan = p match {
155160
case s: SubqueryAlias => cleanQuery(s.child)
156-
case p: Project => cleanQuery(p.child)
161+
case p: Project =>
162+
// SPARK-18814: Map any aliases to their AttributeReference children
163+
// for the checking in the Aggregate operators below this Project.
164+
subqueryColumns = subqueryColumns.map {
165+
xs => p.projectList.collectFirst {
166+
case e @ Alias(child : AttributeReference, _) if e.exprId == xs.exprId =>
167+
child
168+
}.getOrElse(xs)
169+
}
170+
171+
cleanQuery(p.child)
157172
case child => child
158173
}
159174

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
CREATE OR REPLACE TEMPORARY VIEW p AS VALUES (1, 1) AS T(pk, pv);
2+
CREATE OR REPLACE TEMPORARY VIEW c AS VALUES (1, 1) AS T(ck, cv);
3+
4+
-- SPARK-18814.1: Simplified version of TPCDS-Q32
5+
SELECT pk, cv
6+
FROM p, c
7+
WHERE p.pk = c.ck
8+
AND c.cv = (SELECT avg(c1.cv)
9+
FROM c c1
10+
WHERE c1.ck = p.pk);
11+
12+
-- SPARK-18814.2: Adding stack of aggregates
13+
SELECT pk, cv
14+
FROM p, c
15+
WHERE p.pk = c.ck
16+
AND c.cv = (SELECT max(avg)
17+
FROM (SELECT c1.cv, avg(c1.cv) avg
18+
FROM c c1
19+
WHERE c1.ck = p.pk
20+
GROUP BY c1.cv));
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
-- Automatically generated by SQLQueryTestSuite
2+
-- Number of queries: 4
3+
4+
5+
-- !query 0
6+
CREATE OR REPLACE TEMPORARY VIEW p AS VALUES (1, 1) AS T(pk, pv)
7+
-- !query 0 schema
8+
struct<>
9+
-- !query 0 output
10+
11+
12+
13+
-- !query 1
14+
CREATE OR REPLACE TEMPORARY VIEW c AS VALUES (1, 1) AS T(ck, cv)
15+
-- !query 1 schema
16+
struct<>
17+
-- !query 1 output
18+
19+
20+
21+
-- !query 2
22+
SELECT pk, cv
23+
FROM p, c
24+
WHERE p.pk = c.ck
25+
AND c.cv = (SELECT avg(c1.cv)
26+
FROM c c1
27+
WHERE c1.ck = p.pk)
28+
-- !query 2 schema
29+
struct<pk:int,cv:int>
30+
-- !query 2 output
31+
1 1
32+
33+
34+
-- !query 3
35+
SELECT pk, cv
36+
FROM p, c
37+
WHERE p.pk = c.ck
38+
AND c.cv = (SELECT max(avg)
39+
FROM (SELECT c1.cv, avg(c1.cv) avg
40+
FROM c c1
41+
WHERE c1.ck = p.pk
42+
GROUP BY c1.cv))
43+
-- !query 3 schema
44+
struct<pk:int,cv:int>
45+
-- !query 3 output
46+
1 1

sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
491491
sql("select (select sum(-1) from t t2 where t1.c2 = t2.c1 group by t2.c2) sum from t t1")
492492
}
493493
assert(errMsg.getMessage.contains(
494-
"a GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns:"))
494+
"A GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns:"))
495495
}
496496
}
497497

0 commit comments

Comments
 (0)