@@ -99,11 +99,18 @@ impl StaticFilter for ArrayStaticFilter {
9999 ) ) ;
100100 }
101101
102+ // Unwrap dictionary-encoded needles when the value type matches
103+ // in_array, evaluating against the dictionary values and mapping
104+ // back via keys.
102105 downcast_dictionary_array ! {
103106 v => {
104- let values_contains = self . contains( v. values( ) . as_ref( ) , negated) ?;
105- let result = take( & values_contains, v. keys( ) , None ) ?;
106- return Ok ( downcast_array( result. as_ref( ) ) )
107+ // Only unwrap when the haystack (in_array) type matches
108+ // the dictionary value type
109+ if v. values( ) . data_type( ) == self . in_array. data_type( ) {
110+ let values_contains = self . contains( v. values( ) . as_ref( ) , negated) ?;
111+ let result = take( & values_contains, v. keys( ) , None ) ?;
112+ return Ok ( downcast_array( result. as_ref( ) ) ) ;
113+ }
107114 }
108115 _ => { }
109116 }
@@ -3878,4 +3885,204 @@ mod tests {
38783885 ) ;
38793886 Ok ( ( ) )
38803887 }
3888+
3889+ // -----------------------------------------------------------------------
3890+ // Tests for try_new_from_array: evaluates `needle IN in_array`.
3891+ //
3892+ // This exercises the code path used by HashJoin dynamic filter pushdown,
3893+ // where in_array is built directly from the join's build-side arrays.
3894+ // Unlike try_new (used by SQL IN expressions), which always produces a
3895+ // non-Dictionary in_array because evaluate_list() flattens Dictionary
3896+ // scalars, try_new_from_array passes the array directly and can produce
3897+ // a Dictionary in_array.
3898+ // -----------------------------------------------------------------------
3899+
3900+ fn wrap_in_dict ( array : ArrayRef ) -> ArrayRef {
3901+ let keys = Int32Array :: from ( ( 0 ..array. len ( ) as i32 ) . collect :: < Vec < _ > > ( ) ) ;
3902+ Arc :: new ( DictionaryArray :: new ( keys, array) )
3903+ }
3904+
3905+ /// Evaluates `needle IN in_array` via try_new_from_array, the same
3906+ /// path used by HashJoin dynamic filter pushdown (not the SQL literal
3907+ /// IN path which goes through try_new).
3908+ fn eval_in_list_from_array (
3909+ needle : ArrayRef ,
3910+ in_array : ArrayRef ,
3911+ ) -> Result < BooleanArray > {
3912+ let schema =
3913+ Schema :: new ( vec ! [ Field :: new( "a" , needle. data_type( ) . clone( ) , false ) ] ) ;
3914+ let col_a = col ( "a" , & schema) ?;
3915+ let expr = Arc :: new ( InListExpr :: try_new_from_array ( col_a, in_array, false ) ?)
3916+ as Arc < dyn PhysicalExpr > ;
3917+ let batch = RecordBatch :: try_new ( Arc :: new ( schema) , vec ! [ needle] ) ?;
3918+ let result = expr. evaluate ( & batch) ?. into_array ( batch. num_rows ( ) ) ?;
3919+ Ok ( as_boolean_array ( & result) . clone ( ) )
3920+ }
3921+
3922+ #[ test]
3923+ fn test_in_list_from_array_type_combinations ( ) -> Result < ( ) > {
3924+ use arrow:: compute:: cast;
3925+
3926+ // All cases: needle[0] and needle[2] match, needle[1] does not.
3927+ let expected = BooleanArray :: from ( vec ! [ Some ( true ) , Some ( false ) , Some ( true ) ] ) ;
3928+
3929+ // Base arrays cast to each target type
3930+ let base_in = Arc :: new ( Int64Array :: from ( vec ! [ 1i64 , 2 , 3 ] ) ) as ArrayRef ;
3931+ let base_needle = Arc :: new ( Int64Array :: from ( vec ! [ 1i64 , 4 , 2 ] ) ) as ArrayRef ;
3932+
3933+ // Test all specializations in instantiate_static_filter
3934+ let primitive_types = vec ! [
3935+ DataType :: Int8 ,
3936+ DataType :: Int16 ,
3937+ DataType :: Int32 ,
3938+ DataType :: Int64 ,
3939+ DataType :: UInt8 ,
3940+ DataType :: UInt16 ,
3941+ DataType :: UInt32 ,
3942+ DataType :: UInt64 ,
3943+ DataType :: Float32 ,
3944+ DataType :: Float64 ,
3945+ ] ;
3946+
3947+ for dt in & primitive_types {
3948+ let in_array = cast ( & base_in, dt) ?;
3949+ let needle = cast ( & base_needle, dt) ?;
3950+
3951+ // T in_array, T needle
3952+ assert_eq ! (
3953+ expected,
3954+ eval_in_list_from_array( Arc :: clone( & needle) , Arc :: clone( & in_array) ) ?,
3955+ "same-type failed for {dt:?}"
3956+ ) ;
3957+
3958+ // T in_array, Dict(Int32, T) needle
3959+ assert_eq ! (
3960+ expected,
3961+ eval_in_list_from_array( wrap_in_dict( needle) , in_array) ?,
3962+ "dict-needle failed for {dt:?}"
3963+ ) ;
3964+ }
3965+
3966+ // Utf8 (falls through to ArrayStaticFilter)
3967+ let utf8_in = Arc :: new ( StringArray :: from ( vec ! [ "a" , "b" , "c" ] ) ) as ArrayRef ;
3968+ let utf8_needle = Arc :: new ( StringArray :: from ( vec ! [ "a" , "d" , "b" ] ) ) as ArrayRef ;
3969+
3970+ // Utf8 in_array, Utf8 needle
3971+ assert_eq ! (
3972+ expected,
3973+ eval_in_list_from_array( Arc :: clone( & utf8_needle) , Arc :: clone( & utf8_in) , ) ?
3974+ ) ;
3975+
3976+ // Utf8 in_array, Dict(Utf8) needle
3977+ assert_eq ! (
3978+ expected,
3979+ eval_in_list_from_array(
3980+ wrap_in_dict( Arc :: clone( & utf8_needle) ) ,
3981+ Arc :: clone( & utf8_in) ,
3982+ ) ?
3983+ ) ;
3984+
3985+ // Dict(Utf8) in_array, Dict(Utf8) needle: the #20937 bug
3986+ assert_eq ! (
3987+ expected,
3988+ eval_in_list_from_array(
3989+ wrap_in_dict( Arc :: clone( & utf8_needle) ) ,
3990+ wrap_in_dict( Arc :: clone( & utf8_in) ) ,
3991+ ) ?
3992+ ) ;
3993+
3994+ // Struct in_array, Struct needle: multi-column join
3995+ let struct_fields = Fields :: from ( vec ! [
3996+ Field :: new( "c0" , DataType :: Utf8 , true ) ,
3997+ Field :: new( "c1" , DataType :: Int64 , true ) ,
3998+ ] ) ;
3999+ let make_struct = |c0 : ArrayRef , c1 : ArrayRef | -> ArrayRef {
4000+ let pairs: Vec < ( FieldRef , ArrayRef ) > =
4001+ struct_fields. iter ( ) . cloned ( ) . zip ( [ c0, c1] ) . collect ( ) ;
4002+ Arc :: new ( StructArray :: from ( pairs) )
4003+ } ;
4004+ assert_eq ! (
4005+ expected,
4006+ eval_in_list_from_array(
4007+ make_struct(
4008+ Arc :: clone( & utf8_needle) ,
4009+ Arc :: new( Int64Array :: from( vec![ 1 , 4 , 2 ] ) ) ,
4010+ ) ,
4011+ make_struct(
4012+ Arc :: clone( & utf8_in) ,
4013+ Arc :: new( Int64Array :: from( vec![ 1 , 2 , 3 ] ) ) ,
4014+ ) ,
4015+ ) ?
4016+ ) ;
4017+
4018+ // Struct with Dict fields: multi-column Dict join
4019+ let dict_struct_fields = Fields :: from ( vec ! [
4020+ Field :: new(
4021+ "c0" ,
4022+ DataType :: Dictionary ( Box :: new( DataType :: Int32 ) , Box :: new( DataType :: Utf8 ) ) ,
4023+ true ,
4024+ ) ,
4025+ Field :: new( "c1" , DataType :: Int64 , true ) ,
4026+ ] ) ;
4027+ let make_dict_struct = |c0 : ArrayRef , c1 : ArrayRef | -> ArrayRef {
4028+ let pairs: Vec < ( FieldRef , ArrayRef ) > =
4029+ dict_struct_fields. iter ( ) . cloned ( ) . zip ( [ c0, c1] ) . collect ( ) ;
4030+ Arc :: new ( StructArray :: from ( pairs) )
4031+ } ;
4032+ assert_eq ! (
4033+ expected,
4034+ eval_in_list_from_array(
4035+ make_dict_struct(
4036+ wrap_in_dict( Arc :: clone( & utf8_needle) ) ,
4037+ Arc :: new( Int64Array :: from( vec![ 1 , 4 , 2 ] ) ) ,
4038+ ) ,
4039+ make_dict_struct(
4040+ wrap_in_dict( Arc :: clone( & utf8_in) ) ,
4041+ Arc :: new( Int64Array :: from( vec![ 1 , 2 , 3 ] ) ) ,
4042+ ) ,
4043+ ) ?
4044+ ) ;
4045+
4046+ Ok ( ( ) )
4047+ }
4048+
4049+ #[ test]
4050+ fn test_in_list_from_array_type_mismatch_errors ( ) -> Result < ( ) > {
4051+ // Utf8 needle, Dict(Utf8) in_array
4052+ let err = eval_in_list_from_array (
4053+ Arc :: new ( StringArray :: from ( vec ! [ "a" , "d" , "b" ] ) ) ,
4054+ wrap_in_dict ( Arc :: new ( StringArray :: from ( vec ! [ "a" , "b" , "c" ] ) ) ) ,
4055+ )
4056+ . unwrap_err ( )
4057+ . to_string ( ) ;
4058+ assert ! (
4059+ err. contains( "Can't compare arrays of different types" ) ,
4060+ "{err}"
4061+ ) ;
4062+
4063+ // Dict(Utf8) needle, Int64 in_array: specialized Int64StaticFilter
4064+ // rejects the Utf8 dictionary values at construction time
4065+ let err = eval_in_list_from_array (
4066+ wrap_in_dict ( Arc :: new ( StringArray :: from ( vec ! [ "a" , "d" , "b" ] ) ) ) ,
4067+ Arc :: new ( Int64Array :: from ( vec ! [ 1 , 2 , 3 ] ) ) ,
4068+ )
4069+ . unwrap_err ( )
4070+ . to_string ( ) ;
4071+ assert ! ( err. contains( "Failed to downcast" ) , "{err}" ) ;
4072+
4073+ // Dict(Int64) needle, Dict(Utf8) in_array: both Dict but different
4074+ // value types, make_comparator rejects the comparison
4075+ let err = eval_in_list_from_array (
4076+ wrap_in_dict ( Arc :: new ( Int64Array :: from ( vec ! [ 1 , 4 , 2 ] ) ) ) ,
4077+ wrap_in_dict ( Arc :: new ( StringArray :: from ( vec ! [ "a" , "b" , "c" ] ) ) ) ,
4078+ )
4079+ . unwrap_err ( )
4080+ . to_string ( ) ;
4081+ assert ! (
4082+ err. contains( "Can't compare arrays of different types" ) ,
4083+ "{err}"
4084+ ) ;
4085+
4086+ Ok ( ( ) )
4087+ }
38814088}
0 commit comments