@@ -294,27 +294,21 @@ pub fn encode_value<T: Encoder>(
294294 type_ : & Type ,
295295 format : FieldFormat ,
296296) -> PgWireResult < ( ) > {
297- println ! ( "[DATAFUSION POSTGRES ENCODER] encode_value: idx={}, type={:?}" , idx, arr. data_type( ) ) ;
298297 match arr. data_type ( ) {
299298 DataType :: Null => encoder. encode_field_with_type_and_format ( & None :: < i8 > , type_, format) ?,
300299 DataType :: Boolean => {
301- println ! ( "BOOLEAN BRANCH" ) ;
302300 encoder. encode_field_with_type_and_format ( & get_bool_value ( arr, idx) , type_, format) ?
303301 }
304302 DataType :: Int8 => {
305- println ! ( "BOOLEAN BRANCH" ) ;
306303 encoder. encode_field_with_type_and_format ( & get_i8_value ( arr, idx) , type_, format) ?
307304 }
308305 DataType :: Int16 => {
309- println ! ( "BOOLEAN BRANCH" ) ;
310306 encoder. encode_field_with_type_and_format ( & get_i16_value ( arr, idx) , type_, format) ?
311307 }
312308 DataType :: Int32 => {
313- println ! ( "BOOLEAN BRANCH" ) ;
314309 encoder. encode_field_with_type_and_format ( & get_i32_value ( arr, idx) , type_, format) ?
315310 }
316311 DataType :: Int64 => {
317- println ! ( "BOOLEAN BRANCH" ) ;
318312 encoder. encode_field_with_type_and_format ( & get_i64_value ( arr, idx) , type_, format) ?
319313 }
320314 DataType :: UInt8 => encoder. encode_field_with_type_and_format (
@@ -496,7 +490,6 @@ pub fn encode_value<T: Encoder>(
496490 }
497491 } ,
498492 DataType :: List ( _) | DataType :: FixedSizeList ( _, _) | DataType :: LargeList ( _) => {
499- println ! ( "LIST BRANCH" ) ;
500493 if arr. is_null ( idx) {
501494 return encoder. encode_field_with_type_and_format ( & None :: < & [ i8 ] > , type_, format) ;
502495 }
@@ -505,7 +498,6 @@ pub fn encode_value<T: Encoder>(
505498 encoder. encode_field_with_type_and_format ( & value, type_, format) ?
506499 }
507500 DataType :: Struct ( _) => {
508- println ! ( "STRUCT BRANCH" ) ;
509501 let fields = match type_. kind ( ) {
510502 postgres_types:: Kind :: Composite ( fields) => fields,
511503 _ => {
@@ -518,66 +510,34 @@ pub fn encode_value<T: Encoder>(
518510 encoder. encode_field_with_type_and_format ( & value, type_, format) ?
519511 }
520512 DataType :: Dictionary ( _, value_type) => {
521- println ! ( "DICT BRANCH" ) ;
522513 if arr. is_null ( idx) {
523514 return encoder. encode_field_with_type_and_format ( & None :: < i8 > , type_, format) ;
524515 }
525-
526- macro_rules! get_value_index {
527- ( $key_ty: ty) => {
528- arr. as_any( )
529- . downcast_ref:: <DictionaryArray <$key_ty>>( )
530- . map( |dict| dict. keys( ) . value( idx) as usize )
531- } ;
532- }
533-
534- let value_idx = get_value_index ! ( Int8Type )
535- . or_else ( || get_value_index ! ( Int16Type ) )
536- . or_else ( || get_value_index ! ( Int32Type ) )
537- . or_else ( || get_value_index ! ( Int64Type ) )
538- . or_else ( || get_value_index ! ( UInt8Type ) )
539- . or_else ( || get_value_index ! ( UInt16Type ) )
540- . or_else ( || get_value_index ! ( UInt32Type ) )
541- . or_else ( || get_value_index ! ( UInt64Type ) )
542- . ok_or_else ( || {
543- ToSqlError :: from ( format ! (
544- "Unsupported dictionary key type"
545- ) )
546- } ) ?;
547-
548-
549- // Get the dictionary values, ignoring keys
550- // We'll use Int32Type as a common key type, but we're only interested in values
551- macro_rules! get_dict_values {
516+ // Get the dictionary values and the mapped row index
517+ macro_rules! get_dict_values_and_index {
552518 ( $key_type: ty) => {
553519 arr. as_any( )
554520 . downcast_ref:: <DictionaryArray <$key_type>>( )
555- . map( |dict| dict. values( ) . clone ( ) )
521+ . map( |dict| ( dict. values( ) , dict . keys ( ) . value ( idx ) as usize ) )
556522 } ;
557523 }
558524
559525 // Try to extract values using different key types
560- let values = get_dict_values ! ( Int8Type )
561- . or_else ( || get_dict_values ! ( Int16Type ) )
562- . or_else ( || get_dict_values ! ( Int32Type ) )
563- . or_else ( || get_dict_values ! ( Int64Type ) )
564- . or_else ( || get_dict_values ! ( UInt8Type ) )
565- . or_else ( || get_dict_values ! ( UInt16Type ) )
566- . or_else ( || get_dict_values ! ( UInt32Type ) )
567- . or_else ( || get_dict_values ! ( UInt64Type ) )
526+ let ( values, idx ) = get_dict_values_and_index ! ( Int8Type )
527+ . or_else ( || get_dict_values_and_index ! ( Int16Type ) )
528+ . or_else ( || get_dict_values_and_index ! ( Int32Type ) )
529+ . or_else ( || get_dict_values_and_index ! ( Int64Type ) )
530+ . or_else ( || get_dict_values_and_index ! ( UInt8Type ) )
531+ . or_else ( || get_dict_values_and_index ! ( UInt16Type ) )
532+ . or_else ( || get_dict_values_and_index ! ( UInt32Type ) )
533+ . or_else ( || get_dict_values_and_index ! ( UInt64Type ) )
568534 . ok_or_else ( || {
569535 ToSqlError :: from ( format ! (
570- "Unsupported dictionary key type"
536+ "Unsupported dictionary key type for value type {value_type} "
571537 ) )
572538 } ) ?;
573539
574- // If the dictionary has only one value, treat it as a primitive
575- if values. len ( ) == 1 {
576- encode_value ( encoder, & values, 0 , type_, format) ?
577- } else {
578- // Otherwise, use value directly indexed by values array
579- encode_value ( encoder, & values, idx, type_, format) ?
580- }
540+ encode_value ( encoder, values, idx, type_, format) ?
581541 }
582542 _ => {
583543 return Err ( PgWireError :: ApiError ( ToSqlError :: from ( format ! (
@@ -590,3 +550,56 @@ pub fn encode_value<T: Encoder>(
590550
591551 Ok ( ( ) )
592552}
553+
554+ #[ cfg( test) ]
555+ mod tests{
556+ use super :: * ;
557+
558+ #[ test]
559+ fn encodes_dictionary_array ( ) {
560+
561+ #[ derive( Default ) ]
562+ struct MockEncoder {
563+ encoded_value : String ,
564+ }
565+
566+ impl Encoder for MockEncoder {
567+ fn encode_field_with_type_and_format < T > (
568+ & mut self ,
569+ value : & T ,
570+ data_type : & Type ,
571+ _format : FieldFormat ,
572+ ) -> PgWireResult < ( ) >
573+ where
574+ T : ToSql + ToSqlText + Sized {
575+ let mut bytes = BytesMut :: new ( ) ;
576+ let _sql_text = value. to_sql_text ( data_type, & mut bytes) ;
577+ let string = String :: from_utf8 ( ( & bytes) . to_vec ( ) ) ;
578+ self . encoded_value = string. unwrap ( ) ;
579+ Ok ( ( ) )
580+ }
581+ }
582+
583+ let val = "~!@&$[]()@@!!" ;
584+ let value = StringArray :: from_iter_values ( [ val] ) ;
585+ let keys = Int8Array :: from_iter_values ( [ 0 , 0 , 0 , 0 ] ) ;
586+ let dict_arr: Arc < dyn Array > = Arc :: new ( DictionaryArray :: < Int8Type > :: try_new ( keys, Arc :: new ( value) ) . unwrap ( ) ) ;
587+
588+
589+ let mut encoder = MockEncoder :: default ( ) ;
590+
591+ let result = encode_value (
592+ & mut encoder,
593+ & dict_arr,
594+ 2 ,
595+ & Type :: TEXT ,
596+ FieldFormat :: Text ,
597+ ) ;
598+
599+ assert ! ( result. is_ok( ) ) ;
600+
601+ assert ! ( encoder. encoded_value == val) ;
602+
603+ }
604+ }
605+
0 commit comments