@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._
2727import org .apache .spark .sql .catalyst .expressions .codegen ._
2828import org .apache .spark .sql .catalyst .optimizer .{BuildLeft , BuildRight , BuildSide , JoinSelectionHelper }
2929import org .apache .spark .sql .catalyst .plans ._
30- import org .apache .spark .sql .catalyst .plans .physical .{BroadcastDistribution , Distribution , HashPartitioningLike , Partitioning , PartitioningCollection , UnspecifiedDistribution }
30+ import org .apache .spark .sql .catalyst .plans .physical .{BroadcastDistribution , Distribution , Partitioning , PartitioningCollection , UnspecifiedDistribution }
3131import org .apache .spark .sql .execution .{CodegenSupport , SparkPlan }
3232import org .apache .spark .sql .execution .metric .SQLMetrics
3333
@@ -72,10 +72,14 @@ case class BroadcastHashJoinExec private(
7272 override lazy val outputPartitioning : Partitioning = {
7373 joinType match {
7474 case _ : InnerLike if conf.broadcastHashJoinOutputPartitioningExpandLimit > 0 =>
75- streamedPlan.outputPartitioning match {
76- case h : HashPartitioningLike => expandOutputPartitioning(h)
77- case c : PartitioningCollection => expandOutputPartitioning(c)
78- case other => other
75+ val expandedPartitioning = expandOutputPartitioning(streamedPlan.outputPartitioning)
76+ expandedPartitioning match {
77+ // We don't need to handle the empty case, since it could only occur if
78+ // `streamedPlan.outputPartitioning` were an empty `PartitioningCollection`, but its
79+ // constructor prevents that.
80+
81+ case p :: Nil => p
82+ case ps => PartitioningCollection (ps)
7983 }
8084 case _ => streamedPlan.outputPartitioning
8185 }
@@ -95,29 +99,25 @@ case class BroadcastHashJoinExec private(
9599 mapping.toMap
96100 }
97101
98- // Expands the given partitioning collection recursively.
99- private def expandOutputPartitioning (
100- partitioning : PartitioningCollection ): PartitioningCollection = {
101- PartitioningCollection (partitioning.partitionings.flatMap {
102- case h : HashPartitioningLike => expandOutputPartitioning(h).partitionings
103- case c : PartitioningCollection => Seq (expandOutputPartitioning(c))
102+ // Expands the given partitioning recursively.
103+ private def expandOutputPartitioning (partitioning : Partitioning ): Seq [Partitioning ] = {
104+ partitioning match {
105+ case c : PartitioningCollection => c.partitionings.flatMap(expandOutputPartitioning)
106+ case p : Partitioning with Expression =>
107+ // Expands the given partitioning, that is also an expression, by substituting streamed keys
108+ // with build keys.
109+ // For example, if the expressions for the given partitioning are Seq("a", "b", "c") where
110+ // the streamed keys are Seq("b", "c") and the build keys are Seq("x", "y"), the expanded
111+ // partitioning will have the following expressions:
112+ // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y").
113+ // The expanded expressions are returned as `Seq[Partitioning]`.
114+ p.multiTransformDown {
115+ case e : Expression if streamedKeyToBuildKeyMapping.contains(e.canonicalized) =>
116+ e +: streamedKeyToBuildKeyMapping(e.canonicalized)
117+ }.asInstanceOf [LazyList [Partitioning ]]
118+ .take(conf.broadcastHashJoinOutputPartitioningExpandLimit)
104119 case other => Seq (other)
105- })
106- }
107-
108- // Expands the given hash partitioning by substituting streamed keys with build keys.
109- // For example, if the expressions for the given partitioning are Seq("a", "b", "c")
110- // where the streamed keys are Seq("b", "c") and the build keys are Seq("x", "y"),
111- // the expanded partitioning will have the following expressions:
112- // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y").
113- // The expanded expressions are returned as PartitioningCollection.
114- private def expandOutputPartitioning (
115- partitioning : HashPartitioningLike ): PartitioningCollection = {
116- PartitioningCollection (partitioning.multiTransformDown {
117- case e : Expression if streamedKeyToBuildKeyMapping.contains(e.canonicalized) =>
118- e +: streamedKeyToBuildKeyMapping(e.canonicalized)
119- }.asInstanceOf [LazyList [HashPartitioningLike ]]
120- .take(conf.broadcastHashJoinOutputPartitioningExpandLimit))
120+ }
121121 }
122122
123123 protected override def doExecute (): RDD [InternalRow ] = {
0 commit comments