@@ -34,7 +34,7 @@ mod tests {
3434 physical_plan:: { ExecutionPlan , common, joins:: utils:: * , test:: TestMemoryExec } ,
3535 prelude:: SessionContext ,
3636 } ;
37-
37+ use datafusion :: prelude :: SessionConfig ;
3838 use crate :: {
3939 broadcast_join_build_hash_map_exec:: BroadcastJoinBuildHashMapExec ,
4040 broadcast_join_exec:: BroadcastJoinExec ,
@@ -264,6 +264,92 @@ mod tests {
264264 Ok ( ( columns, batches) )
265265 }
266266
267+ async fn join_collect_with_batch_size (
268+ test_type : TestType ,
269+ left : Arc < dyn ExecutionPlan > ,
270+ right : Arc < dyn ExecutionPlan > ,
271+ on : JoinOn ,
272+ join_type : JoinType ,
273+ batch_size : usize
274+ ) -> Result < ( Vec < String > , Vec < RecordBatch > ) > {
275+ MemManager :: init ( 1000000 ) ;
276+ let session_config = SessionConfig :: new ( ) . with_batch_size ( batch_size) ;
277+ let session_ctx = SessionContext :: new_with_config ( session_config) ;
278+ let session_ctx = SessionContext :: new ( ) ;
279+ let task_ctx = session_ctx. task_ctx ( ) ;
280+ let schema = build_join_schema_for_test ( & left. schema ( ) , & right. schema ( ) , join_type) ?;
281+
282+ let join: Arc < dyn ExecutionPlan > = match test_type {
283+ SMJ => {
284+ let sort_options = vec ! [ SortOptions :: default ( ) ; on. len( ) ] ;
285+ Arc :: new ( SortMergeJoinExec :: try_new (
286+ schema,
287+ left,
288+ right,
289+ on,
290+ join_type,
291+ sort_options,
292+ ) ?)
293+ }
294+ BHJLeftProbed => {
295+ let right = Arc :: new ( BroadcastJoinBuildHashMapExec :: new (
296+ right,
297+ on. iter ( ) . map ( |( _, right_key) | right_key. clone ( ) ) . collect ( ) ,
298+ ) ) ;
299+ Arc :: new ( BroadcastJoinExec :: try_new (
300+ schema,
301+ left,
302+ right,
303+ on,
304+ join_type,
305+ JoinSide :: Right ,
306+ true ,
307+ None ,
308+ ) ?)
309+ }
310+ BHJRightProbed => {
311+ let left = Arc :: new ( BroadcastJoinBuildHashMapExec :: new (
312+ left,
313+ on. iter ( ) . map ( |( left_key, _) | left_key. clone ( ) ) . collect ( ) ,
314+ ) ) ;
315+ Arc :: new ( BroadcastJoinExec :: try_new (
316+ schema,
317+ left,
318+ right,
319+ on,
320+ join_type,
321+ JoinSide :: Left ,
322+ true ,
323+ None ,
324+ ) ?)
325+ }
326+ SHJLeftProbed => Arc :: new ( BroadcastJoinExec :: try_new (
327+ schema,
328+ left,
329+ right,
330+ on,
331+ join_type,
332+ JoinSide :: Right ,
333+ false ,
334+ None ,
335+ ) ?) ,
336+ SHJRightProbed => Arc :: new ( BroadcastJoinExec :: try_new (
337+ schema,
338+ left,
339+ right,
340+ on,
341+ join_type,
342+ JoinSide :: Left ,
343+ false ,
344+ None ,
345+ ) ?) ,
346+ } ;
347+ let columns = columns ( & join. schema ( ) ) ;
348+ let stream = join. execute ( 0 , task_ctx) ?;
349+ let batches = common:: collect ( stream) . await ?;
350+ Ok ( ( columns, batches) )
351+ }
352+
267353 const ALL_TEST_TYPE : [ TestType ; 5 ] = [
268354 SMJ ,
269355 BHJLeftProbed ,
@@ -428,6 +514,78 @@ mod tests {
428514 Ok ( ( ) )
429515 }
430516
517+ #[ tokio:: test( flavor = "multi_thread" , worker_threads = 1 ) ]
518+ async fn join_inner_batchsize ( ) -> Result < ( ) > {
519+ for test_type in ALL_TEST_TYPE {
520+ let left = build_table (
521+ ( "a1" , & vec ! [ 1 , 1 , 1 , 1 , 1 ] ) ,
522+ ( "b1" , & vec ! [ 1 , 2 , 3 , 4 , 5 ] ) ,
523+ ( "c1" , & vec ! [ 1 , 2 , 3 , 4 , 5 ] ) ,
524+ ) ;
525+ let right = build_table (
526+ ( "a2" , & vec ! [ 1 , 1 , 1 , 1 , 1 , 1 , 1 ] ) ,
527+ ( "b2" , & vec ! [ 1 , 2 , 3 , 4 , 5 , 6 , 7 ] ) ,
528+ ( "c2" , & vec ! [ 1 , 2 , 3 , 4 , 5 , 6 , 7 ] ) ,
529+ ) ;
530+ let on: JoinOn = vec ! [ (
531+ Arc :: new( Column :: new_with_schema( "a1" , & left. schema( ) ) ?) ,
532+ Arc :: new( Column :: new_with_schema( "a2" , & right. schema( ) ) ?) ,
533+ ) ] ;
534+ let expected = vec ! [
535+ "+----+----+----+----+----+----+" ,
536+ "| a1 | b1 | c1 | a2 | b2 | c2 |" ,
537+ "+----+----+----+----+----+----+" ,
538+ "| 1 | 1 | 1 | 1 | 1 | 1 |" ,
539+ "| 1 | 1 | 1 | 1 | 2 | 2 |" ,
540+ "| 1 | 1 | 1 | 1 | 3 | 3 |" ,
541+ "| 1 | 1 | 1 | 1 | 4 | 4 |" ,
542+ "| 1 | 1 | 1 | 1 | 5 | 5 |" ,
543+ "| 1 | 1 | 1 | 1 | 6 | 6 |" ,
544+ "| 1 | 1 | 1 | 1 | 7 | 7 |" ,
545+ "| 1 | 2 | 2 | 1 | 1 | 1 |" ,
546+ "| 1 | 2 | 2 | 1 | 2 | 2 |" ,
547+ "| 1 | 2 | 2 | 1 | 3 | 3 |" ,
548+ "| 1 | 2 | 2 | 1 | 4 | 4 |" ,
549+ "| 1 | 2 | 2 | 1 | 5 | 5 |" ,
550+ "| 1 | 2 | 2 | 1 | 6 | 6 |" ,
551+ "| 1 | 2 | 2 | 1 | 7 | 7 |" ,
552+ "| 1 | 3 | 3 | 1 | 1 | 1 |" ,
553+ "| 1 | 3 | 3 | 1 | 2 | 2 |" ,
554+ "| 1 | 3 | 3 | 1 | 3 | 3 |" ,
555+ "| 1 | 3 | 3 | 1 | 4 | 4 |" ,
556+ "| 1 | 3 | 3 | 1 | 5 | 5 |" ,
557+ "| 1 | 3 | 3 | 1 | 6 | 6 |" ,
558+ "| 1 | 3 | 3 | 1 | 7 | 7 |" ,
559+ "| 1 | 4 | 4 | 1 | 1 | 1 |" ,
560+ "| 1 | 4 | 4 | 1 | 2 | 2 |" ,
561+ "| 1 | 4 | 4 | 1 | 3 | 3 |" ,
562+ "| 1 | 4 | 4 | 1 | 4 | 4 |" ,
563+ "| 1 | 4 | 4 | 1 | 5 | 5 |" ,
564+ "| 1 | 4 | 4 | 1 | 6 | 6 |" ,
565+ "| 1 | 4 | 4 | 1 | 7 | 7 |" ,
566+ "| 1 | 5 | 5 | 1 | 1 | 1 |" ,
567+ "| 1 | 5 | 5 | 1 | 2 | 2 |" ,
568+ "| 1 | 5 | 5 | 1 | 3 | 3 |" ,
569+ "| 1 | 5 | 5 | 1 | 4 | 4 |" ,
570+ "| 1 | 5 | 5 | 1 | 5 | 5 |" ,
571+ "| 1 | 5 | 5 | 1 | 6 | 6 |" ,
572+ "| 1 | 5 | 5 | 1 | 7 | 7 |" ,
573+ "+----+----+----+----+----+----+" ,
574+ ] ;
575+ let ( _, batches) = join_collect_with_batch_size ( test_type, left. clone ( ) , right. clone ( ) , on. clone ( ) , Inner , 2 ) . await ?;
576+ assert_batches_sorted_eq ! ( expected, & batches) ;
577+ let ( _, batches) = join_collect_with_batch_size ( test_type, left. clone ( ) , right. clone ( ) , on. clone ( ) , Inner , 3 ) . await ?;
578+ assert_batches_sorted_eq ! ( expected, & batches) ;
579+ let ( _, batches) = join_collect_with_batch_size ( test_type, left. clone ( ) , right. clone ( ) , on. clone ( ) , Inner , 4 ) . await ?;
580+ assert_batches_sorted_eq ! ( expected, & batches) ;
581+ let ( _, batches) = join_collect_with_batch_size ( test_type, left. clone ( ) , right. clone ( ) , on. clone ( ) , Inner , 5 ) . await ?;
582+ assert_batches_sorted_eq ! ( expected, & batches) ;
583+ let ( _, batches) = join_collect_with_batch_size ( test_type, left. clone ( ) , right. clone ( ) , on. clone ( ) , Inner , 7 ) . await ?;
584+ assert_batches_sorted_eq ! ( expected, & batches) ;
585+ }
586+ Ok ( ( ) )
587+ }
588+
431589 #[ tokio:: test( flavor = "multi_thread" , worker_threads = 1 ) ]
432590 async fn join_left_one ( ) -> Result < ( ) > {
433591 for test_type in ALL_TEST_TYPE {
0 commit comments