@@ -23,8 +23,9 @@ use std::sync::Arc;
23
23
use arrow:: array:: GenericStringArray ;
24
24
use arrow:: array:: {
25
25
ArrayRef , BooleanArray , Float32Array , Float64Array , Int16Array , Int32Array ,
26
- Int64Array , Int8Array , StringOffsetSizeTrait , UInt16Array , UInt32Array , UInt64Array ,
27
- UInt8Array ,
26
+ Int64Array , Int8Array , StringOffsetSizeTrait , TimestampMicrosecondArray ,
27
+ TimestampMillisecondArray , TimestampNanosecondArray , TimestampSecondArray ,
28
+ UInt16Array , UInt32Array , UInt64Array , UInt8Array ,
28
29
} ;
29
30
use arrow:: datatypes:: ArrowPrimitiveType ;
30
31
use arrow:: {
@@ -35,6 +36,7 @@ use arrow::{
35
36
use crate :: PhysicalExpr ;
36
37
use arrow:: array:: * ;
37
38
use arrow:: buffer:: { Buffer , MutableBuffer } ;
39
+ use arrow:: datatypes:: TimeUnit ;
38
40
use datafusion_common:: ScalarValue ;
39
41
use datafusion_common:: { DataFusionError , Result } ;
40
42
use datafusion_expr:: ColumnarValue ;
@@ -134,8 +136,8 @@ macro_rules! make_contains_primitive {
134
136
. iter( )
135
137
. flat_map( |expr| match expr {
136
138
ColumnarValue :: Scalar ( s) => match s {
137
- ScalarValue :: $SCALAR_VALUE( Some ( v) ) => Some ( * v) ,
138
- ScalarValue :: $SCALAR_VALUE( None ) => None ,
139
+ ScalarValue :: $SCALAR_VALUE( Some ( v) , .. ) => Some ( * v) ,
140
+ ScalarValue :: $SCALAR_VALUE( None , .. ) => None ,
139
141
ScalarValue :: Utf8 ( None ) => None ,
140
142
datatype => unimplemented!( "Unexpected type {} for InList" , datatype) ,
141
143
} ,
@@ -451,6 +453,36 @@ impl PhysicalExpr for InListExpr {
451
453
DataType :: LargeUtf8 => {
452
454
self . compare_utf8 :: < i64 > ( array, list_values, self . negated )
453
455
}
456
+ DataType :: Timestamp ( unit, _) => match unit {
457
+ TimeUnit :: Second => make_contains_primitive ! (
458
+ array,
459
+ list_values,
460
+ self . negated,
461
+ TimestampSecond ,
462
+ TimestampSecondArray
463
+ ) ,
464
+ TimeUnit :: Millisecond => make_contains_primitive ! (
465
+ array,
466
+ list_values,
467
+ self . negated,
468
+ TimestampMillisecond ,
469
+ TimestampMillisecondArray
470
+ ) ,
471
+ TimeUnit :: Microsecond => make_contains_primitive ! (
472
+ array,
473
+ list_values,
474
+ self . negated,
475
+ TimestampMicrosecond ,
476
+ TimestampMicrosecondArray
477
+ ) ,
478
+ TimeUnit :: Nanosecond => make_contains_primitive ! (
479
+ array,
480
+ list_values,
481
+ self . negated,
482
+ TimestampNanosecond ,
483
+ TimestampNanosecondArray
484
+ ) ,
485
+ } ,
454
486
datatype => Result :: Err ( DataFusionError :: NotImplemented ( format ! (
455
487
"InList does not support datatype {:?}." ,
456
488
datatype
@@ -713,4 +745,108 @@ mod tests {
713
745
714
746
Ok ( ( ) )
715
747
}
748
+
749
+ #[ test]
750
+ fn in_list_set_timestamp ( ) -> Result < ( ) > {
751
+ // Size at which to use a Set rather than Vec for `IN` / `NOT IN`
752
+ // Value chosen by the benchmark at
753
+ // https://github.com/apache/arrow-datafusion/pull/2156#discussion_r845198369
754
+ // TODO: add switch codeGen in In_List
755
+ let optimizer_inset_threshold: usize = 30 ;
756
+
757
+ let schema = Schema :: new ( vec ! [ Field :: new(
758
+ "a" ,
759
+ DataType :: Timestamp ( TimeUnit :: Microsecond , None ) ,
760
+ true ,
761
+ ) ] ) ;
762
+ let a = TimestampMicrosecondArray :: from ( vec ! [
763
+ Some ( 1388588401000000000 ) ,
764
+ Some ( 1288588501000000000 ) ,
765
+ None ,
766
+ ] ) ;
767
+ let col_a = col ( "a" , & schema) ?;
768
+ let batch = RecordBatch :: try_new ( Arc :: new ( schema. clone ( ) ) , vec ! [ Arc :: new( a) ] ) ?;
769
+
770
+ let mut list = vec ! [
771
+ lit( ScalarValue :: TimestampMicrosecond (
772
+ Some ( 1388588401000000000 ) ,
773
+ None ,
774
+ ) ) ,
775
+ lit( ScalarValue :: TimestampMicrosecond ( None , None ) ) ,
776
+ lit( ScalarValue :: TimestampMicrosecond (
777
+ Some ( 1388588401000000001 ) ,
778
+ None ,
779
+ ) ) ,
780
+ ] ;
781
+ let start_ts = 1388588401000000001 ;
782
+ for v in start_ts..( start_ts + optimizer_inset_threshold + 4 ) {
783
+ list. push ( lit ( ScalarValue :: TimestampMicrosecond ( Some ( v as i64 ) , None ) ) ) ;
784
+ }
785
+
786
+ in_list ! (
787
+ batch,
788
+ list. clone( ) ,
789
+ & false ,
790
+ vec![ Some ( true ) , None , None ] ,
791
+ col_a. clone( )
792
+ ) ;
793
+
794
+ in_list ! (
795
+ batch,
796
+ list. clone( ) ,
797
+ & true ,
798
+ vec![ Some ( false ) , None , None ] ,
799
+ col_a. clone( )
800
+ ) ;
801
+
802
+ Ok ( ( ) )
803
+ }
804
+
805
+ #[ test]
806
+ fn in_list_timestamp ( ) -> Result < ( ) > {
807
+ let schema = Schema :: new ( vec ! [ Field :: new(
808
+ "a" ,
809
+ DataType :: Timestamp ( TimeUnit :: Microsecond , None ) ,
810
+ true ,
811
+ ) ] ) ;
812
+ let a = TimestampMicrosecondArray :: from ( vec ! [
813
+ Some ( 1388588401000000000 ) ,
814
+ Some ( 1288588501000000000 ) ,
815
+ None ,
816
+ ] ) ;
817
+ let col_a = col ( "a" , & schema) ?;
818
+ let batch = RecordBatch :: try_new ( Arc :: new ( schema. clone ( ) ) , vec ! [ Arc :: new( a) ] ) ?;
819
+
820
+ let list = vec ! [
821
+ lit( ScalarValue :: TimestampMicrosecond (
822
+ Some ( 1388588401000000000 ) ,
823
+ None ,
824
+ ) ) ,
825
+ lit( ScalarValue :: TimestampMicrosecond (
826
+ Some ( 1388588401000000001 ) ,
827
+ None ,
828
+ ) ) ,
829
+ lit( ScalarValue :: TimestampMicrosecond (
830
+ Some ( 1388588401000000002 ) ,
831
+ None ,
832
+ ) ) ,
833
+ ] ;
834
+
835
+ in_list ! (
836
+ batch,
837
+ list. clone( ) ,
838
+ & false ,
839
+ vec![ Some ( true ) , Some ( false ) , None ] ,
840
+ col_a. clone( )
841
+ ) ;
842
+
843
+ in_list ! (
844
+ batch,
845
+ list. clone( ) ,
846
+ & true ,
847
+ vec![ Some ( false ) , Some ( true ) , None ] ,
848
+ col_a. clone( )
849
+ ) ;
850
+ Ok ( ( ) )
851
+ }
716
852
}
0 commit comments