@@ -15,6 +15,8 @@ use pgwire::api::portal::{Format, Portal};
1515use pgwire:: api:: results:: { DataRowEncoder , FieldInfo , QueryResponse } ;
1616use pgwire:: api:: Type ;
1717use pgwire:: error:: { ErrorInfo , PgWireError , PgWireResult } ;
18+ use rust_decimal:: prelude:: ToPrimitive ;
19+ use rust_decimal:: { Decimal , Error } ;
1820use timezone:: Tz ;
1921
2022pub ( crate ) fn into_pg_type ( df_type : & DataType ) -> PgWireResult < Type > {
@@ -38,6 +40,7 @@ pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult<Type> {
3840 DataType :: Binary | DataType :: FixedSizeBinary ( _) | DataType :: LargeBinary => Type :: BYTEA ,
3941 DataType :: Float16 | DataType :: Float32 => Type :: FLOAT4 ,
4042 DataType :: Float64 => Type :: FLOAT8 ,
43+ DataType :: Decimal128 ( _, _) => Type :: NUMERIC ,
4144 DataType :: Utf8 => Type :: VARCHAR ,
4245 DataType :: LargeUtf8 => Type :: TEXT ,
4346 DataType :: List ( field) | DataType :: FixedSizeList ( field, _) | DataType :: LargeList ( field) => {
@@ -83,6 +86,24 @@ pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult<Type> {
8386 } )
8487}
8588
89+ fn get_numeric_128_value ( arr : & Arc < dyn Array > , idx : usize , scale : u32 ) -> PgWireResult < Decimal > {
90+ let array = arr. as_any ( ) . downcast_ref :: < Decimal128Array > ( ) . unwrap ( ) ;
91+ let value = array. value ( idx) ;
92+ Decimal :: try_from_i128_with_scale ( value, scale) . map_err ( |e| {
93+ let message = match e {
94+ Error :: ExceedsMaximumPossibleValue => "Exceeds maximum possible value" ,
95+ Error :: LessThanMinimumPossibleValue => "Less than minimum possible value" ,
96+ Error :: ScaleExceedsMaximumPrecision ( _) => "Scale exceeds maximum precision" ,
97+ _ => unreachable ! ( ) ,
98+ } ;
99+ PgWireError :: UserError ( Box :: new ( ErrorInfo :: new (
100+ "ERROR" . to_owned ( ) ,
101+ "XX000" . to_owned ( ) ,
102+ message. to_owned ( ) ,
103+ ) ) )
104+ } )
105+ }
106+
86107fn get_bool_value ( arr : & Arc < dyn Array > , idx : usize ) -> bool {
87108 arr. as_any ( )
88109 . downcast_ref :: < BooleanArray > ( )
@@ -258,6 +279,9 @@ fn encode_value(
258279 DataType :: UInt64 => encoder. encode_field ( & ( get_u64_value ( arr, idx) as i64 ) ) ?,
259280 DataType :: Float32 => encoder. encode_field ( & get_f32_value ( arr, idx) ) ?,
260281 DataType :: Float64 => encoder. encode_field ( & get_f64_value ( arr, idx) ) ?,
282+ DataType :: Decimal128 ( _, s) => {
283+ encoder. encode_field ( & get_numeric_128_value ( arr, idx, * s as u32 ) ?) ?
284+ }
261285 DataType :: Utf8 => encoder. encode_field ( & get_utf8_value ( arr, idx) ) ?,
262286 DataType :: Utf8View => encoder. encode_field ( & get_utf8_view_value ( arr, idx) ) ?,
263287 DataType :: LargeUtf8 => encoder. encode_field ( & get_large_utf8_value ( arr, idx) ) ?,
@@ -361,6 +385,17 @@ fn encode_value(
361385 DataType :: UInt64 => encoder. encode_field ( & get_u64_list_value ( arr, idx) ) ?,
362386 DataType :: Float32 => encoder. encode_field ( & get_f32_list_value ( arr, idx) ) ?,
363387 DataType :: Float64 => encoder. encode_field ( & get_f64_list_value ( arr, idx) ) ?,
388+ DataType :: Decimal128 ( _, s) => {
389+ let list_arr = arr. as_any ( ) . downcast_ref :: < ListArray > ( ) . unwrap ( ) . value ( idx) ;
390+ let value: Vec < _ > = list_arr
391+ . as_any ( )
392+ . downcast_ref :: < Decimal128Array > ( )
393+ . unwrap ( )
394+ . iter ( )
395+ . map ( |ov| ov. map ( |v| Decimal :: from_i128_with_scale ( v, * s as u32 ) ) )
396+ . collect ( ) ;
397+ encoder. encode_field ( & value) ?
398+ }
364399 DataType :: Utf8 => {
365400 let list_arr = arr. as_any ( ) . downcast_ref :: < ListArray > ( ) . unwrap ( ) . value ( idx) ;
366401 let value: Vec < _ > = list_arr
@@ -711,9 +746,9 @@ pub(crate) async fn encode_dataframe<'a>(
711746 for col in 0 ..cols {
712747 let array = rb. column ( col) ;
713748 if array. is_null ( row) {
714- encoder. encode_field ( & None :: < i8 > ) . unwrap ( ) ;
749+ encoder. encode_field ( & None :: < i8 > ) ? ;
715750 } else {
716- encode_value ( & mut encoder, array, row) . unwrap ( ) ;
751+ encode_value ( & mut encoder, array, row) ?
717752 }
718753 }
719754 encoder. finish ( )
@@ -808,6 +843,20 @@ where
808843 let value = portal. parameter :: < f64 > ( i, & pg_type) ?;
809844 deserialized_params. push ( ScalarValue :: Float64 ( value) ) ;
810845 }
846+ Type :: NUMERIC => {
847+ let value = match portal. parameter :: < Decimal > ( i, & pg_type) ? {
848+ None => ScalarValue :: Decimal128 ( None , 0 , 0 ) ,
849+ Some ( value) => {
850+ let precision = match value. mantissa ( ) {
851+ 0 => 1 ,
852+ m => ( m. abs ( ) as f64 ) . log10 ( ) . floor ( ) as u8 + 1 ,
853+ } ;
854+ let scale = value. scale ( ) as i8 ;
855+ ScalarValue :: Decimal128 ( value. to_i128 ( ) , precision, scale)
856+ }
857+ } ;
858+ deserialized_params. push ( value) ;
859+ }
811860 Type :: TIMESTAMP => {
812861 let value = portal. parameter :: < NaiveDateTime > ( i, & pg_type) ?;
813862 deserialized_params. push ( ScalarValue :: TimestampMicrosecond (
0 commit comments