Skip to content

Commit 47325f6

Browse files
authored
Add support for Decimal128 datatype (#68)
* Add support for Decimal128 datatype * Improve error handling
1 parent 2cf5878 commit 47325f6

File tree

3 files changed

+53
-2
lines changed

3 files changed

+53
-2
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion-postgres/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ datafusion = { workspace = true }
2121
futures = "0.3"
2222
async-trait = "0.1"
2323
chrono = { version = "0.4", features = ["std"] }
24+
rust_decimal = { version = "1.35", features = ["db-postgres"] }

datafusion-postgres/src/datatypes.rs

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ use pgwire::api::portal::{Format, Portal};
1515
use pgwire::api::results::{DataRowEncoder, FieldInfo, QueryResponse};
1616
use pgwire::api::Type;
1717
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
18+
use rust_decimal::prelude::ToPrimitive;
19+
use rust_decimal::{Decimal, Error};
1820
use timezone::Tz;
1921

2022
pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult<Type> {
@@ -38,6 +40,7 @@ pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult<Type> {
3840
DataType::Binary | DataType::FixedSizeBinary(_) | DataType::LargeBinary => Type::BYTEA,
3941
DataType::Float16 | DataType::Float32 => Type::FLOAT4,
4042
DataType::Float64 => Type::FLOAT8,
43+
DataType::Decimal128(_, _) => Type::NUMERIC,
4144
DataType::Utf8 => Type::VARCHAR,
4245
DataType::LargeUtf8 => Type::TEXT,
4346
DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) => {
@@ -83,6 +86,24 @@ pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult<Type> {
8386
})
8487
}
8588

89+
fn get_numeric_128_value(arr: &Arc<dyn Array>, idx: usize, scale: u32) -> PgWireResult<Decimal> {
90+
let array = arr.as_any().downcast_ref::<Decimal128Array>().unwrap();
91+
let value = array.value(idx);
92+
Decimal::try_from_i128_with_scale(value, scale).map_err(|e| {
93+
let message = match e {
94+
Error::ExceedsMaximumPossibleValue => "Exceeds maximum possible value",
95+
Error::LessThanMinimumPossibleValue => "Less than minimum possible value",
96+
Error::ScaleExceedsMaximumPrecision(_) => "Scale exceeds maximum precision",
97+
_ => unreachable!(),
98+
};
99+
PgWireError::UserError(Box::new(ErrorInfo::new(
100+
"ERROR".to_owned(),
101+
"XX000".to_owned(),
102+
message.to_owned(),
103+
)))
104+
})
105+
}
106+
86107
fn get_bool_value(arr: &Arc<dyn Array>, idx: usize) -> bool {
87108
arr.as_any()
88109
.downcast_ref::<BooleanArray>()
@@ -258,6 +279,9 @@ fn encode_value(
258279
DataType::UInt64 => encoder.encode_field(&(get_u64_value(arr, idx) as i64))?,
259280
DataType::Float32 => encoder.encode_field(&get_f32_value(arr, idx))?,
260281
DataType::Float64 => encoder.encode_field(&get_f64_value(arr, idx))?,
282+
DataType::Decimal128(_, s) => {
283+
encoder.encode_field(&get_numeric_128_value(arr, idx, *s as u32)?)?
284+
}
261285
DataType::Utf8 => encoder.encode_field(&get_utf8_value(arr, idx))?,
262286
DataType::Utf8View => encoder.encode_field(&get_utf8_view_value(arr, idx))?,
263287
DataType::LargeUtf8 => encoder.encode_field(&get_large_utf8_value(arr, idx))?,
@@ -361,6 +385,17 @@ fn encode_value(
361385
DataType::UInt64 => encoder.encode_field(&get_u64_list_value(arr, idx))?,
362386
DataType::Float32 => encoder.encode_field(&get_f32_list_value(arr, idx))?,
363387
DataType::Float64 => encoder.encode_field(&get_f64_list_value(arr, idx))?,
388+
DataType::Decimal128(_, s) => {
389+
let list_arr = arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
390+
let value: Vec<_> = list_arr
391+
.as_any()
392+
.downcast_ref::<Decimal128Array>()
393+
.unwrap()
394+
.iter()
395+
.map(|ov| ov.map(|v| Decimal::from_i128_with_scale(v, *s as u32)))
396+
.collect();
397+
encoder.encode_field(&value)?
398+
}
364399
DataType::Utf8 => {
365400
let list_arr = arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
366401
let value: Vec<_> = list_arr
@@ -711,9 +746,9 @@ pub(crate) async fn encode_dataframe<'a>(
711746
for col in 0..cols {
712747
let array = rb.column(col);
713748
if array.is_null(row) {
714-
encoder.encode_field(&None::<i8>).unwrap();
749+
encoder.encode_field(&None::<i8>)?;
715750
} else {
716-
encode_value(&mut encoder, array, row).unwrap();
751+
encode_value(&mut encoder, array, row)?
717752
}
718753
}
719754
encoder.finish()
@@ -808,6 +843,20 @@ where
808843
let value = portal.parameter::<f64>(i, &pg_type)?;
809844
deserialized_params.push(ScalarValue::Float64(value));
810845
}
846+
Type::NUMERIC => {
847+
let value = match portal.parameter::<Decimal>(i, &pg_type)? {
848+
None => ScalarValue::Decimal128(None, 0, 0),
849+
Some(value) => {
850+
let precision = match value.mantissa() {
851+
0 => 1,
852+
m => (m.abs() as f64).log10().floor() as u8 + 1,
853+
};
854+
let scale = value.scale() as i8;
855+
ScalarValue::Decimal128(value.to_i128(), precision, scale)
856+
}
857+
};
858+
deserialized_params.push(value);
859+
}
811860
Type::TIMESTAMP => {
812861
let value = portal.parameter::<NaiveDateTime>(i, &pg_type)?;
813862
deserialized_params.push(ScalarValue::TimestampMicrosecond(

0 commit comments

Comments
 (0)