@@ -3033,4 +3033,302 @@ mod tests {
30333033
30343034 Ok ( ( ) )
30353035 }
3036+
3037+ /// Exhaustive structured test that verifies all lookup strategies, boundaries, and DataFusion types.
3038+ #[ test]
3039+ fn test_in_list_exhaustive ( ) -> Result < ( ) > {
3040+ let types = vec ! [
3041+ DataType :: Int8 ,
3042+ DataType :: Int16 ,
3043+ DataType :: Int32 ,
3044+ DataType :: Int64 ,
3045+ DataType :: Float32 ,
3046+ DataType :: Float64 ,
3047+ DataType :: Utf8 ,
3048+ DataType :: LargeUtf8 ,
3049+ DataType :: Binary ,
3050+ DataType :: LargeBinary ,
3051+ DataType :: Utf8View ,
3052+ DataType :: BinaryView ,
3053+ DataType :: Date32 ,
3054+ DataType :: Date64 ,
3055+ DataType :: Timestamp ( TimeUnit :: Nanosecond , None ) ,
3056+ DataType :: Timestamp ( TimeUnit :: Microsecond , Some ( "UTC" . into( ) ) ) ,
3057+ DataType :: Decimal128 ( 10 , 2 ) ,
3058+ DataType :: Decimal256 ( 38 , 10 ) ,
3059+ DataType :: Dictionary ( Box :: new( DataType :: Int32 ) , Box :: new( DataType :: Utf8 ) ) ,
3060+ DataType :: Dictionary ( Box :: new( DataType :: Int8 ) , Box :: new( DataType :: Int32 ) ) ,
3061+ DataType :: Struct ( Fields :: from( vec![ Field :: new( "f1" , DataType :: Int32 , true ) ] ) ) ,
3062+ ] ;
3063+
3064+ // Critical thresholds: 4 (Decimal boundary), 16 (8B boundary), 32 (4B boundary), 256 (Bloom boundary)
3065+ // Test all sizes 1..33 to cover ALL unrolled branchless filters, plus saturation boundaries.
3066+ let mut sizes: Vec < usize > = ( 1 ..=33 ) . collect ( ) ;
3067+ sizes. extend ( vec ! [ 100 , 256 , 257 , 512 ] ) ;
3068+
3069+ for dt in & types {
3070+ for size in & sizes {
3071+ for input_has_nulls in [ false , true ] {
3072+ for list_has_nulls in [ false , true ] {
3073+ for negated in [ false , true ] {
3074+ for input_is_scalar in [ false , true ] {
3075+ run_exhaustive_test_case (
3076+ dt,
3077+ * size,
3078+ input_has_nulls,
3079+ list_has_nulls,
3080+ negated,
3081+ input_is_scalar,
3082+ ) ?;
3083+ }
3084+ }
3085+ }
3086+ }
3087+ }
3088+ }
3089+ Ok ( ( ) )
3090+ }
3091+
3092+ fn run_exhaustive_test_case (
3093+ dt : & DataType ,
3094+ list_size : usize ,
3095+ input_has_nulls : bool ,
3096+ list_has_nulls : bool ,
3097+ negated : bool ,
3098+ input_is_scalar : bool ,
3099+ ) -> Result < ( ) > {
3100+ let array_len = 100 ;
3101+ let ( input_array, in_list_scalars) = generate_deterministic_data (
3102+ dt,
3103+ array_len,
3104+ list_size,
3105+ input_has_nulls,
3106+ list_has_nulls,
3107+ ) ;
3108+
3109+ let schema = Schema :: new ( vec ! [ Field :: new( "a" , dt. clone( ) , true ) ] ) ;
3110+ let batch = RecordBatch :: try_new (
3111+ Arc :: new ( schema. clone ( ) ) ,
3112+ vec ! [ Arc :: clone( & input_array) ] ,
3113+ ) ?;
3114+ let col_a = col ( "a" , & schema) ?;
3115+ let list_exprs: Vec < Arc < dyn PhysicalExpr > > =
3116+ in_list_scalars. iter ( ) . map ( |s| lit ( s. clone ( ) ) ) . collect ( ) ;
3117+
3118+ let opt_result_arr = if input_is_scalar {
3119+ let scalar = ScalarValue :: try_from_array ( & input_array, 0 ) ?;
3120+ let expr =
3121+ InListExpr :: try_new ( lit ( scalar) , list_exprs. clone ( ) , negated, & schema) ?;
3122+ expr. evaluate ( & batch) ?. into_array ( batch. num_rows ( ) ) ?
3123+ } else {
3124+ let expr = InListExpr :: try_new (
3125+ Arc :: clone ( & col_a) ,
3126+ list_exprs. clone ( ) ,
3127+ negated,
3128+ & schema,
3129+ ) ?;
3130+ expr. evaluate ( & batch) ?. into_array ( batch. num_rows ( ) ) ?
3131+ } ;
3132+
3133+ let opt_result_bool = opt_result_arr
3134+ . as_any ( )
3135+ . downcast_ref :: < BooleanArray > ( )
3136+ . unwrap ( ) ;
3137+ let ref_result = compute_reference_result (
3138+ & batch,
3139+ & col_a,
3140+ & in_list_scalars,
3141+ negated,
3142+ input_is_scalar,
3143+ ) ?;
3144+
3145+ assert_eq ! (
3146+ opt_result_bool, & ref_result,
3147+ "FAIL: Type={dt:?}, Size={list_size}, InNull={input_has_nulls}, ListNull={list_has_nulls}, Neg={negated}, Scalar={input_is_scalar}"
3148+ ) ;
3149+
3150+ Ok ( ( ) )
3151+ }
3152+
3153+ fn generate_deterministic_data (
3154+ dt : & DataType ,
3155+ len : usize ,
3156+ list_size : usize ,
3157+ input_has_nulls : bool ,
3158+ list_has_nulls : bool ,
3159+ ) -> ( ArrayRef , Vec < ScalarValue > ) {
3160+ let base_dt = if let DataType :: Dictionary ( _, v) = dt {
3161+ v. as_ref ( )
3162+ } else {
3163+ dt
3164+ } ;
3165+
3166+ let input_scalars: Vec < ScalarValue > = ( 0 ..len)
3167+ . map ( |i| {
3168+ if input_has_nulls && i % 10 == 0 {
3169+ ScalarValue :: try_new_null ( base_dt) . unwrap ( )
3170+ } else {
3171+ deterministic_scalar ( base_dt, i)
3172+ }
3173+ } )
3174+ . collect ( ) ;
3175+ let mut input_array = ScalarValue :: iter_to_array ( input_scalars) . unwrap ( ) ;
3176+
3177+ if let DataType :: Dictionary ( k, _) = dt {
3178+ match k. as_ref ( ) {
3179+ DataType :: Int32 => {
3180+ let keys = Int32Array :: from (
3181+ ( 0 ..len as i32 ) . map ( |i| i % 10 ) . collect :: < Vec < _ > > ( ) ,
3182+ ) ;
3183+ input_array =
3184+ Arc :: new ( DictionaryArray :: < Int32Type > :: new ( keys, input_array) )
3185+ as ArrayRef ;
3186+ }
3187+ DataType :: Int8 => {
3188+ let keys = Int8Array :: from (
3189+ ( 0 ..len as i8 ) . map ( |i| i % 10 ) . collect :: < Vec < _ > > ( ) ,
3190+ ) ;
3191+ input_array =
3192+ Arc :: new ( DictionaryArray :: < Int8Type > :: new ( keys, input_array) )
3193+ as ArrayRef ;
3194+ }
3195+ _ => unreachable ! ( ) ,
3196+ } ;
3197+ }
3198+
3199+ let mut in_list: Vec < ScalarValue > = ( 0 ..list_size)
3200+ . map ( |i| deterministic_scalar ( base_dt, i * 2 ) )
3201+ . collect ( ) ;
3202+
3203+ if list_has_nulls && !in_list. is_empty ( ) {
3204+ in_list[ 0 ] = ScalarValue :: try_new_null ( base_dt) . unwrap ( ) ;
3205+ }
3206+
3207+ ( input_array, in_list)
3208+ }
3209+
3210+ fn deterministic_scalar ( dt : & DataType , i : usize ) -> ScalarValue {
3211+ match dt {
3212+ DataType :: Int8 => ScalarValue :: Int8 ( Some ( i as i8 ) ) ,
3213+ DataType :: Int16 => ScalarValue :: Int16 ( Some ( i as i16 ) ) ,
3214+ DataType :: Int32 => ScalarValue :: Int32 ( Some ( i as i32 ) ) ,
3215+ DataType :: Int64 => ScalarValue :: Int64 ( Some ( i as i64 ) ) ,
3216+ DataType :: Float32 => ScalarValue :: Float32 ( Some ( i as f32 ) ) ,
3217+ DataType :: Float64 => ScalarValue :: Float64 ( Some ( i as f64 ) ) ,
3218+ DataType :: Utf8 | DataType :: Utf8View | DataType :: LargeUtf8 => {
3219+ let s = format ! ( "val-{i}" ) ;
3220+ if dt == & DataType :: Utf8View {
3221+ ScalarValue :: Utf8View ( Some ( s) )
3222+ } else if dt == & DataType :: LargeUtf8 {
3223+ ScalarValue :: LargeUtf8 ( Some ( s) )
3224+ } else {
3225+ ScalarValue :: Utf8 ( Some ( s) )
3226+ }
3227+ }
3228+ DataType :: Binary | DataType :: LargeBinary | DataType :: BinaryView => {
3229+ let b = format ! ( "bin-{i}" ) . into_bytes ( ) ;
3230+ if dt == & DataType :: BinaryView {
3231+ ScalarValue :: BinaryView ( Some ( b) )
3232+ } else if dt == & DataType :: LargeBinary {
3233+ ScalarValue :: LargeBinary ( Some ( b) )
3234+ } else {
3235+ ScalarValue :: Binary ( Some ( b) )
3236+ }
3237+ }
3238+ DataType :: Date32 => ScalarValue :: Date32 ( Some ( i as i32 ) ) ,
3239+ DataType :: Date64 => ScalarValue :: Date64 ( Some ( i as i64 ) ) ,
3240+ DataType :: Timestamp ( unit, tz) => {
3241+ let val = ( i * 1000 ) as i64 ;
3242+ match unit {
3243+ TimeUnit :: Nanosecond => {
3244+ ScalarValue :: TimestampNanosecond ( Some ( val) , tz. clone ( ) )
3245+ }
3246+ TimeUnit :: Microsecond => {
3247+ ScalarValue :: TimestampMicrosecond ( Some ( val) , tz. clone ( ) )
3248+ }
3249+ TimeUnit :: Millisecond => {
3250+ ScalarValue :: TimestampMillisecond ( Some ( val) , tz. clone ( ) )
3251+ }
3252+ TimeUnit :: Second => {
3253+ ScalarValue :: TimestampSecond ( Some ( val) , tz. clone ( ) )
3254+ }
3255+ }
3256+ }
3257+ DataType :: Decimal128 ( p, s) => {
3258+ ScalarValue :: Decimal128 ( Some ( i as i128 ) , * p, * s)
3259+ }
3260+ DataType :: Decimal256 ( p, s) => {
3261+ ScalarValue :: Decimal256 ( Some ( i256:: from_i128 ( i as i128 ) ) , * p, * s)
3262+ }
3263+ DataType :: Struct ( fields) => {
3264+ let values: Vec < ScalarValue > = fields
3265+ . iter ( )
3266+ . map ( |f| deterministic_scalar ( f. data_type ( ) , i) )
3267+ . collect ( ) ;
3268+ ScalarValue :: Struct ( Arc :: new (
3269+ StructArray :: try_new (
3270+ fields. clone ( ) ,
3271+ values. iter ( ) . map ( |s| s. to_array ( ) . unwrap ( ) ) . collect ( ) ,
3272+ None ,
3273+ )
3274+ . unwrap ( ) ,
3275+ ) )
3276+ }
3277+ _ => ScalarValue :: Int32 ( Some ( i as i32 ) ) ,
3278+ }
3279+ }
3280+
3281+ fn compute_reference_result (
3282+ batch : & RecordBatch ,
3283+ col_a : & Arc < dyn PhysicalExpr > ,
3284+ list : & [ ScalarValue ] ,
3285+ negated : bool ,
3286+ input_is_scalar : bool ,
3287+ ) -> Result < BooleanArray > {
3288+ let input_array = if input_is_scalar {
3289+ let val = col_a. evaluate ( batch) ?;
3290+ let arr = val. into_array ( batch. num_rows ( ) ) ?;
3291+ let s = ScalarValue :: try_from_array ( & arr, 0 ) ?;
3292+ s. to_array_of_size ( batch. num_rows ( ) ) ?
3293+ } else {
3294+ col_a. evaluate ( batch) ?. into_array ( batch. num_rows ( ) ) ?
3295+ } ;
3296+
3297+ let num_rows = batch. num_rows ( ) ;
3298+ let mut result = Vec :: with_capacity ( num_rows) ;
3299+ for i in 0 ..num_rows {
3300+ let mut val = ScalarValue :: try_from_array ( & input_array, i) ?;
3301+ if let ScalarValue :: Dictionary ( _, v) = val {
3302+ val = * v;
3303+ }
3304+
3305+ if val. is_null ( ) {
3306+ result. push ( None ) ;
3307+ continue ;
3308+ }
3309+
3310+ let mut found = false ;
3311+ let mut has_null = false ;
3312+ for list_val in list {
3313+ if list_val. is_null ( ) {
3314+ has_null = true ;
3315+ continue ;
3316+ }
3317+ if list_val == & val {
3318+ found = true ;
3319+ break ;
3320+ }
3321+ }
3322+
3323+ let res = if found {
3324+ Some ( !negated)
3325+ } else if has_null {
3326+ None
3327+ } else {
3328+ Some ( negated)
3329+ } ;
3330+ result. push ( res) ;
3331+ }
3332+ Ok ( BooleanArray :: from ( result) )
3333+ }
30363334}
0 commit comments