Skip to content

Commit 5596ce8

Browse files
aokolnychyigatorsmile
authored andcommitted
[MINOR][SQL] Additional test case for CheckCartesianProducts rule
## What changes were proposed in this pull request? While discovering optimization rules and their test coverage, I did not find any tests for `CheckCartesianProducts` in the Catalyst folder. So, I decided to create a new test suite. Once I finished, I found a test in `JoinSuite` for this functionality so feel free to discard this change if it does not make much sense. The proposed test suite covers a few additional use cases. Author: aokolnychyi <[email protected]> Closes apache#18909 from aokolnychyi/check-cartesian-join-tests.
1 parent c0e333d commit 5596ce8

File tree

4 files changed

+136
-2
lines changed

4 files changed

+136
-2
lines changed
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.scalatest.Matchers._
21+
22+
import org.apache.spark.sql.AnalysisException
23+
import org.apache.spark.sql.catalyst.dsl.expressions._
24+
import org.apache.spark.sql.catalyst.dsl.plans._
25+
import org.apache.spark.sql.catalyst.expressions.Expression
26+
import org.apache.spark.sql.catalyst.plans._
27+
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
28+
import org.apache.spark.sql.catalyst.rules.RuleExecutor
29+
import org.apache.spark.sql.internal.SQLConf.CROSS_JOINS_ENABLED
30+
31+
class CheckCartesianProductsSuite extends PlanTest {
32+
33+
object Optimize extends RuleExecutor[LogicalPlan] {
34+
val batches = Batch("Check Cartesian Products", Once, CheckCartesianProducts) :: Nil
35+
}
36+
37+
val testRelation1 = LocalRelation('a.int, 'b.int)
38+
val testRelation2 = LocalRelation('c.int, 'd.int)
39+
40+
val joinTypesWithRequiredCondition = Seq(Inner, LeftOuter, RightOuter, FullOuter)
41+
val joinTypesWithoutRequiredCondition = Seq(LeftSemi, LeftAnti, ExistenceJoin('exists))
42+
43+
test("CheckCartesianProducts doesn't throw an exception if cross joins are enabled)") {
44+
withSQLConf(CROSS_JOINS_ENABLED.key -> "true") {
45+
noException should be thrownBy {
46+
for (joinType <- joinTypesWithRequiredCondition ++ joinTypesWithoutRequiredCondition) {
47+
performCartesianProductCheck(joinType)
48+
}
49+
}
50+
}
51+
}
52+
53+
test("CheckCartesianProducts throws an exception for join types that require a join condition") {
54+
withSQLConf(CROSS_JOINS_ENABLED.key -> "false") {
55+
for (joinType <- joinTypesWithRequiredCondition) {
56+
val thrownException = the [AnalysisException] thrownBy {
57+
performCartesianProductCheck(joinType)
58+
}
59+
assert(thrownException.message.contains("Detected cartesian product"))
60+
}
61+
}
62+
}
63+
64+
test("CheckCartesianProducts doesn't throw an exception if a join condition is present") {
65+
withSQLConf(CROSS_JOINS_ENABLED.key -> "false") {
66+
for (joinType <- joinTypesWithRequiredCondition) {
67+
noException should be thrownBy {
68+
performCartesianProductCheck(joinType, Some('a === 'd))
69+
}
70+
}
71+
}
72+
}
73+
74+
test("CheckCartesianProducts doesn't throw an exception if join types don't require conditions") {
75+
withSQLConf(CROSS_JOINS_ENABLED.key -> "false") {
76+
for (joinType <- joinTypesWithoutRequiredCondition) {
77+
noException should be thrownBy {
78+
performCartesianProductCheck(joinType)
79+
}
80+
}
81+
}
82+
}
83+
84+
private def performCartesianProductCheck(
85+
joinType: JoinType,
86+
condition: Option[Expression] = None): Unit = {
87+
val analyzedPlan = testRelation1.join(testRelation2, joinType, condition).analyze
88+
val optimizedPlan = Optimize.execute(analyzedPlan)
89+
comparePlans(analyzedPlan, optimizedPlan)
90+
}
91+
}

sql/core/src/test/resources/sql-tests/inputs/cross-join.sql

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,5 @@ create temporary view D(d, vd) as select * from nt1;
3232

3333
-- Allowed since cross join with C is explicit
3434
select * from ((A join B on (a = b)) cross join C) join D on (a = d);
35-
35+
-- Cross joins with non-equal predicates
36+
SELECT * FROM nt1 CROSS JOIN nt2 ON (nt1.k > nt2.k);

sql/core/src/test/resources/sql-tests/results/cross-join.sql.out

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 12
2+
-- Number of queries: 13
33

44

55
-- !query 0
@@ -127,3 +127,13 @@ three 3 three 3 two 2 three 3
127127
two 2 two 2 one 1 two 2
128128
two 2 two 2 three 3 two 2
129129
two 2 two 2 two 2 two 2
130+
131+
-- !query 12
132+
SELECT * FROM nt1 CROSS JOIN nt2 ON (nt1.k > nt2.k)
133+
-- !query 12 schema
134+
struct<k:string,v1:int,k:string,v2:int>
135+
-- !query 12 output
136+
three 3 one 1
137+
three 3 one 5
138+
two 2 one 1
139+
two 2 one 5

sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,9 @@ class JoinSuite extends QueryTest with SharedSQLContext {
216216
Row(1, null, 2, 2) ::
217217
Row(2, 2, 1, null) ::
218218
Row(2, 2, 2, 2) :: Nil)
219+
checkAnswer(
220+
testData3.as("x").join(testData3.as("y"), $"x.a" > $"y.a"),
221+
Row(2, 2, 1, null) :: Nil)
219222
}
220223
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "false") {
221224
val e = intercept[Exception] {
@@ -604,6 +607,35 @@ class JoinSuite extends QueryTest with SharedSQLContext {
604607
}
605608

606609
cartesianQueries.foreach(checkCartesianDetection)
610+
611+
// Check that left_semi, left_anti, existence joins without conditions do not throw
612+
// an exception if cross joins are disabled
613+
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "false") {
614+
checkAnswer(
615+
sql("SELECT * FROM testData3 LEFT SEMI JOIN testData2"),
616+
Row(1, null) :: Row (2, 2) :: Nil)
617+
checkAnswer(
618+
sql("SELECT * FROM testData3 LEFT ANTI JOIN testData2"),
619+
Nil)
620+
checkAnswer(
621+
sql(
622+
"""
623+
|SELECT a FROM testData3
624+
|WHERE
625+
| EXISTS (SELECT * FROM testData)
626+
|OR
627+
| EXISTS (SELECT * FROM testData2)""".stripMargin),
628+
Row(1) :: Row(2) :: Nil)
629+
checkAnswer(
630+
sql(
631+
"""
632+
|SELECT key FROM testData
633+
|WHERE
634+
| key IN (SELECT a FROM testData2)
635+
|OR
636+
| key IN (SELECT a FROM testData3)""".stripMargin),
637+
Row(1) :: Row(2) :: Row(3) :: Nil)
638+
}
607639
}
608640

609641
test("test SortMergeJoin (without spill)") {

0 commit comments

Comments
 (0)