Skip to content

Commit 3286bff

Browse files
dilipbiswalcloud-fan
authored andcommitted
[SPARK-27255][SQL] Report error when illegal expressions are hosted by a plan operator.
## What changes were proposed in this pull request? In the PR, we raise an AnalysisError when we detect the presense of aggregate expressions in where clause. Here is the problem description from the JIRA. Aggregate functions should not be allowed in WHERE clause. But Spark SQL throws an exception when generating codes. It is supposed to throw an exception during parsing or analyzing. Here is an example: ``` val df = spark.sql("select * from t where sum(ta) > 0") df.explain(true) df.show() ``` Resulting exception: ``` Exception in thread "main" java.lang.UnsupportedOperationException: Cannot generate code for expression: sum(cast(input[0, int, false] as bigint)) at org.apache.spark.sql.catalyst.expressions.Unevaluable.doGenCode(Expression.scala:291) at org.apache.spark.sql.catalyst.expressions.Unevaluable.doGenCode$(Expression.scala:290) at org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression.doGenCode(interfaces.scala:87) at org.apache.spark.sql.catalyst.expressions.Expression.$anonfun$genCode$3(Expression.scala:138) at scala.Option.getOrElse(Option.scala:138) ``` Checked the behaviour of other database and all of them return an exception: **Postgress** ``` select * from foo where max(c1) > 0; Error ERROR: aggregate functions are not allowed in WHERE Position: 25 ``` **DB2** ``` db2 => select * from foo where max(c1) > 0; SQL0120N Invalid use of an aggregate function or OLAP function. ``` **Oracle** ``` select * from foo where max(c1) > 0; ORA-00934: group function is not allowed here ``` **MySql** ``` select * from foo where max(c1) > 0; Invalid use of group function ``` **Update** This PR has been enhanced to report error when expressions such as Aggregate, Window, Generate are hosted by operators where they are invalid. ## How was this patch tested? Added tests in AnalysisErrorSuite and group-by.sql Closes apache#24209 from dilipbiswal/SPARK-27255. Authored-by: Dilip Biswal <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 1d20d13 commit 3286bff

File tree

8 files changed

+165
-38
lines changed

8 files changed

+165
-38
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ trait CheckAnalysis extends PredicateHelper {
178178
s"of type ${condition.dataType.catalogString} is not a boolean.")
179179

180180
case Aggregate(groupingExprs, aggregateExprs, child) =>
181-
def isAggregateExpression(expr: Expression) = {
181+
def isAggregateExpression(expr: Expression): Boolean = {
182182
expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr)
183183
}
184184

@@ -376,6 +376,25 @@ trait CheckAnalysis extends PredicateHelper {
376376
throw new IllegalStateException(
377377
"Internal error: logical hint operator should have been removed during analysis")
378378

379+
case f @ Filter(condition, _)
380+
if PlanHelper.specialExpressionsInUnsupportedOperator(f).nonEmpty =>
381+
val invalidExprSqls = PlanHelper.specialExpressionsInUnsupportedOperator(f).map(_.sql)
382+
failAnalysis(
383+
s"""
384+
|Aggregate/Window/Generate expressions are not valid in where clause of the query.
385+
|Expression in where clause: [${condition.sql}]
386+
|Invalid expressions: [${invalidExprSqls.mkString(", ")}]""".stripMargin)
387+
388+
case other if PlanHelper.specialExpressionsInUnsupportedOperator(other).nonEmpty =>
389+
val invalidExprSqls =
390+
PlanHelper.specialExpressionsInUnsupportedOperator(other).map(_.sql)
391+
failAnalysis(
392+
s"""
393+
|The query operator `${other.nodeName}` contains one or more unsupported
394+
|expression types Aggregate, Window or Generate.
395+
|Invalid expressions: [${invalidExprSqls.mkString(", ")}]""".stripMargin
396+
)
397+
379398
case _ => // Analysis successful!
380399
}
381400
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -43,38 +43,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
4343
// - is still resolved
4444
// - only host special expressions in supported operators
4545
override protected def isPlanIntegral(plan: LogicalPlan): Boolean = {
46-
!Utils.isTesting || (plan.resolved && checkSpecialExpressionIntegrity(plan))
47-
}
48-
49-
/**
50-
* Check if all operators in this plan hold structural integrity with regards to hosting special
51-
* expressions.
52-
* Returns true when all operators are integral.
53-
*/
54-
private def checkSpecialExpressionIntegrity(plan: LogicalPlan): Boolean = {
55-
plan.find(specialExpressionInUnsupportedOperator).isEmpty
56-
}
57-
58-
/**
59-
* Check if there's any expression in this query plan operator that is
60-
* - A WindowExpression but the plan is not Window
61-
* - An AggregateExpresion but the plan is not Aggregate or Window
62-
* - A Generator but the plan is not Generate
63-
* Returns true when this operator breaks structural integrity with one of the cases above.
64-
*/
65-
private def specialExpressionInUnsupportedOperator(plan: LogicalPlan): Boolean = {
66-
val exprs = plan.expressions
67-
exprs.flatMap { root =>
68-
root.find {
69-
case e: WindowExpression
70-
if !plan.isInstanceOf[Window] => true
71-
case e: AggregateExpression
72-
if !(plan.isInstanceOf[Aggregate] || plan.isInstanceOf[Window]) => true
73-
case e: Generator
74-
if !plan.isInstanceOf[Generate] => true
75-
case _ => false
76-
}
77-
}.nonEmpty
46+
!Utils.isTesting || (plan.resolved &&
47+
plan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty)
7848
}
7949

8050
protected def fixedPoint = FixedPoint(SQLConf.get.optimizerMaxIterations)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.plans.logical
19+
20+
import org.apache.spark.sql.catalyst.expressions.{Expression, Generator, WindowExpression}
21+
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
22+
23+
/**
24+
* [[PlanHelper]] contains utility methods that can be used by Analyzer and Optimizer.
25+
* It can also be container of methods that are common across multiple rules in Analyzer
26+
* and Optimizer.
27+
*/
28+
object PlanHelper {
29+
/**
30+
* Check if there's any expression in this query plan operator that is
31+
* - A WindowExpression but the plan is not Window
32+
* - An AggregateExpresion but the plan is not Aggregate or Window
33+
* - A Generator but the plan is not Generate
34+
* Returns the list of invalid expressions that this operator hosts. This can happen when
35+
* 1. The input query from users contain invalid expressions.
36+
* Example : SELECT * FROM tab WHERE max(c1) > 0
37+
* 2. Query rewrites inadvertently produce plans that are invalid.
38+
*/
39+
def specialExpressionsInUnsupportedOperator(plan: LogicalPlan): Seq[Expression] = {
40+
val exprs = plan.expressions
41+
val invalidExpressions = exprs.flatMap { root =>
42+
root.collect {
43+
case e: WindowExpression
44+
if !plan.isInstanceOf[Window] => e
45+
case e: AggregateExpression
46+
if !(plan.isInstanceOf[Aggregate] || plan.isInstanceOf[Window]) => e
47+
case e: Generator
48+
if !plan.isInstanceOf[Generate] => e
49+
}
50+
}
51+
invalidExpressions
52+
}
53+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,4 +599,12 @@ class AnalysisErrorSuite extends AnalysisTest {
599599
assertAnalysisError(plan5,
600600
"Accessing outer query column is not allowed in" :: Nil)
601601
}
602+
603+
test("Error on filter condition containing aggregate expressions") {
604+
val a = AttributeReference("a", IntegerType)()
605+
val b = AttributeReference("b", IntegerType)()
606+
val plan = Filter('a === UnresolvedFunction("max", Seq(b), true), LocalRelation(a, b))
607+
assertAnalysisError(plan,
608+
"Aggregate/Window/Generate expressions are not valid in where clause of the query" :: Nil)
609+
}
602610
}

sql/core/src/test/resources/sql-tests/inputs/group-by.sql

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,16 @@ SELECT every("true");
141141
SELECT k, v, every(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg;
142142
SELECT k, v, some(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg;
143143
SELECT k, v, any(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg;
144+
145+
-- Having referencing aggregate expressions is ok.
146+
SELECT count(*) FROM test_agg HAVING count(*) > 1L;
147+
SELECT k, max(v) FROM test_agg GROUP BY k HAVING max(v) = true;
148+
149+
-- Aggrgate expressions can be referenced through an alias
150+
SELECT * FROM (SELECT COUNT(*) AS cnt FROM test_agg) WHERE cnt > 1L;
151+
152+
-- Error when aggregate expressions are in where clause directly
153+
SELECT count(*) FROM test_agg WHERE count(*) > 1L;
154+
SELECT count(*) FROM test_agg WHERE count(*) + 1L > 1L;
155+
SELECT count(*) FROM test_agg WHERE k = 1 or k = 2 or count(*) + 1L > 1L or max(k) > 1;
156+

sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,10 @@ WHERE t1a IN (SELECT min(t2a)
4646
SELECT t1a
4747
FROM t1
4848
GROUP BY 1
49-
HAVING EXISTS (SELECT 1
49+
HAVING EXISTS (SELECT t2a
5050
FROM t2
51-
WHERE t2a < min(t1a + t2a));
51+
GROUP BY 1
52+
HAVING t2a < min(t1a + t2a));
5253

5354
-- TC 01.04
5455
-- Invalid due to mixure of outer and local references under an AggegatedExpression

sql/core/src/test/resources/sql-tests/results/group-by.sql.out

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 46
2+
-- Number of queries: 52
33

44

55
-- !query 0
@@ -459,3 +459,65 @@ struct<k:int,v:boolean,any(v) OVER (PARTITION BY k ORDER BY v ASC NULLS FIRST RA
459459
5 NULL NULL
460460
5 false false
461461
5 true true
462+
463+
464+
-- !query 46
465+
SELECT count(*) FROM test_agg HAVING count(*) > 1L
466+
-- !query 46 schema
467+
struct<count(1):bigint>
468+
-- !query 46 output
469+
10
470+
471+
472+
-- !query 47
473+
SELECT k, max(v) FROM test_agg GROUP BY k HAVING max(v) = true
474+
-- !query 47 schema
475+
struct<k:int,max(v):boolean>
476+
-- !query 47 output
477+
1 true
478+
2 true
479+
5 true
480+
481+
482+
-- !query 48
483+
SELECT * FROM (SELECT COUNT(*) AS cnt FROM test_agg) WHERE cnt > 1L
484+
-- !query 48 schema
485+
struct<cnt:bigint>
486+
-- !query 48 output
487+
10
488+
489+
490+
-- !query 49
491+
SELECT count(*) FROM test_agg WHERE count(*) > 1L
492+
-- !query 49 schema
493+
struct<>
494+
-- !query 49 output
495+
org.apache.spark.sql.AnalysisException
496+
497+
Aggregate/Window/Generate expressions are not valid in where clause of the query.
498+
Expression in where clause: [(count(1) > 1L)]
499+
Invalid expressions: [count(1)];
500+
501+
502+
-- !query 50
503+
SELECT count(*) FROM test_agg WHERE count(*) + 1L > 1L
504+
-- !query 50 schema
505+
struct<>
506+
-- !query 50 output
507+
org.apache.spark.sql.AnalysisException
508+
509+
Aggregate/Window/Generate expressions are not valid in where clause of the query.
510+
Expression in where clause: [((count(1) + 1L) > 1L)]
511+
Invalid expressions: [count(1)];
512+
513+
514+
-- !query 51
515+
SELECT count(*) FROM test_agg WHERE k = 1 or k = 2 or count(*) + 1L > 1L or max(k) > 1
516+
-- !query 51 schema
517+
struct<>
518+
-- !query 51 output
519+
org.apache.spark.sql.AnalysisException
520+
521+
Aggregate/Window/Generate expressions are not valid in where clause of the query.
522+
Expression in where clause: [(((test_agg.`k` = 1) OR (test_agg.`k` = 2)) OR (((count(1) + 1L) > 1L) OR (max(test_agg.`k`) > 1)))]
523+
Invalid expressions: [count(1), max(test_agg.`k`)];

sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,10 @@ Resolved attribute(s) t2b#x missing from min(t2a)#x,t2c#x in operator !Filter t2
7070
SELECT t1a
7171
FROM t1
7272
GROUP BY 1
73-
HAVING EXISTS (SELECT 1
73+
HAVING EXISTS (SELECT t2a
7474
FROM t2
75-
WHERE t2a < min(t1a + t2a))
75+
GROUP BY 1
76+
HAVING t2a < min(t1a + t2a))
7677
-- !query 5 schema
7778
struct<>
7879
-- !query 5 output

0 commit comments

Comments
 (0)