Skip to content

Commit afb0627

Browse files
dilipbiswalgatorsmile
authored andcommitted
[SPARK-23957][SQL] Sorts in subqueries are redundant and can be removed
## What changes were proposed in this pull request? Thanks to henryr for the original idea at apache#21049 Description from the original PR : Subqueries (at least in SQL) have 'bag of tuples' semantics. Ordering them is therefore redundant (unless combined with a limit). This patch removes the top sort operators from the subquery plans. This closes apache#21049. ## How was this patch tested? Added test cases in SubquerySuite to cover in, exists and scalar subqueries. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Dilip Biswal <[email protected]> Closes apache#21853 from dilipbiswal/SPARK-23957.
1 parent d4c3415 commit afb0627

File tree

2 files changed

+310
-2
lines changed

2 files changed

+310
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,20 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
180180
* Optimize all the subqueries inside expression.
181181
*/
182182
object OptimizeSubqueries extends Rule[LogicalPlan] {
183+
private def removeTopLevelSort(plan: LogicalPlan): LogicalPlan = {
184+
plan match {
185+
case Sort(_, _, child) => child
186+
case Project(fields, child) => Project(fields, removeTopLevelSort(child))
187+
case other => other
188+
}
189+
}
183190
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
184191
case s: SubqueryExpression =>
185192
val Subquery(newPlan) = Optimizer.this.execute(Subquery(s.plan))
186-
s.withNewPlan(newPlan)
193+
// At this point we have an optimized subquery plan that we are going to attach
194+
// to this subquery expression. Here we can safely remove any top level sort
195+
// in the plan as tuples produced by a subquery are un-ordered.
196+
s.withNewPlan(removeTopLevelSort(newPlan))
187197
}
188198
}
189199

sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala

Lines changed: 299 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717

1818
package org.apache.spark.sql
1919

20-
import org.apache.spark.sql.catalyst.plans.logical.Join
20+
import scala.collection.mutable.ArrayBuffer
21+
22+
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
23+
import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Sort}
2124
import org.apache.spark.sql.test.SharedSQLContext
2225

2326
class SubquerySuite extends QueryTest with SharedSQLContext {
@@ -970,4 +973,299 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
970973
Row("3", "b") :: Row("4", "b") :: Nil)
971974
}
972975
}
976+
977+
private def getNumSortsInQuery(query: String): Int = {
978+
val plan = sql(query).queryExecution.optimizedPlan
979+
getNumSorts(plan) + getSubqueryExpressions(plan).map{s => getNumSorts(s.plan)}.sum
980+
}
981+
982+
private def getSubqueryExpressions(plan: LogicalPlan): Seq[SubqueryExpression] = {
983+
val subqueryExpressions = ArrayBuffer.empty[SubqueryExpression]
984+
plan transformAllExpressions {
985+
case s: SubqueryExpression =>
986+
subqueryExpressions ++= (getSubqueryExpressions(s.plan) :+ s)
987+
s
988+
}
989+
subqueryExpressions
990+
}
991+
992+
private def getNumSorts(plan: LogicalPlan): Int = {
993+
plan.collect { case s: Sort => s }.size
994+
}
995+
996+
test("SPARK-23957 Remove redundant sort from subquery plan(in subquery)") {
997+
withTempView("t1", "t2", "t3") {
998+
Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1")
999+
Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2")
1000+
Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", "c3").createOrReplaceTempView("t3")
1001+
1002+
// Simple order by
1003+
val query1 =
1004+
"""
1005+
|SELECT c1 FROM t1
1006+
|WHERE
1007+
|c1 IN (SELECT c1 FROM t2 ORDER BY c1)
1008+
""".stripMargin
1009+
assert(getNumSortsInQuery(query1) == 0)
1010+
1011+
// Nested order bys
1012+
val query2 =
1013+
"""
1014+
|SELECT c1
1015+
|FROM t1
1016+
|WHERE c1 IN (SELECT c1
1017+
| FROM (SELECT *
1018+
| FROM t2
1019+
| ORDER BY c2)
1020+
| ORDER BY c1)
1021+
""".stripMargin
1022+
assert(getNumSortsInQuery(query2) == 0)
1023+
1024+
1025+
// nested IN
1026+
val query3 =
1027+
"""
1028+
|SELECT c1
1029+
|FROM t1
1030+
|WHERE c1 IN (SELECT c1
1031+
| FROM t2
1032+
| WHERE c1 IN (SELECT c1
1033+
| FROM t3
1034+
| WHERE c1 = 1
1035+
| ORDER BY c3)
1036+
| ORDER BY c2)
1037+
""".stripMargin
1038+
assert(getNumSortsInQuery(query3) == 0)
1039+
1040+
// Complex subplan and multiple sorts
1041+
val query4 =
1042+
"""
1043+
|SELECT c1
1044+
|FROM t1
1045+
|WHERE c1 IN (SELECT c1
1046+
| FROM (SELECT c1, c2, count(*)
1047+
| FROM t2
1048+
| GROUP BY c1, c2
1049+
| HAVING count(*) > 0
1050+
| ORDER BY c2)
1051+
| ORDER BY c1)
1052+
""".stripMargin
1053+
assert(getNumSortsInQuery(query4) == 0)
1054+
1055+
// Join in subplan
1056+
val query5 =
1057+
"""
1058+
|SELECT c1 FROM t1
1059+
|WHERE
1060+
|c1 IN (SELECT t2.c1 FROM t2, t3
1061+
| WHERE t2.c1 = t3.c1
1062+
| ORDER BY t2.c1)
1063+
""".stripMargin
1064+
assert(getNumSortsInQuery(query5) == 0)
1065+
1066+
val query6 =
1067+
"""
1068+
|SELECT c1
1069+
|FROM t1
1070+
|WHERE (c1, c2) IN (SELECT c1, max(c2)
1071+
| FROM (SELECT c1, c2, count(*)
1072+
| FROM t2
1073+
| GROUP BY c1, c2
1074+
| HAVING count(*) > 0
1075+
| ORDER BY c2)
1076+
| GROUP BY c1
1077+
| HAVING max(c2) > 0
1078+
| ORDER BY c1)
1079+
""".stripMargin
1080+
// The rule to remove redundant sorts is not able to remove the inner sort under
1081+
// an Aggregate operator. We only remove the top level sort.
1082+
assert(getNumSortsInQuery(query6) == 1)
1083+
1084+
// Cases when sort is not removed from the plan
1085+
// Limit on top of sort
1086+
val query7 =
1087+
"""
1088+
|SELECT c1 FROM t1
1089+
|WHERE
1090+
|c1 IN (SELECT c1 FROM t2 ORDER BY c1 limit 1)
1091+
""".stripMargin
1092+
assert(getNumSortsInQuery(query7) == 1)
1093+
1094+
// Sort below a set operations (intersect, union)
1095+
val query8 =
1096+
"""
1097+
|SELECT c1 FROM t1
1098+
|WHERE
1099+
|c1 IN ((
1100+
| SELECT c1 FROM t2
1101+
| ORDER BY c1
1102+
| )
1103+
| UNION
1104+
| (
1105+
| SELECT c1 FROM t2
1106+
| ORDER BY c1
1107+
| ))
1108+
""".stripMargin
1109+
assert(getNumSortsInQuery(query8) == 2)
1110+
}
1111+
}
1112+
1113+
test("SPARK-23957 Remove redundant sort from subquery plan(exists subquery)") {
1114+
withTempView("t1", "t2", "t3") {
1115+
Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1")
1116+
Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2")
1117+
Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", "c3").createOrReplaceTempView("t3")
1118+
1119+
// Simple order by exists correlated
1120+
val query1 =
1121+
"""
1122+
|SELECT c1 FROM t1
1123+
|WHERE
1124+
|EXISTS (SELECT t2.c1 FROM t2 WHERE t1.c1 = t2.c1 ORDER BY t2.c1)
1125+
""".stripMargin
1126+
assert(getNumSortsInQuery(query1) == 0)
1127+
1128+
// Nested order by and correlated.
1129+
val query2 =
1130+
"""
1131+
|SELECT c1
1132+
|FROM t1
1133+
|WHERE EXISTS (SELECT c1
1134+
| FROM (SELECT *
1135+
| FROM t2
1136+
| WHERE t2.c1 = t1.c1
1137+
| ORDER BY t2.c2) t2
1138+
| ORDER BY t2.c1)
1139+
""".stripMargin
1140+
assert(getNumSortsInQuery(query2) == 0)
1141+
1142+
// nested EXISTS
1143+
val query3 =
1144+
"""
1145+
|SELECT c1
1146+
|FROM t1
1147+
|WHERE EXISTS (SELECT c1
1148+
| FROM t2
1149+
| WHERE EXISTS (SELECT c1
1150+
| FROM t3
1151+
| WHERE t3.c1 = t2.c1
1152+
| ORDER BY c3)
1153+
| AND t2.c1 = t1.c1
1154+
| ORDER BY c2)
1155+
""".stripMargin
1156+
assert(getNumSortsInQuery(query3) == 0)
1157+
1158+
// Cases when sort is not removed from the plan
1159+
// Limit on top of sort
1160+
val query4 =
1161+
"""
1162+
|SELECT c1 FROM t1
1163+
|WHERE
1164+
|EXISTS (SELECT t2.c1 FROM t2 WHERE t2.c1 = 1 ORDER BY t2.c1 limit 1)
1165+
""".stripMargin
1166+
assert(getNumSortsInQuery(query4) == 1)
1167+
1168+
// Sort below a set operations (intersect, union)
1169+
val query5 =
1170+
"""
1171+
|SELECT c1 FROM t1
1172+
|WHERE
1173+
|EXISTS ((
1174+
| SELECT c1 FROM t2
1175+
| WHERE t2.c1 = 1
1176+
| ORDER BY t2.c1
1177+
| )
1178+
| UNION
1179+
| (
1180+
| SELECT c1 FROM t2
1181+
| WHERE t2.c1 = 2
1182+
| ORDER BY t2.c1
1183+
| ))
1184+
""".stripMargin
1185+
assert(getNumSortsInQuery(query5) == 2)
1186+
}
1187+
}
1188+
1189+
test("SPARK-23957 Remove redundant sort from subquery plan(scalar subquery)") {
1190+
withTempView("t1", "t2", "t3") {
1191+
Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1")
1192+
Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2")
1193+
Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", "c3").createOrReplaceTempView("t3")
1194+
1195+
// Two scalar subqueries in OR
1196+
val query1 =
1197+
"""
1198+
|SELECT * FROM t1
1199+
|WHERE c1 = (SELECT max(t2.c1)
1200+
| FROM t2
1201+
| ORDER BY max(t2.c1))
1202+
|OR c2 = (SELECT min(t3.c2)
1203+
| FROM t3
1204+
| WHERE t3.c1 = 1
1205+
| ORDER BY min(t3.c2))
1206+
""".stripMargin
1207+
assert(getNumSortsInQuery(query1) == 0)
1208+
1209+
// scalar subquery - groupby and having
1210+
val query2 =
1211+
"""
1212+
|SELECT *
1213+
|FROM t1
1214+
|WHERE c1 = (SELECT max(t2.c1)
1215+
| FROM t2
1216+
| GROUP BY t2.c1
1217+
| HAVING count(*) >= 1
1218+
| ORDER BY max(t2.c1))
1219+
""".stripMargin
1220+
assert(getNumSortsInQuery(query2) == 0)
1221+
1222+
// nested scalar subquery
1223+
val query3 =
1224+
"""
1225+
|SELECT *
1226+
|FROM t1
1227+
|WHERE c1 = (SELECT max(t2.c1)
1228+
| FROM t2
1229+
| WHERE c1 = (SELECT max(t3.c1)
1230+
| FROM t3
1231+
| WHERE t3.c1 = 1
1232+
| GROUP BY t3.c1
1233+
| ORDER BY max(t3.c1)
1234+
| )
1235+
| GROUP BY t2.c1
1236+
| HAVING count(*) >= 1
1237+
| ORDER BY max(t2.c1))
1238+
""".stripMargin
1239+
assert(getNumSortsInQuery(query3) == 0)
1240+
1241+
// Scalar subquery in projection
1242+
val query4 =
1243+
"""
1244+
|SELECT (SELECT min(c1) from t1 group by c1 order by c1)
1245+
|FROM t1
1246+
|WHERE t1.c1 = 1
1247+
""".stripMargin
1248+
assert(getNumSortsInQuery(query4) == 0)
1249+
1250+
// Limit on top of sort prevents it from being pruned.
1251+
val query5 =
1252+
"""
1253+
|SELECT *
1254+
|FROM t1
1255+
|WHERE c1 = (SELECT max(t2.c1)
1256+
| FROM t2
1257+
| WHERE c1 = (SELECT max(t3.c1)
1258+
| FROM t3
1259+
| WHERE t3.c1 = 1
1260+
| GROUP BY t3.c1
1261+
| ORDER BY max(t3.c1)
1262+
| )
1263+
| GROUP BY t2.c1
1264+
| HAVING count(*) >= 1
1265+
| ORDER BY max(t2.c1)
1266+
| LIMIT 1)
1267+
""".stripMargin
1268+
assert(getNumSortsInQuery(query5) == 1)
1269+
}
1270+
}
9731271
}

0 commit comments

Comments
 (0)