Skip to content

Commit fe65361

Browse files
tejasapatilgatorsmile
authored andcommitted
[SPARK-22042][FOLLOW-UP][SQL] ReorderJoinPredicates can break when child's partitioning is not decided
## What changes were proposed in this pull request? This is a followup PR of #19257 where gatorsmile had left couple comments wrt code style. ## How was this patch tested? Doesn't change any functionality. Will depend on build to see if no checkstyle rules are violated. Author: Tejas Patil <[email protected]> Closes #20041 from tejasapatil/followup_19257.
1 parent 4e107fd commit fe65361

File tree

2 files changed

+44
-42
lines changed

2 files changed

+44
-42
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala

Lines changed: 42 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -252,54 +252,56 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
252252
operator.withNewChildren(children)
253253
}
254254

255-
/**
256-
* When the physical operators are created for JOIN, the ordering of join keys is based on order
257-
* in which the join keys appear in the user query. That might not match with the output
258-
* partitioning of the join node's children (thus leading to extra sort / shuffle being
259-
* introduced). This rule will change the ordering of the join keys to match with the
260-
* partitioning of the join nodes' children.
261-
*/
262-
def reorderJoinPredicates(plan: SparkPlan): SparkPlan = {
263-
def reorderJoinKeys(
264-
leftKeys: Seq[Expression],
265-
rightKeys: Seq[Expression],
266-
leftPartitioning: Partitioning,
267-
rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = {
268-
269-
def reorder(expectedOrderOfKeys: Seq[Expression],
270-
currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
271-
val leftKeysBuffer = ArrayBuffer[Expression]()
272-
val rightKeysBuffer = ArrayBuffer[Expression]()
255+
private def reorder(
256+
leftKeys: Seq[Expression],
257+
rightKeys: Seq[Expression],
258+
expectedOrderOfKeys: Seq[Expression],
259+
currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
260+
val leftKeysBuffer = ArrayBuffer[Expression]()
261+
val rightKeysBuffer = ArrayBuffer[Expression]()
273262

274-
expectedOrderOfKeys.foreach(expression => {
275-
val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression))
276-
leftKeysBuffer.append(leftKeys(index))
277-
rightKeysBuffer.append(rightKeys(index))
278-
})
279-
(leftKeysBuffer, rightKeysBuffer)
280-
}
263+
expectedOrderOfKeys.foreach(expression => {
264+
val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression))
265+
leftKeysBuffer.append(leftKeys(index))
266+
rightKeysBuffer.append(rightKeys(index))
267+
})
268+
(leftKeysBuffer, rightKeysBuffer)
269+
}
281270

282-
if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) {
283-
leftPartitioning match {
284-
case HashPartitioning(leftExpressions, _)
285-
if leftExpressions.length == leftKeys.length &&
286-
leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) =>
287-
reorder(leftExpressions, leftKeys)
271+
private def reorderJoinKeys(
272+
leftKeys: Seq[Expression],
273+
rightKeys: Seq[Expression],
274+
leftPartitioning: Partitioning,
275+
rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = {
276+
if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) {
277+
leftPartitioning match {
278+
case HashPartitioning(leftExpressions, _)
279+
if leftExpressions.length == leftKeys.length &&
280+
leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) =>
281+
reorder(leftKeys, rightKeys, leftExpressions, leftKeys)
288282

289-
case _ => rightPartitioning match {
290-
case HashPartitioning(rightExpressions, _)
291-
if rightExpressions.length == rightKeys.length &&
292-
rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) =>
293-
reorder(rightExpressions, rightKeys)
283+
case _ => rightPartitioning match {
284+
case HashPartitioning(rightExpressions, _)
285+
if rightExpressions.length == rightKeys.length &&
286+
rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) =>
287+
reorder(leftKeys, rightKeys, rightExpressions, rightKeys)
294288

295-
case _ => (leftKeys, rightKeys)
296-
}
289+
case _ => (leftKeys, rightKeys)
297290
}
298-
} else {
299-
(leftKeys, rightKeys)
300291
}
292+
} else {
293+
(leftKeys, rightKeys)
301294
}
295+
}
302296

297+
/**
298+
* When the physical operators are created for JOIN, the ordering of join keys is based on order
299+
* in which the join keys appear in the user query. That might not match with the output
300+
* partitioning of the join node's children (thus leading to extra sort / shuffle being
301+
* introduced). This rule will change the ordering of the join keys to match with the
302+
* partitioning of the join nodes' children.
303+
*/
304+
private def reorderJoinPredicates(plan: SparkPlan): SparkPlan = {
303305
plan.transformUp {
304306
case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left,
305307
right) =>

sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -620,15 +620,15 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
620620
|) ab
621621
|JOIN table2 c
622622
|ON ab.i = c.i
623-
|""".stripMargin),
623+
""".stripMargin),
624624
sql("""
625625
|SELECT a.i, a.j, a.k, c.i, c.j, c.k
626626
|FROM bucketed_table a
627627
|JOIN table1 b
628628
|ON a.i = b.i
629629
|JOIN table2 c
630630
|ON a.i = c.i
631-
|""".stripMargin))
631+
""".stripMargin))
632632
}
633633
}
634634
}

0 commit comments

Comments
 (0)