Skip to content

Commit d70ba8b

Browse files
Improve the read performance of Postgres' Decimals (#438)
* Improve the read performance of Postgres' Decimals * Avoid unnecessary cast --------- Co-authored-by: Phillip LeBlanc <[email protected]>
1 parent c4086ec commit d70ba8b

File tree

3 files changed

+23
-123
lines changed

3 files changed

+23
-123
lines changed

Cargo.lock

Lines changed: 4 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

core/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ tracing = "0.1"
9999
trust-dns-resolver = "0.23.2"
100100
url = "2.5.4"
101101
uuid = { version = "1.18", optional = true }
102+
rust_decimal = { version = "1.38.0", features = ["db-postgres"] }
102103

103104
[dev-dependencies]
104105
anyhow = "1.0"

core/src/sql/arrow_sql_gen/postgres.rs

Lines changed: 18 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,12 @@ use arrow::datatypes::{
1515
DataType, Date32Type, Field, Int8Type, IntervalMonthDayNanoType, IntervalUnit, Schema,
1616
SchemaRef, TimeUnit,
1717
};
18-
use bigdecimal::num_bigint::BigInt;
19-
use bigdecimal::num_bigint::Sign;
2018
use bigdecimal::BigDecimal;
21-
use bigdecimal::ToPrimitive;
2219
use byteorder::{BigEndian, ReadBytesExt};
2320
use chrono::{DateTime, Timelike, Utc};
2421
use composite::CompositeType;
2522
use geo_types::geometry::Point;
23+
use rust_decimal::Decimal;
2624
use sea_query::{Alias, ColumnType, SeaRc};
2725
use serde_json::Value;
2826
use snafu::prelude::*;
@@ -196,7 +194,7 @@ pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option<SchemaRef>) -> Resu
196194
let mut arrow_fields: Vec<Option<Field>> = Vec::new();
197195
let mut arrow_columns_builders: Vec<Option<Box<dyn ArrayBuilder>>> = Vec::new();
198196
let mut postgres_types: Vec<Type> = Vec::new();
199-
let mut postgres_numeric_scales: Vec<Option<u16>> = Vec::new();
197+
let mut postgres_numeric_scales: Vec<Option<u32>> = Vec::new();
200198
let mut column_names: Vec<String> = Vec::new();
201199

202200
if !rows.is_empty() {
@@ -205,13 +203,13 @@ pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option<SchemaRef>) -> Resu
205203
let column_name = column.name();
206204
let column_type = column.type_();
207205

208-
let mut numeric_scale: Option<u16> = None;
206+
let mut numeric_scale: Option<u32> = None;
209207

210208
let data_type = if *column_type == Type::NUMERIC {
211209
if let Some(schema) = projected_schema.as_ref() {
212210
match get_decimal_column_precision_and_scale(column_name, schema) {
213211
Some((precision, scale)) => {
214-
numeric_scale = Some(u16::try_from(scale).unwrap_or_default());
212+
numeric_scale = Some(u32::try_from(scale).unwrap_or_default());
215213
Some(DataType::Decimal128(precision, scale))
216214
}
217215
None => None,
@@ -465,10 +463,9 @@ pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option<SchemaRef>) -> Resu
465463
}
466464
}
467465
Type::NUMERIC => {
468-
let v: Option<BigDecimalFromSql> =
469-
row.try_get(i).context(FailedToGetRowValueSnafu {
470-
pg_type: Type::NUMERIC,
471-
})?;
466+
let v: Option<Decimal> = row.try_get(i).context(FailedToGetRowValueSnafu {
467+
pg_type: Type::NUMERIC,
468+
})?;
472469
let scale = {
473470
if let Some(v) = &v {
474471
v.scale()
@@ -511,21 +508,16 @@ pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option<SchemaRef>) -> Resu
511508
*postgres_numeric_scale = Some(scale);
512509
};
513510

514-
let Some(v) = v else {
511+
let Some(mut v) = v else {
515512
dec_builder.append_null();
516513
continue;
517514
};
518515

519516
// Record Batch Scale is determined by first row, while Postgres Numeric Type doesn't have fixed scale
520517
// Resolve scale difference for incoming records
521518
let dest_scale = postgres_numeric_scale.unwrap_or_default();
522-
let Some(v_i128) = v.to_decimal_128_with_scale(dest_scale) else {
523-
return FailedToConvertBigDecimalToI128Snafu {
524-
big_decimal: v.inner,
525-
}
526-
.fail();
527-
};
528-
dec_builder.append_value(v_i128);
519+
v.rescale(dest_scale);
520+
dec_builder.append_value(v.mantissa());
529521
}
530522
Type::TIMESTAMP => {
531523
let Some(builder) = builder else {
@@ -1005,101 +997,6 @@ pub(crate) fn map_data_type_to_column_type_postgres(
1005997
pub(crate) fn get_postgres_composite_type_name(table_name: &str, field_name: &str) -> String {
1006998
format!("struct_{table_name}_{field_name}")
1007999
}
1008-
1009-
struct BigDecimalFromSql {
1010-
inner: BigDecimal,
1011-
scale: u16,
1012-
}
1013-
1014-
impl BigDecimalFromSql {
1015-
fn to_decimal_128_with_scale(&self, dest_scale: u16) -> Option<i128> {
1016-
// Resolve scale difference by upscaling / downscaling to the scale of arrow Decimal128 type
1017-
if dest_scale != self.scale {
1018-
return (&self.inner * 10i128.pow(u32::from(dest_scale))).to_i128();
1019-
}
1020-
1021-
(&self.inner * 10i128.pow(u32::from(self.scale))).to_i128()
1022-
}
1023-
1024-
fn scale(&self) -> u16 {
1025-
self.scale
1026-
}
1027-
}
1028-
1029-
#[allow(clippy::cast_sign_loss)]
1030-
#[allow(clippy::cast_possible_wrap)]
1031-
#[allow(clippy::cast_possible_truncation)]
1032-
impl<'a> FromSql<'a> for BigDecimalFromSql {
1033-
fn from_sql(
1034-
_ty: &Type,
1035-
raw: &'a [u8],
1036-
) -> std::prelude::v1::Result<Self, Box<dyn std::error::Error + Sync + Send>> {
1037-
let raw_u16: Vec<u16> = raw
1038-
.chunks(2)
1039-
.map(|chunk| {
1040-
if chunk.len() == 2 {
1041-
u16::from_be_bytes([chunk[0], chunk[1]])
1042-
} else {
1043-
u16::from_be_bytes([chunk[0], 0])
1044-
}
1045-
})
1046-
.collect();
1047-
1048-
let base_10_000_digit_count = raw_u16[0];
1049-
let weight = raw_u16[1] as i16;
1050-
let sign = raw_u16[2];
1051-
let scale = raw_u16[3];
1052-
1053-
let mut base_10_000_digits = Vec::new();
1054-
for i in 4..4 + base_10_000_digit_count {
1055-
base_10_000_digits.push(raw_u16[i as usize]);
1056-
}
1057-
1058-
let mut u8_digits = Vec::new();
1059-
for &base_10_000_digit in base_10_000_digits.iter().rev() {
1060-
let mut base_10_000_digit = base_10_000_digit;
1061-
let mut temp_result = Vec::new();
1062-
while base_10_000_digit > 0 {
1063-
temp_result.push((base_10_000_digit % 10) as u8);
1064-
base_10_000_digit /= 10;
1065-
}
1066-
while temp_result.len() < 4 {
1067-
temp_result.push(0);
1068-
}
1069-
u8_digits.extend(temp_result);
1070-
}
1071-
u8_digits.reverse();
1072-
1073-
let value_scale = 4 * (i64::from(base_10_000_digit_count) - i64::from(weight) - 1);
1074-
let size = i64::try_from(u8_digits.len())? + i64::from(scale) - value_scale;
1075-
u8_digits.resize(size as usize, 0);
1076-
1077-
let sign = match sign {
1078-
0x4000 => Sign::Minus,
1079-
0x0000 => Sign::Plus,
1080-
_ => {
1081-
return Err(Box::new(Error::FailedToParseBigDecimalFromPostgres {
1082-
bytes: raw.to_vec(),
1083-
}))
1084-
}
1085-
};
1086-
1087-
let Some(digits) = BigInt::from_radix_be(sign, u8_digits.as_slice(), 10) else {
1088-
return Err(Box::new(Error::FailedToParseBigDecimalFromPostgres {
1089-
bytes: raw.to_vec(),
1090-
}));
1091-
};
1092-
Ok(BigDecimalFromSql {
1093-
inner: BigDecimal::new(digits, i64::from(scale)),
1094-
scale,
1095-
})
1096-
}
1097-
1098-
fn accepts(ty: &Type) -> bool {
1099-
matches!(*ty, Type::NUMERIC)
1100-
}
1101-
}
1102-
11031000
// interval_send - Postgres C (https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/timestamp.c#L1032)
11041001
// interval values are internally stored as three integral fields: months, days, and microseconds
11051002
struct IntervalFromSql {
@@ -1225,28 +1122,28 @@ mod tests {
12251122

12261123
#[allow(clippy::cast_possible_truncation)]
12271124
#[tokio::test]
1228-
async fn test_big_decimal_from_sql() {
1125+
async fn test_decimal_from_sql() {
12291126
let positive_u16: Vec<u16> = vec![5, 3, 0, 5, 9345, 1293, 2903, 1293, 932];
12301127
let positive_raw: Vec<u8> = positive_u16
12311128
.iter()
12321129
.flat_map(|&x| vec![(x >> 8) as u8, x as u8])
12331130
.collect();
1234-
let positive =
1235-
BigDecimal::from_str("9345129329031293.0932").expect("Failed to parse big decimal");
1236-
let positive_result = BigDecimalFromSql::from_sql(&Type::NUMERIC, positive_raw.as_slice())
1131+
let positive = Decimal::from_str("9345129329031293.0932").expect("Failed to parse decimal");
1132+
let positive_result = Decimal::from_sql(&Type::NUMERIC, positive_raw.as_slice())
12371133
.expect("Failed to run FromSql");
1238-
assert_eq!(positive_result.inner, positive);
1134+
assert_eq!(positive_result, positive);
12391135

12401136
let negative_u16: Vec<u16> = vec![5, 3, 0x4000, 5, 9345, 1293, 2903, 1293, 932];
12411137
let negative_raw: Vec<u8> = negative_u16
12421138
.iter()
12431139
.flat_map(|&x| vec![(x >> 8) as u8, x as u8])
12441140
.collect();
1141+
12451142
let negative =
1246-
BigDecimal::from_str("-9345129329031293.0932").expect("Failed to parse big decimal");
1247-
let negative_result = BigDecimalFromSql::from_sql(&Type::NUMERIC, negative_raw.as_slice())
1143+
Decimal::from_str("-9345129329031293.0932").expect("Failed to parse decimal");
1144+
let negative_result = Decimal::from_sql(&Type::NUMERIC, negative_raw.as_slice())
12481145
.expect("Failed to run FromSql");
1249-
assert_eq!(negative_result.inner, negative);
1146+
assert_eq!(negative_result, negative);
12501147
}
12511148

12521149
#[test]

0 commit comments

Comments
 (0)