@@ -15,6 +15,8 @@ use pgwire::api::portal::{Format, Portal};
1515use pgwire:: api:: results:: { DataRowEncoder , FieldInfo , QueryResponse } ;
1616use pgwire:: api:: Type ;
1717use pgwire:: error:: { ErrorInfo , PgWireError , PgWireResult } ;
18+ use rust_decimal:: prelude:: ToPrimitive ;
19+ use rust_decimal:: { Decimal , Error } ;
1820use timezone:: Tz ;
1921
2022pub ( crate ) fn into_pg_type ( df_type : & DataType ) -> PgWireResult < Type > {
@@ -38,6 +40,7 @@ pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult<Type> {
3840 DataType :: Binary | DataType :: FixedSizeBinary ( _) | DataType :: LargeBinary => Type :: BYTEA ,
3941 DataType :: Float16 | DataType :: Float32 => Type :: FLOAT4 ,
4042 DataType :: Float64 => Type :: FLOAT8 ,
43+ DataType :: Decimal128 ( _, _) => Type :: NUMERIC ,
4144 DataType :: Utf8 => Type :: VARCHAR ,
4245 DataType :: LargeUtf8 => Type :: TEXT ,
4346 DataType :: List ( field) | DataType :: FixedSizeList ( field, _) | DataType :: LargeList ( field) => {
@@ -72,6 +75,7 @@ pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult<Type> {
7275 }
7376 }
7477 DataType :: Utf8View => Type :: TEXT ,
78+ DataType :: Dictionary ( _, value_type) => into_pg_type ( value_type) ?,
7579 _ => {
7680 return Err ( PgWireError :: UserError ( Box :: new ( ErrorInfo :: new (
7781 "ERROR" . to_owned ( ) ,
@@ -82,6 +86,24 @@ pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult<Type> {
8286 } )
8387}
8488
89+ fn get_numeric_128_value ( arr : & Arc < dyn Array > , idx : usize , scale : u32 ) -> PgWireResult < Decimal > {
90+ let array = arr. as_any ( ) . downcast_ref :: < Decimal128Array > ( ) . unwrap ( ) ;
91+ let value = array. value ( idx) ;
92+ Decimal :: try_from_i128_with_scale ( value, scale) . map_err ( |e| {
93+ let message = match e {
94+ Error :: ExceedsMaximumPossibleValue => "Exceeds maximum possible value" ,
95+ Error :: LessThanMinimumPossibleValue => "Less than minimum possible value" ,
96+ Error :: ScaleExceedsMaximumPrecision ( _) => "Scale exceeds maximum precision" ,
97+ _ => unreachable ! ( ) ,
98+ } ;
99+ PgWireError :: UserError ( Box :: new ( ErrorInfo :: new (
100+ "ERROR" . to_owned ( ) ,
101+ "XX000" . to_owned ( ) ,
102+ message. to_owned ( ) ,
103+ ) ) )
104+ } )
105+ }
106+
85107fn get_bool_value ( arr : & Arc < dyn Array > , idx : usize ) -> bool {
86108 arr. as_any ( )
87109 . downcast_ref :: < BooleanArray > ( )
@@ -257,6 +279,9 @@ fn encode_value(
257279 DataType :: UInt64 => encoder. encode_field ( & ( get_u64_value ( arr, idx) as i64 ) ) ?,
258280 DataType :: Float32 => encoder. encode_field ( & get_f32_value ( arr, idx) ) ?,
259281 DataType :: Float64 => encoder. encode_field ( & get_f64_value ( arr, idx) ) ?,
282+ DataType :: Decimal128 ( _, s) => {
283+ encoder. encode_field ( & get_numeric_128_value ( arr, idx, * s as u32 ) ?) ?
284+ }
260285 DataType :: Utf8 => encoder. encode_field ( & get_utf8_value ( arr, idx) ) ?,
261286 DataType :: Utf8View => encoder. encode_field ( & get_utf8_view_value ( arr, idx) ) ?,
262287 DataType :: LargeUtf8 => encoder. encode_field ( & get_large_utf8_value ( arr, idx) ) ?,
@@ -360,6 +385,17 @@ fn encode_value(
360385 DataType :: UInt64 => encoder. encode_field ( & get_u64_list_value ( arr, idx) ) ?,
361386 DataType :: Float32 => encoder. encode_field ( & get_f32_list_value ( arr, idx) ) ?,
362387 DataType :: Float64 => encoder. encode_field ( & get_f64_list_value ( arr, idx) ) ?,
388+ DataType :: Decimal128 ( _, s) => {
389+ let list_arr = arr. as_any ( ) . downcast_ref :: < ListArray > ( ) . unwrap ( ) . value ( idx) ;
390+ let value: Vec < _ > = list_arr
391+ . as_any ( )
392+ . downcast_ref :: < Decimal128Array > ( )
393+ . unwrap ( )
394+ . iter ( )
395+ . map ( |ov| ov. map ( |v| Decimal :: from_i128_with_scale ( v, * s as u32 ) ) )
396+ . collect ( ) ;
397+ encoder. encode_field ( & value) ?
398+ }
363399 DataType :: Utf8 => {
364400 let list_arr = arr. as_any ( ) . downcast_ref :: < ListArray > ( ) . unwrap ( ) . value ( idx) ;
365401 let value: Vec < _ > = list_arr
@@ -609,6 +645,45 @@ fn encode_value(
609645 }
610646 }
611647 }
648+ DataType :: Dictionary ( _, value_type) => {
649+ // Get the dictionary values, ignoring keys
650+ // We'll use Int32Type as a common key type, but we're only interested in values
651+ macro_rules! get_dict_values {
652+ ( $key_type: ty) => {
653+ arr. as_any( )
654+ . downcast_ref:: <DictionaryArray <$key_type>>( )
655+ . map( |dict| dict. values( ) )
656+ } ;
657+ }
658+
659+ // Try to extract values using different key types
660+ let values = get_dict_values ! ( Int8Type )
661+ . or_else ( || get_dict_values ! ( Int16Type ) )
662+ . or_else ( || get_dict_values ! ( Int32Type ) )
663+ . or_else ( || get_dict_values ! ( Int64Type ) )
664+ . or_else ( || get_dict_values ! ( UInt8Type ) )
665+ . or_else ( || get_dict_values ! ( UInt16Type ) )
666+ . or_else ( || get_dict_values ! ( UInt32Type ) )
667+ . or_else ( || get_dict_values ! ( UInt64Type ) )
668+ . ok_or_else ( || {
669+ PgWireError :: UserError ( Box :: new ( ErrorInfo :: new (
670+ "ERROR" . to_owned ( ) ,
671+ "XX000" . to_owned ( ) ,
672+ format ! (
673+ "Unsupported dictionary key type for value type {}" ,
674+ value_type
675+ ) ,
676+ ) ) )
677+ } ) ?;
678+
679+ // If the dictionary has only one value, treat it as a primitive
680+ if values. len ( ) == 1 {
681+ encode_value ( encoder, values, 0 ) ?
682+ } else {
683+ // Otherwise, use value directly indexed by values array
684+ encode_value ( encoder, values, idx) ?
685+ }
686+ }
612687 _ => {
613688 return Err ( PgWireError :: UserError ( Box :: new ( ErrorInfo :: new (
614689 "ERROR" . to_owned ( ) ,
@@ -671,9 +746,9 @@ pub(crate) async fn encode_dataframe<'a>(
671746 for col in 0 ..cols {
672747 let array = rb. column ( col) ;
673748 if array. is_null ( row) {
674- encoder. encode_field ( & None :: < i8 > ) . unwrap ( ) ;
749+ encoder. encode_field ( & None :: < i8 > ) ? ;
675750 } else {
676- encode_value ( & mut encoder, array, row) . unwrap ( ) ;
751+ encode_value ( & mut encoder, array, row) ?
677752 }
678753 }
679754 encoder. finish ( )
@@ -768,6 +843,20 @@ where
768843 let value = portal. parameter :: < f64 > ( i, & pg_type) ?;
769844 deserialized_params. push ( ScalarValue :: Float64 ( value) ) ;
770845 }
846+ Type :: NUMERIC => {
847+ let value = match portal. parameter :: < Decimal > ( i, & pg_type) ? {
848+ None => ScalarValue :: Decimal128 ( None , 0 , 0 ) ,
849+ Some ( value) => {
850+ let precision = match value. mantissa ( ) {
851+ 0 => 1 ,
852+ m => ( m. abs ( ) as f64 ) . log10 ( ) . floor ( ) as u8 + 1 ,
853+ } ;
854+ let scale = value. scale ( ) as i8 ;
855+ ScalarValue :: Decimal128 ( value. to_i128 ( ) , precision, scale)
856+ }
857+ } ;
858+ deserialized_params. push ( value) ;
859+ }
771860 Type :: TIMESTAMP => {
772861 let value = portal. parameter :: < NaiveDateTime > ( i, & pg_type) ?;
773862 deserialized_params. push ( ScalarValue :: TimestampMicrosecond (
0 commit comments