@@ -20,6 +20,7 @@ package org.apache.spark.sql.streaming
20
20
import java .util .UUID
21
21
22
22
import org .apache .spark .rdd .RDD
23
+ import org .apache .spark .sql .DataFrame
23
24
import org .apache .spark .sql .catalyst .InternalRow
24
25
import org .apache .spark .sql .catalyst .analysis .UnresolvedAttribute
25
26
import org .apache .spark .sql .catalyst .expressions .Attribute
@@ -32,66 +33,71 @@ import org.apache.spark.sql.test.SharedSQLContext
32
33
class EnsureStatefulOpPartitioningSuite extends SparkPlanTest with SharedSQLContext {
33
34
34
35
import testImplicits ._
35
- super .beforeAll()
36
36
37
- private val baseDf = Seq (( 1 , " A " ), ( 2 , " b " )).toDF( " num " , " char " )
37
+ private var baseDf : DataFrame = null
38
38
39
- testEnsureStatefulOpPartitioning(
40
- " ClusteredDistribution generates Exchange with HashPartitioning" ,
41
- baseDf.queryExecution.sparkPlan,
42
- requiredDistribution = keys => ClusteredDistribution (keys),
43
- expectedPartitioning =
44
- keys => HashPartitioning (keys, spark.sessionState.conf.numShufflePartitions),
45
- expectShuffle = true )
39
+ override def beforeAll (): Unit = {
40
+ super .beforeAll()
41
+ baseDf = Seq ((1 , " A" ), (2 , " b" )).toDF(" num" , " char" )
42
+ }
43
+
44
+ test(" ClusteredDistribution generates Exchange with HashPartitioning" ) {
45
+ testEnsureStatefulOpPartitioning(
46
+ baseDf.queryExecution.sparkPlan,
47
+ requiredDistribution = keys => ClusteredDistribution (keys),
48
+ expectedPartitioning =
49
+ keys => HashPartitioning (keys, spark.sessionState.conf.numShufflePartitions),
50
+ expectShuffle = true )
51
+ }
46
52
47
- testEnsureStatefulOpPartitioning(
48
- " ClusteredDistribution with coalesce(1) generates Exchange with HashPartitioning" ,
49
- baseDf.coalesce(1 ).queryExecution.sparkPlan,
50
- requiredDistribution = keys => ClusteredDistribution (keys),
51
- expectedPartitioning =
52
- keys => HashPartitioning (keys, spark.sessionState.conf.numShufflePartitions),
53
- expectShuffle = true )
53
+ test(" ClusteredDistribution with coalesce(1) generates Exchange with HashPartitioning" ) {
54
+ testEnsureStatefulOpPartitioning(
55
+ baseDf.coalesce(1 ).queryExecution.sparkPlan,
56
+ requiredDistribution = keys => ClusteredDistribution (keys),
57
+ expectedPartitioning =
58
+ keys => HashPartitioning (keys, spark.sessionState.conf.numShufflePartitions),
59
+ expectShuffle = true )
60
+ }
54
61
55
- testEnsureStatefulOpPartitioning(
56
- " AllTuples generates Exchange with SinglePartition" ,
57
- baseDf.queryExecution.sparkPlan,
58
- requiredDistribution = _ => AllTuples ,
59
- expectedPartitioning = _ => SinglePartition ,
60
- expectShuffle = true )
62
+ test(" AllTuples generates Exchange with SinglePartition" ) {
63
+ testEnsureStatefulOpPartitioning(
64
+ baseDf.queryExecution.sparkPlan,
65
+ requiredDistribution = _ => AllTuples ,
66
+ expectedPartitioning = _ => SinglePartition ,
67
+ expectShuffle = true )
68
+ }
61
69
62
- testEnsureStatefulOpPartitioning(
63
- " AllTuples with coalesce(1) doesn't need Exchange" ,
64
- baseDf.coalesce(1 ).queryExecution.sparkPlan,
65
- requiredDistribution = _ => AllTuples ,
66
- expectedPartitioning = _ => SinglePartition ,
67
- expectShuffle = false )
70
+ test(" AllTuples with coalesce(1) doesn't need Exchange" ) {
71
+ testEnsureStatefulOpPartitioning(
72
+ baseDf.coalesce(1 ).queryExecution.sparkPlan,
73
+ requiredDistribution = _ => AllTuples ,
74
+ expectedPartitioning = _ => SinglePartition ,
75
+ expectShuffle = false )
76
+ }
68
77
69
78
/**
70
79
* For `StatefulOperator` with the given `requiredChildDistribution`, and child SparkPlan
71
80
* `inputPlan`, ensures that the incremental planner adds exchanges, if required, in order to
72
81
* ensure the expected partitioning.
73
82
*/
74
83
private def testEnsureStatefulOpPartitioning (
75
- testName : String ,
76
84
inputPlan : SparkPlan ,
77
85
requiredDistribution : Seq [Attribute ] => Distribution ,
78
86
expectedPartitioning : Seq [Attribute ] => Partitioning ,
79
87
expectShuffle : Boolean ): Unit = {
80
- test(testName) {
81
- val operator = TestStatefulOperator (inputPlan, requiredDistribution(inputPlan.output.take(1 )))
82
- val executed = executePlan(operator, OutputMode .Complete ())
83
- if (expectShuffle) {
84
- val exchange = executed.children.find(_.isInstanceOf [Exchange ])
85
- if (exchange.isEmpty) {
86
- fail(s " Was expecting an exchange but didn't get one in: \n $executed" )
87
- }
88
- assert(exchange.get ===
89
- ShuffleExchange (expectedPartitioning(inputPlan.output.take(1 )), inputPlan),
90
- s " Exchange didn't have expected properties: \n ${exchange.get}" )
91
- } else {
92
- assert(! executed.children.exists(_.isInstanceOf [Exchange ]),
93
- s " Unexpected exchange found in: \n $executed" )
88
+ val operator = TestStatefulOperator (inputPlan, requiredDistribution(inputPlan.output.take(1 )))
89
+ val executed = executePlan(operator, OutputMode .Complete ())
90
+ if (expectShuffle) {
91
+ val exchange = executed.children.find(_.isInstanceOf [Exchange ])
92
+ if (exchange.isEmpty) {
93
+ fail(s " Was expecting an exchange but didn't get one in: \n $executed" )
94
94
}
95
+ assert(exchange.get ===
96
+ ShuffleExchange (expectedPartitioning(inputPlan.output.take(1 )), inputPlan),
97
+ s " Exchange didn't have expected properties: \n ${exchange.get}" )
98
+ } else {
99
+ assert(! executed.children.exists(_.isInstanceOf [Exchange ]),
100
+ s " Unexpected exchange found in: \n $executed" )
95
101
}
96
102
}
97
103
0 commit comments