|
18 | 18 | package org.apache.spark.sql.execution
|
19 | 19 |
|
20 | 20 | import org.apache.spark.rdd.RDD
|
21 |
| -import org.apache.spark.sql.{execution, Row} |
| 21 | +import org.apache.spark.sql.{execution, DataFrame, Row} |
22 | 22 | import org.apache.spark.sql.catalyst.InternalRow
|
23 | 23 | import org.apache.spark.sql.catalyst.expressions._
|
24 | 24 | import org.apache.spark.sql.catalyst.plans._
|
25 | 25 | import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, Repartition, Sort, Union}
|
26 | 26 | import org.apache.spark.sql.catalyst.plans.physical._
|
27 |
| -import org.apache.spark.sql.execution.columnar.InMemoryRelation |
| 27 | +import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} |
28 | 28 | import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec}
|
29 | 29 | import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
|
30 | 30 | import org.apache.spark.sql.functions._
|
@@ -703,6 +703,66 @@ class PlannerSuite extends SharedSQLContext {
|
703 | 703 | Range(1, 2, 1, 1)))
|
704 | 704 | df.queryExecution.executedPlan.execute()
|
705 | 705 | }
|
| 706 | + |
| 707 | + test("SPARK-24556: always rewrite output partitioning in ReusedExchangeExec " + |
| 708 | + "and InMemoryTableScanExec") { |
| 709 | + def checkOutputPartitioningRewrite( |
| 710 | + plans: Seq[SparkPlan], |
| 711 | + expectedPartitioningClass: Class[_]): Unit = { |
| 712 | + assert(plans.size == 1) |
| 713 | + val plan = plans.head |
| 714 | + val partitioning = plan.outputPartitioning |
| 715 | + assert(partitioning.getClass == expectedPartitioningClass) |
| 716 | + val partitionedAttrs = partitioning.asInstanceOf[Expression].references |
| 717 | + assert(partitionedAttrs.subsetOf(plan.outputSet)) |
| 718 | + } |
| 719 | + |
| 720 | + def checkReusedExchangeOutputPartitioningRewrite( |
| 721 | + df: DataFrame, |
| 722 | + expectedPartitioningClass: Class[_]): Unit = { |
| 723 | + val reusedExchange = df.queryExecution.executedPlan.collect { |
| 724 | + case r: ReusedExchangeExec => r |
| 725 | + } |
| 726 | + checkOutputPartitioningRewrite(reusedExchange, expectedPartitioningClass) |
| 727 | + } |
| 728 | + |
| 729 | + def checkInMemoryTableScanOutputPartitioningRewrite( |
| 730 | + df: DataFrame, |
| 731 | + expectedPartitioningClass: Class[_]): Unit = { |
| 732 | + val inMemoryScan = df.queryExecution.executedPlan.collect { |
| 733 | + case m: InMemoryTableScanExec => m |
| 734 | + } |
| 735 | + checkOutputPartitioningRewrite(inMemoryScan, expectedPartitioningClass) |
| 736 | + } |
| 737 | + |
| 738 | + // ReusedExchange is HashPartitioning |
| 739 | + val df1 = Seq(1 -> "a").toDF("i", "j").repartition($"i") |
| 740 | + val df2 = Seq(1 -> "a").toDF("i", "j").repartition($"i") |
| 741 | + checkReusedExchangeOutputPartitioningRewrite(df1.union(df2), classOf[HashPartitioning]) |
| 742 | + |
| 743 | + // ReusedExchange is RangePartitioning |
| 744 | + val df3 = Seq(1 -> "a").toDF("i", "j").orderBy($"i") |
| 745 | + val df4 = Seq(1 -> "a").toDF("i", "j").orderBy($"i") |
| 746 | + checkReusedExchangeOutputPartitioningRewrite(df3.union(df4), classOf[RangePartitioning]) |
| 747 | + |
| 748 | + // InMemoryTableScan is HashPartitioning |
| 749 | + Seq(1 -> "a").toDF("i", "j").repartition($"i").persist() |
| 750 | + checkInMemoryTableScanOutputPartitioningRewrite( |
| 751 | + Seq(1 -> "a").toDF("i", "j").repartition($"i"), classOf[HashPartitioning]) |
| 752 | + |
| 753 | + // InMemoryTableScan is RangePartitioning |
| 754 | + spark.range(1, 100, 1, 10).toDF().persist() |
| 755 | + checkInMemoryTableScanOutputPartitioningRewrite( |
| 756 | + spark.range(1, 100, 1, 10).toDF(), classOf[RangePartitioning]) |
| 757 | + |
| 758 | + // InMemoryTableScan is PartitioningCollection |
| 759 | + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { |
| 760 | + Seq(1 -> "a").toDF("i", "j").join(Seq(1 -> "a").toDF("m", "n"), $"i" === $"m").persist() |
| 761 | + checkInMemoryTableScanOutputPartitioningRewrite( |
| 762 | + Seq(1 -> "a").toDF("i", "j").join(Seq(1 -> "a").toDF("m", "n"), $"i" === $"m"), |
| 763 | + classOf[PartitioningCollection]) |
| 764 | + } |
| 765 | + } |
706 | 766 | }
|
707 | 767 |
|
708 | 768 | // Used for unit-testing EnsureRequirements
|
|
0 commit comments