@@ -513,38 +513,31 @@ pub fn encode_value<T: Encoder>(
513513 if arr. is_null ( idx) {
514514 return encoder. encode_field_with_type_and_format ( & None :: < i8 > , type_, format) ;
515515 }
516- // Get the dictionary values, ignoring keys
517- // We'll use Int32Type as a common key type, but we're only interested in values
518- macro_rules! get_dict_values {
516+ // Get the dictionary values and the mapped row index
517+ macro_rules! get_dict_values_and_index {
519518 ( $key_type: ty) => {
520519 arr. as_any( )
521520 . downcast_ref:: <DictionaryArray <$key_type>>( )
522- . map( |dict| dict. values( ) )
521+ . map( |dict| ( dict. values( ) , dict . keys ( ) . value ( idx ) as usize ) )
523522 } ;
524523 }
525524
526525 // Try to extract values using different key types
527- let values = get_dict_values ! ( Int8Type )
528- . or_else ( || get_dict_values ! ( Int16Type ) )
529- . or_else ( || get_dict_values ! ( Int32Type ) )
530- . or_else ( || get_dict_values ! ( Int64Type ) )
531- . or_else ( || get_dict_values ! ( UInt8Type ) )
532- . or_else ( || get_dict_values ! ( UInt16Type ) )
533- . or_else ( || get_dict_values ! ( UInt32Type ) )
534- . or_else ( || get_dict_values ! ( UInt64Type ) )
526+ let ( values, idx ) = get_dict_values_and_index ! ( Int8Type )
527+ . or_else ( || get_dict_values_and_index ! ( Int16Type ) )
528+ . or_else ( || get_dict_values_and_index ! ( Int32Type ) )
529+ . or_else ( || get_dict_values_and_index ! ( Int64Type ) )
530+ . or_else ( || get_dict_values_and_index ! ( UInt8Type ) )
531+ . or_else ( || get_dict_values_and_index ! ( UInt16Type ) )
532+ . or_else ( || get_dict_values_and_index ! ( UInt32Type ) )
533+ . or_else ( || get_dict_values_and_index ! ( UInt64Type ) )
535534 . ok_or_else ( || {
536535 ToSqlError :: from ( format ! (
537536 "Unsupported dictionary key type for value type {value_type}"
538537 ) )
539538 } ) ?;
540539
541- // If the dictionary has only one value, treat it as a primitive
542- if values. len ( ) == 1 {
543- encode_value ( encoder, values, 0 , type_, format) ?
544- } else {
545- // Otherwise, use value directly indexed by values array
546- encode_value ( encoder, values, idx, type_, format) ?
547- }
540+ encode_value ( encoder, values, idx, type_, format) ?
548541 }
549542 _ => {
550543 return Err ( PgWireError :: ApiError ( ToSqlError :: from ( format ! (
@@ -557,3 +550,48 @@ pub fn encode_value<T: Encoder>(
557550
558551 Ok ( ( ) )
559552}
553+
554+ #[ cfg( test) ]
555+ mod tests {
556+ use super :: * ;
557+
558+ #[ test]
559+ fn encodes_dictionary_array ( ) {
560+ #[ derive( Default ) ]
561+ struct MockEncoder {
562+ encoded_value : String ,
563+ }
564+
565+ impl Encoder for MockEncoder {
566+ fn encode_field_with_type_and_format < T > (
567+ & mut self ,
568+ value : & T ,
569+ data_type : & Type ,
570+ _format : FieldFormat ,
571+ ) -> PgWireResult < ( ) >
572+ where
573+ T : ToSql + ToSqlText + Sized ,
574+ {
575+ let mut bytes = BytesMut :: new ( ) ;
576+ let _sql_text = value. to_sql_text ( data_type, & mut bytes) ;
577+ let string = String :: from_utf8 ( ( & bytes) . to_vec ( ) ) ;
578+ self . encoded_value = string. unwrap ( ) ;
579+ Ok ( ( ) )
580+ }
581+ }
582+
583+ let val = "~!@&$[]()@@!!" ;
584+ let value = StringArray :: from_iter_values ( [ val] ) ;
585+ let keys = Int8Array :: from_iter_values ( [ 0 , 0 , 0 , 0 ] ) ;
586+ let dict_arr: Arc < dyn Array > =
587+ Arc :: new ( DictionaryArray :: < Int8Type > :: try_new ( keys, Arc :: new ( value) ) . unwrap ( ) ) ;
588+
589+ let mut encoder = MockEncoder :: default ( ) ;
590+
591+ let result = encode_value ( & mut encoder, & dict_arr, 2 , & Type :: TEXT , FieldFormat :: Text ) ;
592+
593+ assert ! ( result. is_ok( ) ) ;
594+
595+ assert ! ( encoder. encoded_value == val) ;
596+ }
597+ }
0 commit comments