Skip to content

Commit 2fea269

Browse files
authored
Support decimal128 without casting to double (#328)
* support decimal128 without casting to double * fix parquet test * clippy * clippy... again
1 parent f48a4e3 commit 2fea269

File tree

1 file changed

+84
-41
lines changed

1 file changed

+84
-41
lines changed

crates/duckdb/src/vtab/arrow.rs

Lines changed: 84 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use arrow::{
1717
record_batch::RecordBatch,
1818
};
1919

20-
use num::cast::AsPrimitive;
20+
use num::{cast::AsPrimitive, ToPrimitive};
2121

2222
/// A pointer to the Arrow record batch for the table function.
2323
#[repr(C)]
@@ -165,7 +165,7 @@ pub fn to_duckdb_type_id(data_type: &DataType) -> Result<LogicalTypeId, Box<dyn
165165
// duckdb/src/main/capi/helper-c.cpp does not support decimal
166166
// DataType::Decimal128(_, _) => Decimal,
167167
// DataType::Decimal256(_, _) => Decimal,
168-
DataType::Decimal128(_, _) => Double,
168+
DataType::Decimal128(_, _) => Decimal,
169169
DataType::Decimal256(_, _) => Double,
170170
DataType::Map(_, _) => Map,
171171
_ => {
@@ -177,35 +177,34 @@ pub fn to_duckdb_type_id(data_type: &DataType) -> Result<LogicalTypeId, Box<dyn
177177

178178
/// Convert arrow DataType to duckdb logical type
179179
pub fn to_duckdb_logical_type(data_type: &DataType) -> Result<LogicalType, Box<dyn std::error::Error>> {
180-
if data_type.is_primitive()
181-
|| matches!(
182-
data_type,
183-
DataType::Boolean | DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary | DataType::LargeBinary
184-
)
185-
{
186-
Ok(LogicalType::new(to_duckdb_type_id(data_type)?))
187-
} else if let DataType::Dictionary(_, value_type) = data_type {
188-
to_duckdb_logical_type(value_type)
189-
} else if let DataType::Struct(fields) = data_type {
190-
let mut shape = vec![];
191-
for field in fields.iter() {
192-
shape.push((field.name().as_str(), to_duckdb_logical_type(field.data_type())?));
193-
}
194-
Ok(LogicalType::struct_type(shape.as_slice()))
195-
} else if let DataType::List(child) = data_type {
196-
Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?))
197-
} else if let DataType::LargeList(child) = data_type {
198-
Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?))
199-
} else if let DataType::FixedSizeList(child, array_size) = data_type {
200-
Ok(LogicalType::array(
180+
match data_type {
181+
DataType::Dictionary(_, value_type) => to_duckdb_logical_type(value_type),
182+
DataType::Struct(fields) => {
183+
let mut shape = vec![];
184+
for field in fields.iter() {
185+
shape.push((field.name().as_str(), to_duckdb_logical_type(field.data_type())?));
186+
}
187+
Ok(LogicalType::struct_type(shape.as_slice()))
188+
}
189+
DataType::List(child) | DataType::LargeList(child) => {
190+
Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?))
191+
}
192+
DataType::FixedSizeList(child, array_size) => Ok(LogicalType::array(
201193
&to_duckdb_logical_type(child.data_type())?,
202194
*array_size as u64,
203-
))
204-
} else {
205-
Err(
206-
format!("Unsupported data type: {data_type}, please file an issue https://github.com/wangfenjin/duckdb-rs")
207-
.into(),
195+
)),
196+
DataType::Decimal128(width, scale) if *scale > 0 => {
197+
// DuckDB does not support negative decimal scales
198+
Ok(LogicalType::decimal(*width, (*scale).try_into().unwrap()))
199+
}
200+
DataType::Boolean | DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary | DataType::LargeBinary => {
201+
Ok(LogicalType::new(to_duckdb_type_id(data_type)?))
202+
}
203+
dtype if dtype.is_primitive() => Ok(LogicalType::new(to_duckdb_type_id(data_type)?)),
204+
_ => Err(format!(
205+
"Unsupported data type: {data_type}, please file an issue https://github.com/wangfenjin/duckdb-rs"
208206
)
207+
.into()),
209208
}
210209
}
211210

@@ -354,13 +353,11 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) -> Result<
354353
out.as_mut_any().downcast_mut().unwrap(),
355354
);
356355
}
357-
DataType::Decimal128(_, _) => {
356+
DataType::Decimal128(width, _) => {
358357
decimal_array_to_vector(
359-
array
360-
.as_any()
361-
.downcast_ref::<Decimal128Array>()
362-
.expect("Unable to downcast to BooleanArray"),
358+
as_primitive_array(array),
363359
out.as_mut_any().downcast_mut().unwrap(),
360+
*width,
364361
);
365362
}
366363

@@ -407,10 +404,43 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) -> Result<
407404
}
408405

409406
/// Convert Arrow [Decimal128Array] to a duckdb vector.
410-
fn decimal_array_to_vector(array: &Decimal128Array, out: &mut FlatVector) {
411-
assert!(array.len() <= out.capacity());
412-
for i in 0..array.len() {
413-
out.as_mut_slice()[i] = array.value_as_string(i).parse::<f64>().unwrap();
407+
fn decimal_array_to_vector(array: &Decimal128Array, out: &mut FlatVector, width: u8) {
408+
match width {
409+
1..=4 => {
410+
let out_data = out.as_mut_slice();
411+
for (i, value) in array.values().iter().enumerate() {
412+
out_data[i] = value.to_i16().unwrap();
413+
}
414+
}
415+
5..=9 => {
416+
let out_data = out.as_mut_slice();
417+
for (i, value) in array.values().iter().enumerate() {
418+
out_data[i] = value.to_i32().unwrap();
419+
}
420+
}
421+
10..=18 => {
422+
let out_data = out.as_mut_slice();
423+
for (i, value) in array.values().iter().enumerate() {
424+
out_data[i] = value.to_i64().unwrap();
425+
}
426+
}
427+
19..=38 => {
428+
let out_data = out.as_mut_slice();
429+
for (i, value) in array.values().iter().enumerate() {
430+
out_data[i] = value.to_i128().unwrap();
431+
}
432+
}
433+
// This should never happen, arrow only supports 1-38 decimal digits
434+
_ => panic!("Invalid decimal width: {}", width),
435+
}
436+
437+
// Set nulls
438+
if let Some(nulls) = array.nulls() {
439+
for (i, null) in nulls.into_iter().enumerate() {
440+
if !null {
441+
out.set_null(i);
442+
}
443+
}
414444
}
415445
}
416446

@@ -581,8 +611,8 @@ mod test {
581611
use crate::{Connection, Result};
582612
use arrow::{
583613
array::{
584-
Array, ArrayRef, AsArray, BinaryArray, Date32Array, Date64Array, Decimal256Array, FixedSizeListArray,
585-
Float64Array, GenericListArray, Int32Array, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray,
614+
Array, ArrayRef, AsArray, BinaryArray, Date32Array, Date64Array, Decimal128Array, Decimal256Array,
615+
FixedSizeListArray, GenericListArray, Int32Array, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray,
586616
StructArray, Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray,
587617
TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray,
588618
},
@@ -606,9 +636,9 @@ mod test {
606636
let mut arr = stmt.query_arrow(param)?;
607637
let rb = arr.next().expect("no record batch");
608638
assert_eq!(rb.num_columns(), 1);
609-
let column = rb.column(0).as_any().downcast_ref::<Float64Array>().unwrap();
639+
let column = rb.column(0).as_any().downcast_ref::<Decimal128Array>().unwrap();
610640
assert_eq!(column.len(), 1);
611-
assert_eq!(column.value(0), 300.0);
641+
assert_eq!(column.value(0), i128::from(30000));
612642
Ok(())
613643
}
614644

@@ -896,6 +926,19 @@ mod test {
896926
Ok(())
897927
}
898928

929+
#[test]
930+
fn test_decimal128_roundtrip() -> Result<(), Box<dyn Error>> {
931+
let array: PrimitiveArray<arrow::datatypes::Decimal128Type> =
932+
Decimal128Array::from(vec![i128::from(1), i128::from(2), i128::from(3)]);
933+
check_rust_primitive_array_roundtrip(array.clone(), array)?;
934+
935+
// With width and scale
936+
let array: PrimitiveArray<arrow::datatypes::Decimal128Type> =
937+
Decimal128Array::from(vec![i128::from(12345)]).with_data_type(DataType::Decimal128(5, 2));
938+
check_rust_primitive_array_roundtrip(array.clone(), array)?;
939+
Ok(())
940+
}
941+
899942
#[test]
900943
fn test_timestamp_tz_insert() -> Result<(), Box<dyn Error>> {
901944
// TODO: This test should be reworked once we support TIMESTAMP_TZ properly

0 commit comments

Comments
 (0)