@@ -17,7 +17,7 @@ use arrow::{
17
17
record_batch:: RecordBatch ,
18
18
} ;
19
19
20
- use num:: cast:: AsPrimitive ;
20
+ use num:: { cast:: AsPrimitive , ToPrimitive } ;
21
21
22
22
/// A pointer to the Arrow record batch for the table function.
23
23
#[ repr( C ) ]
@@ -165,7 +165,7 @@ pub fn to_duckdb_type_id(data_type: &DataType) -> Result<LogicalTypeId, Box<dyn
165
165
// duckdb/src/main/capi/helper-c.cpp does not support decimal
166
166
// DataType::Decimal128(_, _) => Decimal,
167
167
// DataType::Decimal256(_, _) => Decimal,
168
- DataType :: Decimal128 ( _, _) => Double ,
168
+ DataType :: Decimal128 ( _, _) => Decimal ,
169
169
DataType :: Decimal256 ( _, _) => Double ,
170
170
DataType :: Map ( _, _) => Map ,
171
171
_ => {
@@ -177,35 +177,34 @@ pub fn to_duckdb_type_id(data_type: &DataType) -> Result<LogicalTypeId, Box<dyn
177
177
178
178
/// Convert arrow DataType to duckdb logical type
179
179
pub fn to_duckdb_logical_type ( data_type : & DataType ) -> Result < LogicalType , Box < dyn std:: error:: Error > > {
180
- if data_type. is_primitive ( )
181
- || matches ! (
182
- data_type,
183
- DataType :: Boolean | DataType :: Utf8 | DataType :: LargeUtf8 | DataType :: Binary | DataType :: LargeBinary
184
- )
185
- {
186
- Ok ( LogicalType :: new ( to_duckdb_type_id ( data_type) ?) )
187
- } else if let DataType :: Dictionary ( _, value_type) = data_type {
188
- to_duckdb_logical_type ( value_type)
189
- } else if let DataType :: Struct ( fields) = data_type {
190
- let mut shape = vec ! [ ] ;
191
- for field in fields. iter ( ) {
192
- shape. push ( ( field. name ( ) . as_str ( ) , to_duckdb_logical_type ( field. data_type ( ) ) ?) ) ;
193
- }
194
- Ok ( LogicalType :: struct_type ( shape. as_slice ( ) ) )
195
- } else if let DataType :: List ( child) = data_type {
196
- Ok ( LogicalType :: list ( & to_duckdb_logical_type ( child. data_type ( ) ) ?) )
197
- } else if let DataType :: LargeList ( child) = data_type {
198
- Ok ( LogicalType :: list ( & to_duckdb_logical_type ( child. data_type ( ) ) ?) )
199
- } else if let DataType :: FixedSizeList ( child, array_size) = data_type {
200
- Ok ( LogicalType :: array (
180
+ match data_type {
181
+ DataType :: Dictionary ( _, value_type) => to_duckdb_logical_type ( value_type) ,
182
+ DataType :: Struct ( fields) => {
183
+ let mut shape = vec ! [ ] ;
184
+ for field in fields. iter ( ) {
185
+ shape. push ( ( field. name ( ) . as_str ( ) , to_duckdb_logical_type ( field. data_type ( ) ) ?) ) ;
186
+ }
187
+ Ok ( LogicalType :: struct_type ( shape. as_slice ( ) ) )
188
+ }
189
+ DataType :: List ( child) | DataType :: LargeList ( child) => {
190
+ Ok ( LogicalType :: list ( & to_duckdb_logical_type ( child. data_type ( ) ) ?) )
191
+ }
192
+ DataType :: FixedSizeList ( child, array_size) => Ok ( LogicalType :: array (
201
193
& to_duckdb_logical_type ( child. data_type ( ) ) ?,
202
194
* array_size as u64 ,
203
- ) )
204
- } else {
205
- Err (
206
- format ! ( "Unsupported data type: {data_type}, please file an issue https://github.com/wangfenjin/duckdb-rs" )
207
- . into ( ) ,
195
+ ) ) ,
196
+ DataType :: Decimal128 ( width, scale) if * scale > 0 => {
197
+ // DuckDB does not support negative decimal scales
198
+ Ok ( LogicalType :: decimal ( * width, ( * scale) . try_into ( ) . unwrap ( ) ) )
199
+ }
200
+ DataType :: Boolean | DataType :: Utf8 | DataType :: LargeUtf8 | DataType :: Binary | DataType :: LargeBinary => {
201
+ Ok ( LogicalType :: new ( to_duckdb_type_id ( data_type) ?) )
202
+ }
203
+ dtype if dtype. is_primitive ( ) => Ok ( LogicalType :: new ( to_duckdb_type_id ( data_type) ?) ) ,
204
+ _ => Err ( format ! (
205
+ "Unsupported data type: {data_type}, please file an issue https://github.com/wangfenjin/duckdb-rs"
208
206
)
207
+ . into ( ) ) ,
209
208
}
210
209
}
211
210
@@ -354,13 +353,11 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) -> Result<
354
353
out. as_mut_any ( ) . downcast_mut ( ) . unwrap ( ) ,
355
354
) ;
356
355
}
357
- DataType :: Decimal128 ( _ , _) => {
356
+ DataType :: Decimal128 ( width , _) => {
358
357
decimal_array_to_vector (
359
- array
360
- . as_any ( )
361
- . downcast_ref :: < Decimal128Array > ( )
362
- . expect ( "Unable to downcast to BooleanArray" ) ,
358
+ as_primitive_array ( array) ,
363
359
out. as_mut_any ( ) . downcast_mut ( ) . unwrap ( ) ,
360
+ * width,
364
361
) ;
365
362
}
366
363
@@ -407,10 +404,43 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) -> Result<
407
404
}
408
405
409
406
/// Convert Arrow [Decimal128Array] to a duckdb vector.
410
- fn decimal_array_to_vector ( array : & Decimal128Array , out : & mut FlatVector ) {
411
- assert ! ( array. len( ) <= out. capacity( ) ) ;
412
- for i in 0 ..array. len ( ) {
413
- out. as_mut_slice ( ) [ i] = array. value_as_string ( i) . parse :: < f64 > ( ) . unwrap ( ) ;
407
+ fn decimal_array_to_vector ( array : & Decimal128Array , out : & mut FlatVector , width : u8 ) {
408
+ match width {
409
+ 1 ..=4 => {
410
+ let out_data = out. as_mut_slice ( ) ;
411
+ for ( i, value) in array. values ( ) . iter ( ) . enumerate ( ) {
412
+ out_data[ i] = value. to_i16 ( ) . unwrap ( ) ;
413
+ }
414
+ }
415
+ 5 ..=9 => {
416
+ let out_data = out. as_mut_slice ( ) ;
417
+ for ( i, value) in array. values ( ) . iter ( ) . enumerate ( ) {
418
+ out_data[ i] = value. to_i32 ( ) . unwrap ( ) ;
419
+ }
420
+ }
421
+ 10 ..=18 => {
422
+ let out_data = out. as_mut_slice ( ) ;
423
+ for ( i, value) in array. values ( ) . iter ( ) . enumerate ( ) {
424
+ out_data[ i] = value. to_i64 ( ) . unwrap ( ) ;
425
+ }
426
+ }
427
+ 19 ..=38 => {
428
+ let out_data = out. as_mut_slice ( ) ;
429
+ for ( i, value) in array. values ( ) . iter ( ) . enumerate ( ) {
430
+ out_data[ i] = value. to_i128 ( ) . unwrap ( ) ;
431
+ }
432
+ }
433
+ // This should never happen, arrow only supports 1-38 decimal digits
434
+ _ => panic ! ( "Invalid decimal width: {}" , width) ,
435
+ }
436
+
437
+ // Set nulls
438
+ if let Some ( nulls) = array. nulls ( ) {
439
+ for ( i, null) in nulls. into_iter ( ) . enumerate ( ) {
440
+ if !null {
441
+ out. set_null ( i) ;
442
+ }
443
+ }
414
444
}
415
445
}
416
446
@@ -581,8 +611,8 @@ mod test {
581
611
use crate :: { Connection , Result } ;
582
612
use arrow:: {
583
613
array:: {
584
- Array , ArrayRef , AsArray , BinaryArray , Date32Array , Date64Array , Decimal256Array , FixedSizeListArray ,
585
- Float64Array , GenericListArray , Int32Array , ListArray , OffsetSizeTrait , PrimitiveArray , StringArray ,
614
+ Array , ArrayRef , AsArray , BinaryArray , Date32Array , Date64Array , Decimal128Array , Decimal256Array ,
615
+ FixedSizeListArray , GenericListArray , Int32Array , ListArray , OffsetSizeTrait , PrimitiveArray , StringArray ,
586
616
StructArray , Time32SecondArray , Time64MicrosecondArray , TimestampMicrosecondArray ,
587
617
TimestampMillisecondArray , TimestampNanosecondArray , TimestampSecondArray ,
588
618
} ,
@@ -606,9 +636,9 @@ mod test {
606
636
let mut arr = stmt. query_arrow ( param) ?;
607
637
let rb = arr. next ( ) . expect ( "no record batch" ) ;
608
638
assert_eq ! ( rb. num_columns( ) , 1 ) ;
609
- let column = rb. column ( 0 ) . as_any ( ) . downcast_ref :: < Float64Array > ( ) . unwrap ( ) ;
639
+ let column = rb. column ( 0 ) . as_any ( ) . downcast_ref :: < Decimal128Array > ( ) . unwrap ( ) ;
610
640
assert_eq ! ( column. len( ) , 1 ) ;
611
- assert_eq ! ( column. value( 0 ) , 300.0 ) ;
641
+ assert_eq ! ( column. value( 0 ) , i128 :: from ( 30000 ) ) ;
612
642
Ok ( ( ) )
613
643
}
614
644
@@ -896,6 +926,19 @@ mod test {
896
926
Ok ( ( ) )
897
927
}
898
928
929
+ #[ test]
930
+ fn test_decimal128_roundtrip ( ) -> Result < ( ) , Box < dyn Error > > {
931
+ let array: PrimitiveArray < arrow:: datatypes:: Decimal128Type > =
932
+ Decimal128Array :: from ( vec ! [ i128 :: from( 1 ) , i128 :: from( 2 ) , i128 :: from( 3 ) ] ) ;
933
+ check_rust_primitive_array_roundtrip ( array. clone ( ) , array) ?;
934
+
935
+ // With width and scale
936
+ let array: PrimitiveArray < arrow:: datatypes:: Decimal128Type > =
937
+ Decimal128Array :: from ( vec ! [ i128 :: from( 12345 ) ] ) . with_data_type ( DataType :: Decimal128 ( 5 , 2 ) ) ;
938
+ check_rust_primitive_array_roundtrip ( array. clone ( ) , array) ?;
939
+ Ok ( ( ) )
940
+ }
941
+
899
942
#[ test]
900
943
fn test_timestamp_tz_insert ( ) -> Result < ( ) , Box < dyn Error > > {
901
944
// TODO: This test should be reworked once we support TIMESTAMP_TZ properly
0 commit comments