@@ -25,7 +25,7 @@ import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled}
25
25
import org .apache .spark .sql .catalyst .TableIdentifier
26
26
import org .apache .spark .sql .catalyst .analysis .UnresolvedRelation
27
27
import org .apache .spark .sql .catalyst .expressions .{Ascending , SortOrder }
28
- import org .apache .spark .sql .execution .SortExec
28
+ import org .apache .spark .sql .execution .{ BinaryExecNode , SortExec }
29
29
import org .apache .spark .sql .execution .joins ._
30
30
import org .apache .spark .sql .internal .SQLConf
31
31
import org .apache .spark .sql .test .SharedSQLContext
@@ -857,4 +857,29 @@ class JoinSuite extends QueryTest with SharedSQLContext {
857
857
858
858
joinQueries.foreach(assertJoinOrdering)
859
859
}
860
+
861
+ test(" SPARK-22445 Respect stream-side child's needCopyResult in BroadcastHashJoin" ) {
862
+ val df1 = Seq ((2 , 3 ), (2 , 5 ), (2 , 2 ), (3 , 8 ), (2 , 1 )).toDF(" k" , " v1" )
863
+ val df2 = Seq ((2 , 8 ), (3 , 7 ), (3 , 4 ), (1 , 2 )).toDF(" k" , " v2" )
864
+ val df3 = Seq ((1 , 1 ), (3 , 2 ), (4 , 3 ), (5 , 1 )).toDF(" k" , " v3" )
865
+
866
+ withSQLConf(
867
+ SQLConf .AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ,
868
+ SQLConf .JOIN_REORDER_ENABLED .key -> " false" ) {
869
+ val df = df1.join(df2, " k" ).join(functions.broadcast(df3), " k" )
870
+ val plan = df.queryExecution.sparkPlan
871
+
872
+ // Check if `needCopyResult` in `BroadcastHashJoin` is correct when smj->bhj
873
+ val joins = new collection.mutable.ArrayBuffer [BinaryExecNode ]()
874
+ plan.foreachUp {
875
+ case j : BroadcastHashJoinExec => joins += j
876
+ case j : SortMergeJoinExec => joins += j
877
+ case _ =>
878
+ }
879
+ assert(joins.size == 2 )
880
+ assert(joins(0 ).isInstanceOf [SortMergeJoinExec ])
881
+ assert(joins(1 ).isInstanceOf [BroadcastHashJoinExec ])
882
+ checkAnswer(df, Row (3 , 8 , 7 , 2 ) :: Row (3 , 8 , 4 , 2 ) :: Nil )
883
+ }
884
+ }
860
885
}
0 commit comments