Skip to content

Commit 895666b

Browse files
committed
[SPARK-55551][SQL] Improve BroadcastHashJoinExec output partitioning
### What changes were proposed in this pull request? This is a minor refector of `BroadcastHashJoinExec.outputPartitioning` to: - simlify the logic and - make it future proof by using `Partitioning with Expression` instead of `HashPartitioningLike`. ### Why are the changes needed? Code cleanup and add support for future partitionings that implement `Expression` but not `HashPartitioningLike`. (Like `KeyedPartitioning` is in #54330.) ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #54335 from peter-toth/SPARK-55551-improve-broadcasthashjoinexec-output-partitioning. Authored-by: Peter Toth <peter.toth@gmail.com> Signed-off-by: Peter Toth <peter.toth@gmail.com>
1 parent d2be5d2 commit 895666b

File tree

2 files changed

+31
-32
lines changed

2 files changed

+31
-32
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._
2727
import org.apache.spark.sql.catalyst.expressions.codegen._
2828
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide, JoinSelectionHelper}
2929
import org.apache.spark.sql.catalyst.plans._
30-
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, HashPartitioningLike, Partitioning, PartitioningCollection, UnspecifiedDistribution}
30+
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, PartitioningCollection, UnspecifiedDistribution}
3131
import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan}
3232
import org.apache.spark.sql.execution.metric.SQLMetrics
3333

@@ -72,10 +72,14 @@ case class BroadcastHashJoinExec private(
7272
override lazy val outputPartitioning: Partitioning = {
7373
joinType match {
7474
case _: InnerLike if conf.broadcastHashJoinOutputPartitioningExpandLimit > 0 =>
75-
streamedPlan.outputPartitioning match {
76-
case h: HashPartitioningLike => expandOutputPartitioning(h)
77-
case c: PartitioningCollection => expandOutputPartitioning(c)
78-
case other => other
75+
val expandedPartitioning = expandOutputPartitioning(streamedPlan.outputPartitioning)
76+
expandedPartitioning match {
77+
// We don't need to handle the empty case, since it could only occur if
78+
// `streamedPlan.outputPartitioning` were an empty `PartitioningCollection`, but its
79+
// constructor prevents that.
80+
81+
case p :: Nil => p
82+
case ps => PartitioningCollection(ps)
7983
}
8084
case _ => streamedPlan.outputPartitioning
8185
}
@@ -95,29 +99,25 @@ case class BroadcastHashJoinExec private(
9599
mapping.toMap
96100
}
97101

98-
// Expands the given partitioning collection recursively.
99-
private def expandOutputPartitioning(
100-
partitioning: PartitioningCollection): PartitioningCollection = {
101-
PartitioningCollection(partitioning.partitionings.flatMap {
102-
case h: HashPartitioningLike => expandOutputPartitioning(h).partitionings
103-
case c: PartitioningCollection => Seq(expandOutputPartitioning(c))
102+
// Expands the given partitioning recursively.
103+
private def expandOutputPartitioning(partitioning: Partitioning): Seq[Partitioning] = {
104+
partitioning match {
105+
case c: PartitioningCollection => c.partitionings.flatMap(expandOutputPartitioning)
106+
case p: Partitioning with Expression =>
107+
// Expands the given partitioning, that is also an expression, by substituting streamed keys
108+
// with build keys.
109+
// For example, if the expressions for the given partitioning are Seq("a", "b", "c") where
110+
// the streamed keys are Seq("b", "c") and the build keys are Seq("x", "y"), the expanded
111+
// partitioning will have the following expressions:
112+
// Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y").
113+
// The expanded expressions are returned as `Seq[Partitioning]`.
114+
p.multiTransformDown {
115+
case e: Expression if streamedKeyToBuildKeyMapping.contains(e.canonicalized) =>
116+
e +: streamedKeyToBuildKeyMapping(e.canonicalized)
117+
}.asInstanceOf[LazyList[Partitioning]]
118+
.take(conf.broadcastHashJoinOutputPartitioningExpandLimit)
104119
case other => Seq(other)
105-
})
106-
}
107-
108-
// Expands the given hash partitioning by substituting streamed keys with build keys.
109-
// For example, if the expressions for the given partitioning are Seq("a", "b", "c")
110-
// where the streamed keys are Seq("b", "c") and the build keys are Seq("x", "y"),
111-
// the expanded partitioning will have the following expressions:
112-
// Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y").
113-
// The expanded expressions are returned as PartitioningCollection.
114-
private def expandOutputPartitioning(
115-
partitioning: HashPartitioningLike): PartitioningCollection = {
116-
PartitioningCollection(partitioning.multiTransformDown {
117-
case e: Expression if streamedKeyToBuildKeyMapping.contains(e.canonicalized) =>
118-
e +: streamedKeyToBuildKeyMapping(e.canonicalized)
119-
}.asInstanceOf[LazyList[HashPartitioningLike]]
120-
.take(conf.broadcastHashJoinOutputPartitioningExpandLimit))
120+
}
121121
}
122122

123123
protected override def doExecute(): RDD[InternalRow] = {

sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -592,11 +592,10 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils
592592
HashPartitioning(Seq(l3), 1)))),
593593
right = DummySparkPlan())
594594
expected = PartitioningCollection(Seq(
595-
PartitioningCollection(Seq(
596-
HashPartitioning(Seq(l1), 1),
597-
HashPartitioning(Seq(r1), 1),
598-
HashPartitioning(Seq(l2), 1),
599-
HashPartitioning(Seq(r2), 1))),
595+
HashPartitioning(Seq(l1), 1),
596+
HashPartitioning(Seq(r1), 1),
597+
HashPartitioning(Seq(l2), 1),
598+
HashPartitioning(Seq(r2), 1),
600599
HashPartitioning(Seq(l3), 1),
601600
HashPartitioning(Seq(r3), 1)))
602601
assert(bhj.outputPartitioning === expected)

0 commit comments

Comments
 (0)