|
17 | 17 |
|
18 | 18 | package org.apache.spark.sql
|
19 | 19 |
|
20 |
| -import org.apache.spark.sql.catalyst.plans.logical.Join |
| 20 | +import scala.collection.mutable.ArrayBuffer |
| 21 | + |
| 22 | +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression |
| 23 | +import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Sort} |
21 | 24 | import org.apache.spark.sql.test.SharedSQLContext
|
22 | 25 |
|
23 | 26 | class SubquerySuite extends QueryTest with SharedSQLContext {
|
@@ -970,4 +973,299 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
|
970 | 973 | Row("3", "b") :: Row("4", "b") :: Nil)
|
971 | 974 | }
|
972 | 975 | }
|
| 976 | + |
| 977 | + private def getNumSortsInQuery(query: String): Int = { |
| 978 | + val plan = sql(query).queryExecution.optimizedPlan |
| 979 | + getNumSorts(plan) + getSubqueryExpressions(plan).map{s => getNumSorts(s.plan)}.sum |
| 980 | + } |
| 981 | + |
| 982 | + private def getSubqueryExpressions(plan: LogicalPlan): Seq[SubqueryExpression] = { |
| 983 | + val subqueryExpressions = ArrayBuffer.empty[SubqueryExpression] |
| 984 | + plan transformAllExpressions { |
| 985 | + case s: SubqueryExpression => |
| 986 | + subqueryExpressions ++= (getSubqueryExpressions(s.plan) :+ s) |
| 987 | + s |
| 988 | + } |
| 989 | + subqueryExpressions |
| 990 | + } |
| 991 | + |
| 992 | + private def getNumSorts(plan: LogicalPlan): Int = { |
| 993 | + plan.collect { case s: Sort => s }.size |
| 994 | + } |
| 995 | + |
| 996 | + test("SPARK-23957 Remove redundant sort from subquery plan(in subquery)") { |
| 997 | + withTempView("t1", "t2", "t3") { |
| 998 | + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") |
| 999 | + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2") |
| 1000 | + Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", "c3").createOrReplaceTempView("t3") |
| 1001 | + |
| 1002 | + // Simple order by |
| 1003 | + val query1 = |
| 1004 | + """ |
| 1005 | + |SELECT c1 FROM t1 |
| 1006 | + |WHERE |
| 1007 | + |c1 IN (SELECT c1 FROM t2 ORDER BY c1) |
| 1008 | + """.stripMargin |
| 1009 | + assert(getNumSortsInQuery(query1) == 0) |
| 1010 | + |
| 1011 | + // Nested order bys |
| 1012 | + val query2 = |
| 1013 | + """ |
| 1014 | + |SELECT c1 |
| 1015 | + |FROM t1 |
| 1016 | + |WHERE c1 IN (SELECT c1 |
| 1017 | + | FROM (SELECT * |
| 1018 | + | FROM t2 |
| 1019 | + | ORDER BY c2) |
| 1020 | + | ORDER BY c1) |
| 1021 | + """.stripMargin |
| 1022 | + assert(getNumSortsInQuery(query2) == 0) |
| 1023 | + |
| 1024 | + |
| 1025 | + // nested IN |
| 1026 | + val query3 = |
| 1027 | + """ |
| 1028 | + |SELECT c1 |
| 1029 | + |FROM t1 |
| 1030 | + |WHERE c1 IN (SELECT c1 |
| 1031 | + | FROM t2 |
| 1032 | + | WHERE c1 IN (SELECT c1 |
| 1033 | + | FROM t3 |
| 1034 | + | WHERE c1 = 1 |
| 1035 | + | ORDER BY c3) |
| 1036 | + | ORDER BY c2) |
| 1037 | + """.stripMargin |
| 1038 | + assert(getNumSortsInQuery(query3) == 0) |
| 1039 | + |
| 1040 | + // Complex subplan and multiple sorts |
| 1041 | + val query4 = |
| 1042 | + """ |
| 1043 | + |SELECT c1 |
| 1044 | + |FROM t1 |
| 1045 | + |WHERE c1 IN (SELECT c1 |
| 1046 | + | FROM (SELECT c1, c2, count(*) |
| 1047 | + | FROM t2 |
| 1048 | + | GROUP BY c1, c2 |
| 1049 | + | HAVING count(*) > 0 |
| 1050 | + | ORDER BY c2) |
| 1051 | + | ORDER BY c1) |
| 1052 | + """.stripMargin |
| 1053 | + assert(getNumSortsInQuery(query4) == 0) |
| 1054 | + |
| 1055 | + // Join in subplan |
| 1056 | + val query5 = |
| 1057 | + """ |
| 1058 | + |SELECT c1 FROM t1 |
| 1059 | + |WHERE |
| 1060 | + |c1 IN (SELECT t2.c1 FROM t2, t3 |
| 1061 | + | WHERE t2.c1 = t3.c1 |
| 1062 | + | ORDER BY t2.c1) |
| 1063 | + """.stripMargin |
| 1064 | + assert(getNumSortsInQuery(query5) == 0) |
| 1065 | + |
| 1066 | + val query6 = |
| 1067 | + """ |
| 1068 | + |SELECT c1 |
| 1069 | + |FROM t1 |
| 1070 | + |WHERE (c1, c2) IN (SELECT c1, max(c2) |
| 1071 | + | FROM (SELECT c1, c2, count(*) |
| 1072 | + | FROM t2 |
| 1073 | + | GROUP BY c1, c2 |
| 1074 | + | HAVING count(*) > 0 |
| 1075 | + | ORDER BY c2) |
| 1076 | + | GROUP BY c1 |
| 1077 | + | HAVING max(c2) > 0 |
| 1078 | + | ORDER BY c1) |
| 1079 | + """.stripMargin |
| 1080 | + // The rule to remove redundant sorts is not able to remove the inner sort under |
| 1081 | + // an Aggregate operator. We only remove the top level sort. |
| 1082 | + assert(getNumSortsInQuery(query6) == 1) |
| 1083 | + |
| 1084 | + // Cases when sort is not removed from the plan |
| 1085 | + // Limit on top of sort |
| 1086 | + val query7 = |
| 1087 | + """ |
| 1088 | + |SELECT c1 FROM t1 |
| 1089 | + |WHERE |
| 1090 | + |c1 IN (SELECT c1 FROM t2 ORDER BY c1 limit 1) |
| 1091 | + """.stripMargin |
| 1092 | + assert(getNumSortsInQuery(query7) == 1) |
| 1093 | + |
| 1094 | + // Sort below a set operations (intersect, union) |
| 1095 | + val query8 = |
| 1096 | + """ |
| 1097 | + |SELECT c1 FROM t1 |
| 1098 | + |WHERE |
| 1099 | + |c1 IN (( |
| 1100 | + | SELECT c1 FROM t2 |
| 1101 | + | ORDER BY c1 |
| 1102 | + | ) |
| 1103 | + | UNION |
| 1104 | + | ( |
| 1105 | + | SELECT c1 FROM t2 |
| 1106 | + | ORDER BY c1 |
| 1107 | + | )) |
| 1108 | + """.stripMargin |
| 1109 | + assert(getNumSortsInQuery(query8) == 2) |
| 1110 | + } |
| 1111 | + } |
| 1112 | + |
| 1113 | + test("SPARK-23957 Remove redundant sort from subquery plan(exists subquery)") { |
| 1114 | + withTempView("t1", "t2", "t3") { |
| 1115 | + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") |
| 1116 | + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2") |
| 1117 | + Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", "c3").createOrReplaceTempView("t3") |
| 1118 | + |
| 1119 | + // Simple order by exists correlated |
| 1120 | + val query1 = |
| 1121 | + """ |
| 1122 | + |SELECT c1 FROM t1 |
| 1123 | + |WHERE |
| 1124 | + |EXISTS (SELECT t2.c1 FROM t2 WHERE t1.c1 = t2.c1 ORDER BY t2.c1) |
| 1125 | + """.stripMargin |
| 1126 | + assert(getNumSortsInQuery(query1) == 0) |
| 1127 | + |
| 1128 | + // Nested order by and correlated. |
| 1129 | + val query2 = |
| 1130 | + """ |
| 1131 | + |SELECT c1 |
| 1132 | + |FROM t1 |
| 1133 | + |WHERE EXISTS (SELECT c1 |
| 1134 | + | FROM (SELECT * |
| 1135 | + | FROM t2 |
| 1136 | + | WHERE t2.c1 = t1.c1 |
| 1137 | + | ORDER BY t2.c2) t2 |
| 1138 | + | ORDER BY t2.c1) |
| 1139 | + """.stripMargin |
| 1140 | + assert(getNumSortsInQuery(query2) == 0) |
| 1141 | + |
| 1142 | + // nested EXISTS |
| 1143 | + val query3 = |
| 1144 | + """ |
| 1145 | + |SELECT c1 |
| 1146 | + |FROM t1 |
| 1147 | + |WHERE EXISTS (SELECT c1 |
| 1148 | + | FROM t2 |
| 1149 | + | WHERE EXISTS (SELECT c1 |
| 1150 | + | FROM t3 |
| 1151 | + | WHERE t3.c1 = t2.c1 |
| 1152 | + | ORDER BY c3) |
| 1153 | + | AND t2.c1 = t1.c1 |
| 1154 | + | ORDER BY c2) |
| 1155 | + """.stripMargin |
| 1156 | + assert(getNumSortsInQuery(query3) == 0) |
| 1157 | + |
| 1158 | + // Cases when sort is not removed from the plan |
| 1159 | + // Limit on top of sort |
| 1160 | + val query4 = |
| 1161 | + """ |
| 1162 | + |SELECT c1 FROM t1 |
| 1163 | + |WHERE |
| 1164 | + |EXISTS (SELECT t2.c1 FROM t2 WHERE t2.c1 = 1 ORDER BY t2.c1 limit 1) |
| 1165 | + """.stripMargin |
| 1166 | + assert(getNumSortsInQuery(query4) == 1) |
| 1167 | + |
| 1168 | + // Sort below a set operations (intersect, union) |
| 1169 | + val query5 = |
| 1170 | + """ |
| 1171 | + |SELECT c1 FROM t1 |
| 1172 | + |WHERE |
| 1173 | + |EXISTS (( |
| 1174 | + | SELECT c1 FROM t2 |
| 1175 | + | WHERE t2.c1 = 1 |
| 1176 | + | ORDER BY t2.c1 |
| 1177 | + | ) |
| 1178 | + | UNION |
| 1179 | + | ( |
| 1180 | + | SELECT c1 FROM t2 |
| 1181 | + | WHERE t2.c1 = 2 |
| 1182 | + | ORDER BY t2.c1 |
| 1183 | + | )) |
| 1184 | + """.stripMargin |
| 1185 | + assert(getNumSortsInQuery(query5) == 2) |
| 1186 | + } |
| 1187 | + } |
| 1188 | + |
| 1189 | + test("SPARK-23957 Remove redundant sort from subquery plan(scalar subquery)") { |
| 1190 | + withTempView("t1", "t2", "t3") { |
| 1191 | + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") |
| 1192 | + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2") |
| 1193 | + Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", "c3").createOrReplaceTempView("t3") |
| 1194 | + |
| 1195 | + // Two scalar subqueries in OR |
| 1196 | + val query1 = |
| 1197 | + """ |
| 1198 | + |SELECT * FROM t1 |
| 1199 | + |WHERE c1 = (SELECT max(t2.c1) |
| 1200 | + | FROM t2 |
| 1201 | + | ORDER BY max(t2.c1)) |
| 1202 | + |OR c2 = (SELECT min(t3.c2) |
| 1203 | + | FROM t3 |
| 1204 | + | WHERE t3.c1 = 1 |
| 1205 | + | ORDER BY min(t3.c2)) |
| 1206 | + """.stripMargin |
| 1207 | + assert(getNumSortsInQuery(query1) == 0) |
| 1208 | + |
| 1209 | + // scalar subquery - groupby and having |
| 1210 | + val query2 = |
| 1211 | + """ |
| 1212 | + |SELECT * |
| 1213 | + |FROM t1 |
| 1214 | + |WHERE c1 = (SELECT max(t2.c1) |
| 1215 | + | FROM t2 |
| 1216 | + | GROUP BY t2.c1 |
| 1217 | + | HAVING count(*) >= 1 |
| 1218 | + | ORDER BY max(t2.c1)) |
| 1219 | + """.stripMargin |
| 1220 | + assert(getNumSortsInQuery(query2) == 0) |
| 1221 | + |
| 1222 | + // nested scalar subquery |
| 1223 | + val query3 = |
| 1224 | + """ |
| 1225 | + |SELECT * |
| 1226 | + |FROM t1 |
| 1227 | + |WHERE c1 = (SELECT max(t2.c1) |
| 1228 | + | FROM t2 |
| 1229 | + | WHERE c1 = (SELECT max(t3.c1) |
| 1230 | + | FROM t3 |
| 1231 | + | WHERE t3.c1 = 1 |
| 1232 | + | GROUP BY t3.c1 |
| 1233 | + | ORDER BY max(t3.c1) |
| 1234 | + | ) |
| 1235 | + | GROUP BY t2.c1 |
| 1236 | + | HAVING count(*) >= 1 |
| 1237 | + | ORDER BY max(t2.c1)) |
| 1238 | + """.stripMargin |
| 1239 | + assert(getNumSortsInQuery(query3) == 0) |
| 1240 | + |
| 1241 | + // Scalar subquery in projection |
| 1242 | + val query4 = |
| 1243 | + """ |
| 1244 | + |SELECT (SELECT min(c1) from t1 group by c1 order by c1) |
| 1245 | + |FROM t1 |
| 1246 | + |WHERE t1.c1 = 1 |
| 1247 | + """.stripMargin |
| 1248 | + assert(getNumSortsInQuery(query4) == 0) |
| 1249 | + |
| 1250 | + // Limit on top of sort prevents it from being pruned. |
| 1251 | + val query5 = |
| 1252 | + """ |
| 1253 | + |SELECT * |
| 1254 | + |FROM t1 |
| 1255 | + |WHERE c1 = (SELECT max(t2.c1) |
| 1256 | + | FROM t2 |
| 1257 | + | WHERE c1 = (SELECT max(t3.c1) |
| 1258 | + | FROM t3 |
| 1259 | + | WHERE t3.c1 = 1 |
| 1260 | + | GROUP BY t3.c1 |
| 1261 | + | ORDER BY max(t3.c1) |
| 1262 | + | ) |
| 1263 | + | GROUP BY t2.c1 |
| 1264 | + | HAVING count(*) >= 1 |
| 1265 | + | ORDER BY max(t2.c1) |
| 1266 | + | LIMIT 1) |
| 1267 | + """.stripMargin |
| 1268 | + assert(getNumSortsInQuery(query5) == 1) |
| 1269 | + } |
| 1270 | + } |
973 | 1271 | }
|
0 commit comments