@@ -15,6 +15,8 @@ use pgwire::api::portal::{Format, Portal};
1515use pgwire:: api:: results:: { DataRowEncoder , FieldInfo , QueryResponse } ;
1616use pgwire:: api:: Type ;
1717use pgwire:: error:: { ErrorInfo , PgWireError , PgWireResult } ;
18+ use rust_decimal:: prelude:: ToPrimitive ;
19+ use rust_decimal:: Decimal ;
1820use timezone:: Tz ;
1921
2022pub ( crate ) fn into_pg_type ( df_type : & DataType ) -> PgWireResult < Type > {
@@ -38,6 +40,7 @@ pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult<Type> {
3840 DataType :: Binary | DataType :: FixedSizeBinary ( _) | DataType :: LargeBinary => Type :: BYTEA ,
3941 DataType :: Float16 | DataType :: Float32 => Type :: FLOAT4 ,
4042 DataType :: Float64 => Type :: FLOAT8 ,
43+ DataType :: Decimal128 ( _, _) => Type :: NUMERIC ,
4144 DataType :: Utf8 => Type :: VARCHAR ,
4245 DataType :: LargeUtf8 => Type :: TEXT ,
4346 DataType :: List ( field) | DataType :: FixedSizeList ( field, _) | DataType :: LargeList ( field) => {
@@ -83,6 +86,11 @@ pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult<Type> {
8386 } )
8487}
8588
89+ fn get_numeric_128_value ( arr : & Arc < dyn Array > , idx : usize , scale : u32 ) -> Decimal {
90+ let array = arr. as_any ( ) . downcast_ref :: < Decimal128Array > ( ) . unwrap ( ) ;
91+ Decimal :: from_i128_with_scale ( array. value ( idx) , scale)
92+ }
93+
8694fn get_bool_value ( arr : & Arc < dyn Array > , idx : usize ) -> bool {
8795 arr. as_any ( )
8896 . downcast_ref :: < BooleanArray > ( )
@@ -258,6 +266,9 @@ fn encode_value(
258266 DataType :: UInt64 => encoder. encode_field ( & ( get_u64_value ( arr, idx) as i64 ) ) ?,
259267 DataType :: Float32 => encoder. encode_field ( & get_f32_value ( arr, idx) ) ?,
260268 DataType :: Float64 => encoder. encode_field ( & get_f64_value ( arr, idx) ) ?,
269+ DataType :: Decimal128 ( _, s) => {
270+ encoder. encode_field ( & get_numeric_128_value ( arr, idx, * s as u32 ) ) ?
271+ }
261272 DataType :: Utf8 => encoder. encode_field ( & get_utf8_value ( arr, idx) ) ?,
262273 DataType :: Utf8View => encoder. encode_field ( & get_utf8_view_value ( arr, idx) ) ?,
263274 DataType :: LargeUtf8 => encoder. encode_field ( & get_large_utf8_value ( arr, idx) ) ?,
@@ -361,6 +372,17 @@ fn encode_value(
361372 DataType :: UInt64 => encoder. encode_field ( & get_u64_list_value ( arr, idx) ) ?,
362373 DataType :: Float32 => encoder. encode_field ( & get_f32_list_value ( arr, idx) ) ?,
363374 DataType :: Float64 => encoder. encode_field ( & get_f64_list_value ( arr, idx) ) ?,
375+ DataType :: Decimal128 ( _, s) => {
376+ let list_arr = arr. as_any ( ) . downcast_ref :: < ListArray > ( ) . unwrap ( ) . value ( idx) ;
377+ let value: Vec < _ > = list_arr
378+ . as_any ( )
379+ . downcast_ref :: < Decimal128Array > ( )
380+ . unwrap ( )
381+ . iter ( )
382+ . map ( |v| Decimal :: from_i128_with_scale ( v. unwrap ( ) , * s as u32 ) )
383+ . collect ( ) ;
384+ encoder. encode_field ( & value) ?
385+ }
364386 DataType :: Utf8 => {
365387 let list_arr = arr. as_any ( ) . downcast_ref :: < ListArray > ( ) . unwrap ( ) . value ( idx) ;
366388 let value: Vec < _ > = list_arr
@@ -808,6 +830,23 @@ where
808830 let value = portal. parameter :: < f64 > ( i, & pg_type) ?;
809831 deserialized_params. push ( ScalarValue :: Float64 ( value) ) ;
810832 }
833+ Type :: NUMERIC => {
834+ let value = match portal. parameter :: < Decimal > ( i, & pg_type) ? {
835+ None => ScalarValue :: Decimal128 ( None , 0 , 0 ) ,
836+ Some ( value) => {
837+ let mantissa = value. mantissa ( ) ;
838+ // Count digits in the mantissa
839+ let precision = if mantissa == 0 {
840+ 1
841+ } else {
842+ ( mantissa. abs ( ) as f64 ) . log10 ( ) . floor ( ) as u8 + 1
843+ } ;
844+ let scale = value. scale ( ) as i8 ;
845+ ScalarValue :: Decimal128 ( value. to_i128 ( ) , precision, scale)
846+ }
847+ } ;
848+ deserialized_params. push ( value) ;
849+ }
811850 Type :: TIMESTAMP => {
812851 let value = portal. parameter :: < NaiveDateTime > ( i, & pg_type) ?;
813852 deserialized_params. push ( ScalarValue :: TimestampMicrosecond (
0 commit comments