@@ -802,6 +802,32 @@ fn get_corrected_filter_mask(
802802
803803 Some ( corrected_mask. finish ( ) )
804804 }
805+ JoinType :: LeftAnti => {
806+ for i in 0 ..row_indices_length {
807+ let last_index =
808+ last_index_for_row ( i, row_indices, batch_ids, row_indices_length) ;
809+
810+ if filter_mask. value ( i) {
811+ seen_true = true ;
812+ }
813+
814+ if last_index {
815+ if !seen_true {
816+ corrected_mask. append_value ( true ) ;
817+ } else {
818+ corrected_mask. append_null ( ) ;
819+ }
820+
821+ seen_true = false ;
822+ } else {
823+ corrected_mask. append_null ( ) ;
824+ }
825+ }
826+
827+ let null_matched = expected_size - corrected_mask. len ( ) ;
828+ corrected_mask. extend ( vec ! [ Some ( true ) ; null_matched] ) ;
829+ Some ( corrected_mask. finish ( ) )
830+ }
805831 // Only outer joins needs to keep track of processed rows and apply corrected filter mask
806832 _ => None ,
807833 }
@@ -835,15 +861,18 @@ impl Stream for SMJStream {
835861 JoinType :: Left
836862 | JoinType :: LeftSemi
837863 | JoinType :: Right
864+ | JoinType :: LeftAnti
838865 )
839866 {
840867 self . freeze_all ( ) ?;
841868
842869 if !self . output_record_batches . batches . is_empty ( )
843- && self . buffered_data . scanning_finished ( )
844870 {
845- let out_batch = self . filter_joined_batch ( ) ?;
846- return Poll :: Ready ( Some ( Ok ( out_batch) ) ) ;
871+ let out_filtered_batch =
872+ self . filter_joined_batch ( ) ?;
873+ return Poll :: Ready ( Some ( Ok (
874+ out_filtered_batch,
875+ ) ) ) ;
847876 }
848877 }
849878
@@ -907,15 +936,17 @@ impl Stream for SMJStream {
907936 // because target output batch size can be hit in the middle of
908937 // filtering causing the filtering to be incomplete and causing
909938 // correctness issues
910- let record_batch = if ! ( self . filter . is_some ( )
939+ if self . filter . is_some ( )
911940 && matches ! (
912941 self . join_type,
913- JoinType :: Left | JoinType :: LeftSemi | JoinType :: Right
914- ) ) {
915- record_batch
916- } else {
942+ JoinType :: Left
943+ | JoinType :: LeftSemi
944+ | JoinType :: Right
945+ | JoinType :: LeftAnti
946+ )
947+ {
917948 continue ;
918- } ;
949+ }
919950
920951 return Poll :: Ready ( Some ( Ok ( record_batch) ) ) ;
921952 }
@@ -929,7 +960,10 @@ impl Stream for SMJStream {
929960 if self . filter . is_some ( )
930961 && matches ! (
931962 self . join_type,
932- JoinType :: Left | JoinType :: LeftSemi | JoinType :: Right
963+ JoinType :: Left
964+ | JoinType :: LeftSemi
965+ | JoinType :: Right
966+ | JoinType :: LeftAnti
933967 )
934968 {
935969 let out = self . filter_joined_batch ( ) ?;
@@ -1273,11 +1307,7 @@ impl SMJStream {
12731307 } ;
12741308
12751309 if matches ! ( self . join_type, JoinType :: LeftAnti ) && self . filter . is_some ( ) {
1276- join_streamed = !self
1277- . streamed_batch
1278- . join_filter_matched_idxs
1279- . contains ( & ( self . streamed_batch . idx as u64 ) )
1280- && !self . streamed_joined ;
1310+ join_streamed = !self . streamed_joined ;
12811311 join_buffered = join_streamed;
12821312 }
12831313 }
@@ -1519,7 +1549,10 @@ impl SMJStream {
15191549 // Push the filtered batch which contains rows passing join filter to the output
15201550 if matches ! (
15211551 self . join_type,
1522- JoinType :: Left | JoinType :: LeftSemi | JoinType :: Right
1552+ JoinType :: Left
1553+ | JoinType :: LeftSemi
1554+ | JoinType :: Right
1555+ | JoinType :: LeftAnti
15231556 ) {
15241557 self . output_record_batches
15251558 . batches
@@ -1654,7 +1687,10 @@ impl SMJStream {
16541687 if !( self . filter . is_some ( )
16551688 && matches ! (
16561689 self . join_type,
1657- JoinType :: Left | JoinType :: LeftSemi | JoinType :: Right
1690+ JoinType :: Left
1691+ | JoinType :: LeftSemi
1692+ | JoinType :: Right
1693+ | JoinType :: LeftAnti
16581694 ) )
16591695 {
16601696 self . output_record_batches . batches . clear ( ) ;
@@ -1727,7 +1763,7 @@ impl SMJStream {
17271763 & self . schema ,
17281764 & [ filtered_record_batch, null_joined_streamed_batch] ,
17291765 ) ?;
1730- } else if matches ! ( self . join_type, JoinType :: LeftSemi ) {
1766+ } else if matches ! ( self . join_type, JoinType :: LeftSemi | JoinType :: LeftAnti ) {
17311767 let output_column_indices = ( 0 ..streamed_columns_length) . collect :: < Vec < _ > > ( ) ;
17321768 filtered_record_batch =
17331769 filtered_record_batch. project ( & output_column_indices) ?;
@@ -3349,6 +3385,7 @@ mod tests {
33493385 batch_ids : vec ! [ ] ,
33503386 } ;
33513387
3388+ // Insert already prejoined non-filtered rows
33523389 batches. batches . push ( RecordBatch :: try_new (
33533390 Arc :: clone ( & schema) ,
33543391 vec ! [
@@ -3835,6 +3872,178 @@ mod tests {
38353872 Ok ( ( ) )
38363873 }
38373874
3875+ #[ tokio:: test]
3876+ async fn test_left_anti_join_filtered_mask ( ) -> Result < ( ) > {
3877+ let mut joined_batches = build_joined_record_batches ( ) ?;
3878+ let schema = joined_batches. batches . first ( ) . unwrap ( ) . schema ( ) ;
3879+
3880+ let output = concat_batches ( & schema, & joined_batches. batches ) ?;
3881+ let out_mask = joined_batches. filter_mask . finish ( ) ;
3882+ let out_indices = joined_batches. row_indices . finish ( ) ;
3883+
3884+ assert_eq ! (
3885+ get_corrected_filter_mask(
3886+ LeftAnti ,
3887+ & UInt64Array :: from( vec![ 0 ] ) ,
3888+ & [ 0usize ] ,
3889+ & BooleanArray :: from( vec![ true ] ) ,
3890+ 1
3891+ )
3892+ . unwrap( ) ,
3893+ BooleanArray :: from( vec![ None ] )
3894+ ) ;
3895+
3896+ assert_eq ! (
3897+ get_corrected_filter_mask(
3898+ LeftAnti ,
3899+ & UInt64Array :: from( vec![ 0 ] ) ,
3900+ & [ 0usize ] ,
3901+ & BooleanArray :: from( vec![ false ] ) ,
3902+ 1
3903+ )
3904+ . unwrap( ) ,
3905+ BooleanArray :: from( vec![ Some ( true ) ] )
3906+ ) ;
3907+
3908+ assert_eq ! (
3909+ get_corrected_filter_mask(
3910+ LeftAnti ,
3911+ & UInt64Array :: from( vec![ 0 , 0 ] ) ,
3912+ & [ 0usize ; 2 ] ,
3913+ & BooleanArray :: from( vec![ true , true ] ) ,
3914+ 2
3915+ )
3916+ . unwrap( ) ,
3917+ BooleanArray :: from( vec![ None , None ] )
3918+ ) ;
3919+
3920+ assert_eq ! (
3921+ get_corrected_filter_mask(
3922+ LeftAnti ,
3923+ & UInt64Array :: from( vec![ 0 , 0 , 0 ] ) ,
3924+ & [ 0usize ; 3 ] ,
3925+ & BooleanArray :: from( vec![ true , true , true ] ) ,
3926+ 3
3927+ )
3928+ . unwrap( ) ,
3929+ BooleanArray :: from( vec![ None , None , None ] )
3930+ ) ;
3931+
3932+ assert_eq ! (
3933+ get_corrected_filter_mask(
3934+ LeftAnti ,
3935+ & UInt64Array :: from( vec![ 0 , 0 , 0 ] ) ,
3936+ & [ 0usize ; 3 ] ,
3937+ & BooleanArray :: from( vec![ true , false , true ] ) ,
3938+ 3
3939+ )
3940+ . unwrap( ) ,
3941+ BooleanArray :: from( vec![ None , None , None ] )
3942+ ) ;
3943+
3944+ assert_eq ! (
3945+ get_corrected_filter_mask(
3946+ LeftAnti ,
3947+ & UInt64Array :: from( vec![ 0 , 0 , 0 ] ) ,
3948+ & [ 0usize ; 3 ] ,
3949+ & BooleanArray :: from( vec![ false , false , true ] ) ,
3950+ 3
3951+ )
3952+ . unwrap( ) ,
3953+ BooleanArray :: from( vec![ None , None , None ] )
3954+ ) ;
3955+
3956+ assert_eq ! (
3957+ get_corrected_filter_mask(
3958+ LeftAnti ,
3959+ & UInt64Array :: from( vec![ 0 , 0 , 0 ] ) ,
3960+ & [ 0usize ; 3 ] ,
3961+ & BooleanArray :: from( vec![ false , true , true ] ) ,
3962+ 3
3963+ )
3964+ . unwrap( ) ,
3965+ BooleanArray :: from( vec![ None , None , None ] )
3966+ ) ;
3967+
3968+ assert_eq ! (
3969+ get_corrected_filter_mask(
3970+ LeftAnti ,
3971+ & UInt64Array :: from( vec![ 0 , 0 , 0 ] ) ,
3972+ & [ 0usize ; 3 ] ,
3973+ & BooleanArray :: from( vec![ false , false , false ] ) ,
3974+ 3
3975+ )
3976+ . unwrap( ) ,
3977+ BooleanArray :: from( vec![ None , None , Some ( true ) ] )
3978+ ) ;
3979+
3980+ let corrected_mask = get_corrected_filter_mask (
3981+ LeftAnti ,
3982+ & out_indices,
3983+ & joined_batches. batch_ids ,
3984+ & out_mask,
3985+ output. num_rows ( ) ,
3986+ )
3987+ . unwrap ( ) ;
3988+
3989+ assert_eq ! (
3990+ corrected_mask,
3991+ BooleanArray :: from( vec![
3992+ None ,
3993+ None ,
3994+ None ,
3995+ None ,
3996+ None ,
3997+ Some ( true ) ,
3998+ None ,
3999+ Some ( true )
4000+ ] )
4001+ ) ;
4002+
4003+ let filtered_rb = filter_record_batch ( & output, & corrected_mask) ?;
4004+
4005+ assert_batches_eq ! (
4006+ & [
4007+ "+---+----+---+----+" ,
4008+ "| a | b | x | y |" ,
4009+ "+---+----+---+----+" ,
4010+ "| 1 | 13 | 1 | 12 |" ,
4011+ "| 1 | 14 | 1 | 11 |" ,
4012+ "+---+----+---+----+" ,
4013+ ] ,
4014+ & [ filtered_rb]
4015+ ) ;
4016+
4017+ // output null rows
4018+ let null_mask = arrow:: compute:: not ( & corrected_mask) ?;
4019+ assert_eq ! (
4020+ null_mask,
4021+ BooleanArray :: from( vec![
4022+ None ,
4023+ None ,
4024+ None ,
4025+ None ,
4026+ None ,
4027+ Some ( false ) ,
4028+ None ,
4029+ Some ( false ) ,
4030+ ] )
4031+ ) ;
4032+
4033+ let null_joined_batch = filter_record_batch ( & output, & null_mask) ?;
4034+
4035+ assert_batches_eq ! (
4036+ & [
4037+ "+---+---+---+---+" ,
4038+ "| a | b | x | y |" ,
4039+ "+---+---+---+---+" ,
4040+ "+---+---+---+---+" ,
4041+ ] ,
4042+ & [ null_joined_batch]
4043+ ) ;
4044+ Ok ( ( ) )
4045+ }
4046+
38384047 /// Returns the column names on the schema
38394048 fn columns ( schema : & Schema ) -> Vec < String > {
38404049 schema. fields ( ) . iter ( ) . map ( |f| f. name ( ) . clone ( ) ) . collect ( )
0 commit comments