Skip to content

Commit 9dbe53e

Browse files
yucaicloud-fan
authored andcommitted
[SPARK-24556][SQL] Always rewrite output partitioning in ReusedExchangeExec and InMemoryTableScanExec
## What changes were proposed in this pull request? Currently, ReusedExchange and InMemoryTableScanExec only rewrite output partitioning if child's partitioning is HashPartitioning and do nothing for other partitioning, e.g., RangePartitioning. We should always rewrite it, otherwise, unnecessary shuffle could be introduced like https://issues.apache.org/jira/browse/SPARK-24556. ## How was this patch tested? Add new tests. Author: yucai <[email protected]> Closes apache#21564 from yucai/SPARK-24556.
1 parent a78a904 commit 9dbe53e

File tree

3 files changed

+67
-7
lines changed

3 files changed

+67
-7
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.dsl.expressions._
2424
import org.apache.spark.sql.catalyst.expressions._
2525
import org.apache.spark.sql.catalyst.plans.QueryPlan
26-
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}
26+
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
2727
import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, SparkPlan, WholeStageCodegenExec}
2828
import org.apache.spark.sql.execution.vectorized._
2929
import org.apache.spark.sql.types._
@@ -169,8 +169,8 @@ case class InMemoryTableScanExec(
169169
// But the cached version could alias output, so we need to replace output.
170170
override def outputPartitioning: Partitioning = {
171171
relation.cachedPlan.outputPartitioning match {
172-
case h: HashPartitioning => updateAttribute(h).asInstanceOf[HashPartitioning]
173-
case _ => relation.cachedPlan.outputPartitioning
172+
case e: Expression => updateAttribute(e).asInstanceOf[Partitioning]
173+
case other => other
174174
}
175175
}
176176

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.broadcast
2424
import org.apache.spark.rdd.RDD
2525
import org.apache.spark.sql.catalyst.InternalRow
2626
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SortOrder}
27-
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}
27+
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
2828
import org.apache.spark.sql.catalyst.rules.Rule
2929
import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode}
3030
import org.apache.spark.sql.internal.SQLConf
@@ -70,7 +70,7 @@ case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchan
7070
}
7171

7272
override def outputPartitioning: Partitioning = child.outputPartitioning match {
73-
case h: HashPartitioning => h.copy(expressions = h.expressions.map(updateAttr))
73+
case e: Expression => updateAttr(e).asInstanceOf[Partitioning]
7474
case other => other
7575
}
7676

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@
1818
package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.rdd.RDD
21-
import org.apache.spark.sql.{execution, Row}
21+
import org.apache.spark.sql.{execution, DataFrame, Row}
2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.plans._
2525
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, Repartition, Sort, Union}
2626
import org.apache.spark.sql.catalyst.plans.physical._
27-
import org.apache.spark.sql.execution.columnar.InMemoryRelation
27+
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
2828
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec}
2929
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
3030
import org.apache.spark.sql.functions._
@@ -703,6 +703,66 @@ class PlannerSuite extends SharedSQLContext {
703703
Range(1, 2, 1, 1)))
704704
df.queryExecution.executedPlan.execute()
705705
}
706+
707+
test("SPARK-24556: always rewrite output partitioning in ReusedExchangeExec " +
708+
"and InMemoryTableScanExec") {
709+
def checkOutputPartitioningRewrite(
710+
plans: Seq[SparkPlan],
711+
expectedPartitioningClass: Class[_]): Unit = {
712+
assert(plans.size == 1)
713+
val plan = plans.head
714+
val partitioning = plan.outputPartitioning
715+
assert(partitioning.getClass == expectedPartitioningClass)
716+
val partitionedAttrs = partitioning.asInstanceOf[Expression].references
717+
assert(partitionedAttrs.subsetOf(plan.outputSet))
718+
}
719+
720+
def checkReusedExchangeOutputPartitioningRewrite(
721+
df: DataFrame,
722+
expectedPartitioningClass: Class[_]): Unit = {
723+
val reusedExchange = df.queryExecution.executedPlan.collect {
724+
case r: ReusedExchangeExec => r
725+
}
726+
checkOutputPartitioningRewrite(reusedExchange, expectedPartitioningClass)
727+
}
728+
729+
def checkInMemoryTableScanOutputPartitioningRewrite(
730+
df: DataFrame,
731+
expectedPartitioningClass: Class[_]): Unit = {
732+
val inMemoryScan = df.queryExecution.executedPlan.collect {
733+
case m: InMemoryTableScanExec => m
734+
}
735+
checkOutputPartitioningRewrite(inMemoryScan, expectedPartitioningClass)
736+
}
737+
738+
// ReusedExchange is HashPartitioning
739+
val df1 = Seq(1 -> "a").toDF("i", "j").repartition($"i")
740+
val df2 = Seq(1 -> "a").toDF("i", "j").repartition($"i")
741+
checkReusedExchangeOutputPartitioningRewrite(df1.union(df2), classOf[HashPartitioning])
742+
743+
// ReusedExchange is RangePartitioning
744+
val df3 = Seq(1 -> "a").toDF("i", "j").orderBy($"i")
745+
val df4 = Seq(1 -> "a").toDF("i", "j").orderBy($"i")
746+
checkReusedExchangeOutputPartitioningRewrite(df3.union(df4), classOf[RangePartitioning])
747+
748+
// InMemoryTableScan is HashPartitioning
749+
Seq(1 -> "a").toDF("i", "j").repartition($"i").persist()
750+
checkInMemoryTableScanOutputPartitioningRewrite(
751+
Seq(1 -> "a").toDF("i", "j").repartition($"i"), classOf[HashPartitioning])
752+
753+
// InMemoryTableScan is RangePartitioning
754+
spark.range(1, 100, 1, 10).toDF().persist()
755+
checkInMemoryTableScanOutputPartitioningRewrite(
756+
spark.range(1, 100, 1, 10).toDF(), classOf[RangePartitioning])
757+
758+
// InMemoryTableScan is PartitioningCollection
759+
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
760+
Seq(1 -> "a").toDF("i", "j").join(Seq(1 -> "a").toDF("m", "n"), $"i" === $"m").persist()
761+
checkInMemoryTableScanOutputPartitioningRewrite(
762+
Seq(1 -> "a").toDF("i", "j").join(Seq(1 -> "a").toDF("m", "n"), $"i" === $"m"),
763+
classOf[PartitioningCollection])
764+
}
765+
}
706766
}
707767

708768
// Used for unit-testing EnsureRequirements

0 commit comments

Comments
 (0)