Skip to content

Commit 7f16c69

Browse files
tejasapatilgatorsmile
authored andcommitted
[SPARK-19122][SQL] Unnecessary shuffle+sort added if join predicates ordering differ from bucketing and sorting order
## What changes were proposed in this pull request? Jira : https://issues.apache.org/jira/browse/SPARK-19122 `leftKeys` and `rightKeys` in `SortMergeJoinExec` are altered based on the ordering of join keys in the child's `outputPartitioning`. This is done everytime `requiredChildDistribution` is invoked during query planning. ## How was this patch tested? - Added new test case - Existing tests Author: Tejas Patil <[email protected]> Closes apache#16985 from tejasapatil/SPARK-19122_join_order_shuffle.
1 parent 9443999 commit 7f16c69

File tree

3 files changed

+155
-0
lines changed

3 files changed

+155
-0
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
2929
import org.apache.spark.sql.catalyst.util.DateTimeUtils
3030
import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand}
3131
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange}
32+
import org.apache.spark.sql.execution.joins.ReorderJoinPredicates
3233
import org.apache.spark.sql.types.{BinaryType, DateType, DecimalType, TimestampType, _}
3334
import org.apache.spark.util.Utils
3435

@@ -103,6 +104,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
103104
protected def preparations: Seq[Rule[SparkPlan]] = Seq(
104105
python.ExtractPythonUDFs,
105106
PlanSubqueries(sparkSession),
107+
new ReorderJoinPredicates,
106108
EnsureRequirements(sparkSession.sessionState.conf),
107109
CollapseCodegenStages(sparkSession.sessionState.conf),
108110
ReuseExchange(sparkSession.sessionState.conf),
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.joins
19+
20+
import scala.collection.mutable.ArrayBuffer
21+
22+
import org.apache.spark.sql.catalyst.expressions.Expression
23+
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}
24+
import org.apache.spark.sql.catalyst.rules.Rule
25+
import org.apache.spark.sql.execution.SparkPlan
26+
27+
/**
28+
* When the physical operators are created for JOIN, the ordering of join keys is based on order
29+
* in which the join keys appear in the user query. That might not match with the output
30+
* partitioning of the join node's children (thus leading to extra sort / shuffle being
31+
* introduced). This rule will change the ordering of the join keys to match with the
32+
* partitioning of the join nodes' children.
33+
*/
34+
class ReorderJoinPredicates extends Rule[SparkPlan] {
35+
private def reorderJoinKeys(
36+
leftKeys: Seq[Expression],
37+
rightKeys: Seq[Expression],
38+
leftPartitioning: Partitioning,
39+
rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = {
40+
41+
def reorder(
42+
expectedOrderOfKeys: Seq[Expression],
43+
currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
44+
val leftKeysBuffer = ArrayBuffer[Expression]()
45+
val rightKeysBuffer = ArrayBuffer[Expression]()
46+
47+
expectedOrderOfKeys.foreach(expression => {
48+
val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression))
49+
leftKeysBuffer.append(leftKeys(index))
50+
rightKeysBuffer.append(rightKeys(index))
51+
})
52+
(leftKeysBuffer, rightKeysBuffer)
53+
}
54+
55+
if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) {
56+
leftPartitioning match {
57+
case HashPartitioning(leftExpressions, _)
58+
if leftExpressions.length == leftKeys.length &&
59+
leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) =>
60+
reorder(leftExpressions, leftKeys)
61+
62+
case _ => rightPartitioning match {
63+
case HashPartitioning(rightExpressions, _)
64+
if rightExpressions.length == rightKeys.length &&
65+
rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) =>
66+
reorder(rightExpressions, rightKeys)
67+
68+
case _ => (leftKeys, rightKeys)
69+
}
70+
}
71+
} else {
72+
(leftKeys, rightKeys)
73+
}
74+
}
75+
76+
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
77+
case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) =>
78+
val (reorderedLeftKeys, reorderedRightKeys) =
79+
reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
80+
BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition,
81+
left, right)
82+
83+
case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) =>
84+
val (reorderedLeftKeys, reorderedRightKeys) =
85+
reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
86+
ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition,
87+
left, right)
88+
89+
case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right) =>
90+
val (reorderedLeftKeys, reorderedRightKeys) =
91+
reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
92+
SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right)
93+
}
94+
}

sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,65 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
543543
)
544544
}
545545

546+
test("SPARK-19122 Re-order join predicates if they match with the child's output partitioning") {
547+
val bucketedTableTestSpec = BucketedTableTestSpec(
548+
Some(BucketSpec(8, Seq("i", "j", "k"), Seq("i", "j", "k"))),
549+
numPartitions = 1,
550+
expectedShuffle = false,
551+
expectedSort = false)
552+
553+
// If the set of join columns is equal to the set of bucketed + sort columns, then
554+
// the order of join keys in the query should not matter and there should not be any shuffle
555+
// and sort added in the query plan
556+
Seq(
557+
Seq("i", "j", "k"),
558+
Seq("i", "k", "j"),
559+
Seq("j", "k", "i"),
560+
Seq("j", "i", "k"),
561+
Seq("k", "j", "i"),
562+
Seq("k", "i", "j")
563+
).foreach(joinKeys => {
564+
testBucketing(
565+
bucketedTableTestSpecLeft = bucketedTableTestSpec,
566+
bucketedTableTestSpecRight = bucketedTableTestSpec,
567+
joinCondition = joinCondition(joinKeys)
568+
)
569+
})
570+
}
571+
572+
test("SPARK-19122 No re-ordering should happen if set of join columns != set of child's " +
573+
"partitioning columns") {
574+
575+
// join predicates is a super set of child's partitioning columns
576+
val bucketedTableTestSpec1 =
577+
BucketedTableTestSpec(Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))), numPartitions = 1)
578+
testBucketing(
579+
bucketedTableTestSpecLeft = bucketedTableTestSpec1,
580+
bucketedTableTestSpecRight = bucketedTableTestSpec1,
581+
joinCondition = joinCondition(Seq("i", "j", "k"))
582+
)
583+
584+
// child's partitioning columns is a super set of join predicates
585+
val bucketedTableTestSpec2 =
586+
BucketedTableTestSpec(Some(BucketSpec(8, Seq("i", "j", "k"), Seq("i", "j", "k"))),
587+
numPartitions = 1)
588+
testBucketing(
589+
bucketedTableTestSpecLeft = bucketedTableTestSpec2,
590+
bucketedTableTestSpecRight = bucketedTableTestSpec2,
591+
joinCondition = joinCondition(Seq("i", "j"))
592+
)
593+
594+
// set of child's partitioning columns != set join predicates (despite the lengths of the
595+
// sets are same)
596+
val bucketedTableTestSpec3 =
597+
BucketedTableTestSpec(Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))), numPartitions = 1)
598+
testBucketing(
599+
bucketedTableTestSpecLeft = bucketedTableTestSpec3,
600+
bucketedTableTestSpecRight = bucketedTableTestSpec3,
601+
joinCondition = joinCondition(Seq("j", "k"))
602+
)
603+
}
604+
546605
test("error if there exists any malformed bucket files") {
547606
withTable("bucketed_table") {
548607
df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")

0 commit comments

Comments
 (0)