Skip to content

Commit 1ffc805

Browse files
committed
Add support for Decimal128 datatype
1 parent 2cf5878 commit 1ffc805

File tree

3 files changed

+41
-0
lines changed

3 files changed

+41
-0
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion-postgres/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ datafusion = { workspace = true }
2121
futures = "0.3"
2222
async-trait = "0.1"
2323
chrono = { version = "0.4", features = ["std"] }
24+
rust_decimal = { version = "1.35", features = ["db-postgres"] }

datafusion-postgres/src/datatypes.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ use pgwire::api::portal::{Format, Portal};
1515
use pgwire::api::results::{DataRowEncoder, FieldInfo, QueryResponse};
1616
use pgwire::api::Type;
1717
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
18+
use rust_decimal::prelude::ToPrimitive;
19+
use rust_decimal::Decimal;
1820
use timezone::Tz;
1921

2022
pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult<Type> {
@@ -38,6 +40,7 @@ pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult<Type> {
3840
DataType::Binary | DataType::FixedSizeBinary(_) | DataType::LargeBinary => Type::BYTEA,
3941
DataType::Float16 | DataType::Float32 => Type::FLOAT4,
4042
DataType::Float64 => Type::FLOAT8,
43+
DataType::Decimal128(_, _) => Type::NUMERIC,
4144
DataType::Utf8 => Type::VARCHAR,
4245
DataType::LargeUtf8 => Type::TEXT,
4346
DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) => {
@@ -83,6 +86,11 @@ pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult<Type> {
8386
})
8487
}
8588

89+
fn get_numeric_128_value(arr: &Arc<dyn Array>, idx: usize, scale: u32) -> Decimal {
90+
let array = arr.as_any().downcast_ref::<Decimal128Array>().unwrap();
91+
Decimal::from_i128_with_scale(array.value(idx), scale)
92+
}
93+
8694
fn get_bool_value(arr: &Arc<dyn Array>, idx: usize) -> bool {
8795
arr.as_any()
8896
.downcast_ref::<BooleanArray>()
@@ -258,6 +266,9 @@ fn encode_value(
258266
DataType::UInt64 => encoder.encode_field(&(get_u64_value(arr, idx) as i64))?,
259267
DataType::Float32 => encoder.encode_field(&get_f32_value(arr, idx))?,
260268
DataType::Float64 => encoder.encode_field(&get_f64_value(arr, idx))?,
269+
DataType::Decimal128(_, s) => {
270+
encoder.encode_field(&get_numeric_128_value(arr, idx, *s as u32))?
271+
}
261272
DataType::Utf8 => encoder.encode_field(&get_utf8_value(arr, idx))?,
262273
DataType::Utf8View => encoder.encode_field(&get_utf8_view_value(arr, idx))?,
263274
DataType::LargeUtf8 => encoder.encode_field(&get_large_utf8_value(arr, idx))?,
@@ -361,6 +372,17 @@ fn encode_value(
361372
DataType::UInt64 => encoder.encode_field(&get_u64_list_value(arr, idx))?,
362373
DataType::Float32 => encoder.encode_field(&get_f32_list_value(arr, idx))?,
363374
DataType::Float64 => encoder.encode_field(&get_f64_list_value(arr, idx))?,
375+
DataType::Decimal128(_, s) => {
376+
let list_arr = arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
377+
let value: Vec<_> = list_arr
378+
.as_any()
379+
.downcast_ref::<Decimal128Array>()
380+
.unwrap()
381+
.iter()
382+
.map(|v| Decimal::from_i128_with_scale(v.unwrap(), *s as u32))
383+
.collect();
384+
encoder.encode_field(&value)?
385+
}
364386
DataType::Utf8 => {
365387
let list_arr = arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
366388
let value: Vec<_> = list_arr
@@ -808,6 +830,23 @@ where
808830
let value = portal.parameter::<f64>(i, &pg_type)?;
809831
deserialized_params.push(ScalarValue::Float64(value));
810832
}
833+
Type::NUMERIC => {
834+
let value = match portal.parameter::<Decimal>(i, &pg_type)? {
835+
None => ScalarValue::Decimal128(None, 0, 0),
836+
Some(value) => {
837+
let mantissa = value.mantissa();
838+
// Count digits in the mantissa
839+
let precision = if mantissa == 0 {
840+
1
841+
} else {
842+
(mantissa.abs() as f64).log10().floor() as u8 + 1
843+
};
844+
let scale = value.scale() as i8;
845+
ScalarValue::Decimal128(value.to_i128(), precision, scale)
846+
}
847+
};
848+
deserialized_params.push(value);
849+
}
811850
Type::TIMESTAMP => {
812851
let value = portal.parameter::<NaiveDateTime>(i, &pg_type)?;
813852
deserialized_params.push(ScalarValue::TimestampMicrosecond(

0 commit comments

Comments
 (0)