@@ -6,8 +6,8 @@ use super::{
6
6
use crate :: vtab:: vector:: Inserter ;
7
7
use arrow:: array:: {
8
8
as_boolean_array, as_large_list_array, as_list_array, as_primitive_array, as_string_array, Array , ArrayData ,
9
- BooleanArray , Decimal128Array , FixedSizeListArray , GenericListArray , OffsetSizeTrait , PrimitiveArray , StringArray ,
10
- StructArray ,
9
+ AsArray , BooleanArray , Decimal128Array , FixedSizeListArray , GenericListArray , OffsetSizeTrait , PrimitiveArray ,
10
+ StringArray , StructArray ,
11
11
} ;
12
12
13
13
use arrow:: {
@@ -138,9 +138,15 @@ pub fn to_duckdb_type_id(data_type: &DataType) -> Result<LogicalTypeId, Box<dyn
138
138
DataType :: UInt64 => UBigint ,
139
139
DataType :: Float32 => Float ,
140
140
DataType :: Float64 => Double ,
141
- DataType :: Timestamp ( _, _) => Timestamp ,
142
- DataType :: Date32 => Time ,
143
- DataType :: Date64 => Time ,
141
+ DataType :: Timestamp ( unit, None ) => match unit {
142
+ TimeUnit :: Second => TimestampS ,
143
+ TimeUnit :: Millisecond => TimestampMs ,
144
+ TimeUnit :: Microsecond => Timestamp ,
145
+ TimeUnit :: Nanosecond => TimestampNs ,
146
+ } ,
147
+ DataType :: Timestamp ( _, Some ( _) ) => TimestampTZ ,
148
+ DataType :: Date32 => Date ,
149
+ DataType :: Date64 => Date ,
144
150
DataType :: Time32 ( _) => Time ,
145
151
DataType :: Time64 ( _) => Time ,
146
152
DataType :: Duration ( _) => Interval ,
@@ -250,6 +256,16 @@ fn primitive_array_to_flat_vector<T: ArrowPrimitiveType>(array: &PrimitiveArray<
250
256
out_vector. copy :: < T :: Native > ( array. values ( ) ) ;
251
257
}
252
258
259
+ fn primitive_array_to_flat_vector_cast < T : ArrowPrimitiveType > (
260
+ data_type : DataType ,
261
+ array : & dyn Array ,
262
+ out_vector : & mut dyn Vector ,
263
+ ) {
264
+ let array = arrow:: compute:: kernels:: cast:: cast ( array, & data_type) . unwrap ( ) ;
265
+ let out_vector: & mut FlatVector = out_vector. as_mut_any ( ) . downcast_mut ( ) . unwrap ( ) ;
266
+ out_vector. copy :: < T :: Native > ( array. as_primitive :: < T > ( ) . values ( ) ) ;
267
+ }
268
+
253
269
fn primitive_array_to_vector ( array : & dyn Array , out : & mut dyn Vector ) {
254
270
match array. data_type ( ) {
255
271
DataType :: Boolean => {
@@ -303,6 +319,7 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) {
303
319
out. as_mut_any ( ) . downcast_mut ( ) . unwrap ( ) ,
304
320
) ;
305
321
}
322
+ DataType :: Float16 => todo ! ( "Float16 is not supported yet" ) ,
306
323
DataType :: Float32 => {
307
324
primitive_array_to_flat_vector :: < Float32Type > (
308
325
as_primitive_array ( array) ,
@@ -324,22 +341,55 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) {
324
341
out. as_mut_any ( ) . downcast_mut ( ) . unwrap ( ) ,
325
342
) ;
326
343
}
327
- // DataType::Decimal256(_, _) => {
328
- // primitive_array_to_flat_vector::<Decimal256Type>(
329
- // as_primitive_array(array),
330
- // out.as_mut_any().downcast_mut().unwrap(),
331
- // );
332
- // }
333
- _ => {
334
- todo ! ( )
344
+ DataType :: Decimal256 ( _, _) => todo ! ( "Decimal256 is not supported yet" ) ,
345
+
346
+ // DuckDB Only supports timetamp_tz in microsecond precision
347
+ DataType :: Timestamp ( _, Some ( tz) ) => primitive_array_to_flat_vector_cast :: < TimestampMicrosecondType > (
348
+ DataType :: Timestamp ( TimeUnit :: Microsecond , Some ( tz. clone ( ) ) ) ,
349
+ array,
350
+ out,
351
+ ) ,
352
+ DataType :: Timestamp ( unit, None ) => match unit {
353
+ TimeUnit :: Second => primitive_array_to_flat_vector :: < TimestampSecondType > (
354
+ as_primitive_array ( array) ,
355
+ out. as_mut_any ( ) . downcast_mut ( ) . unwrap ( ) ,
356
+ ) ,
357
+ TimeUnit :: Millisecond => primitive_array_to_flat_vector :: < TimestampMillisecondType > (
358
+ as_primitive_array ( array) ,
359
+ out. as_mut_any ( ) . downcast_mut ( ) . unwrap ( ) ,
360
+ ) ,
361
+ TimeUnit :: Microsecond => primitive_array_to_flat_vector :: < TimestampMicrosecondType > (
362
+ as_primitive_array ( array) ,
363
+ out. as_mut_any ( ) . downcast_mut ( ) . unwrap ( ) ,
364
+ ) ,
365
+ TimeUnit :: Nanosecond => primitive_array_to_flat_vector :: < TimestampNanosecondType > (
366
+ as_primitive_array ( array) ,
367
+ out. as_mut_any ( ) . downcast_mut ( ) . unwrap ( ) ,
368
+ ) ,
369
+ } ,
370
+ DataType :: Date32 => {
371
+ primitive_array_to_flat_vector :: < Date32Type > (
372
+ as_primitive_array ( array) ,
373
+ out. as_mut_any ( ) . downcast_mut ( ) . unwrap ( ) ,
374
+ ) ;
375
+ }
376
+ DataType :: Date64 => primitive_array_to_flat_vector_cast :: < Date32Type > ( Date32Type :: DATA_TYPE , array, out) ,
377
+ DataType :: Time32 ( _) => {
378
+ primitive_array_to_flat_vector_cast :: < Time64MicrosecondType > ( Time64MicrosecondType :: DATA_TYPE , array, out)
335
379
}
380
+ DataType :: Time64 ( _) => {
381
+ primitive_array_to_flat_vector_cast :: < Time64MicrosecondType > ( Time64MicrosecondType :: DATA_TYPE , array, out)
382
+ }
383
+ _ => todo ! (
384
+ "Converting '{dtype:#?}' to primitive flat vector is not supported" ,
385
+ dtype = array. data_type( )
386
+ ) ,
336
387
}
337
388
}
338
389
339
- /// Convert Arrow [BooleanArray ] to a duckdb vector.
390
+ /// Convert Arrow [Decimal128Array ] to a duckdb vector.
340
391
fn decimal_array_to_vector ( array : & Decimal128Array , out : & mut FlatVector ) {
341
392
assert ! ( array. len( ) <= out. capacity( ) ) ;
342
-
343
393
for i in 0 ..array. len ( ) {
344
394
out. as_mut_slice ( ) [ i] = array. value_as_string ( i) . parse :: < f64 > ( ) . unwrap ( ) ;
345
395
}
@@ -488,8 +538,12 @@ mod test {
488
538
use super :: { arrow_recordbatch_to_query_params, ArrowVTab } ;
489
539
use crate :: { Connection , Result } ;
490
540
use arrow:: {
491
- array:: { Float64Array , Int32Array } ,
492
- datatypes:: { DataType , Field , Schema } ,
541
+ array:: {
542
+ Array , AsArray , Date32Array , Date64Array , Float64Array , Int32Array , PrimitiveArray , StringArray ,
543
+ Time32SecondArray , Time64MicrosecondArray , TimestampMicrosecondArray , TimestampMillisecondArray ,
544
+ TimestampNanosecondArray , TimestampSecondArray ,
545
+ } ,
546
+ datatypes:: { ArrowPrimitiveType , DataType , Field , Schema } ,
493
547
record_batch:: RecordBatch ,
494
548
} ;
495
549
use std:: { error:: Error , sync:: Arc } ;
@@ -534,4 +588,137 @@ mod test {
534
588
assert_eq ! ( column. value( 0 ) , 15 ) ;
535
589
Ok ( ( ) )
536
590
}
591
+
592
+ fn check_rust_primitive_array_roundtrip < T1 , T2 > (
593
+ input_array : PrimitiveArray < T1 > ,
594
+ expected_array : PrimitiveArray < T2 > ,
595
+ ) -> Result < ( ) , Box < dyn Error > >
596
+ where
597
+ T1 : ArrowPrimitiveType ,
598
+ T2 : ArrowPrimitiveType ,
599
+ {
600
+ let db = Connection :: open_in_memory ( ) ?;
601
+ db. register_table_function :: < ArrowVTab > ( "arrow" ) ?;
602
+
603
+ // Roundtrip a record batch from Rust to DuckDB and back to Rust
604
+ let schema = Schema :: new ( vec ! [ Field :: new( "a" , input_array. data_type( ) . clone( ) , false ) ] ) ;
605
+
606
+ let rb = RecordBatch :: try_new ( Arc :: new ( schema) , vec ! [ Arc :: new( input_array. clone( ) ) ] ) ?;
607
+ let param = arrow_recordbatch_to_query_params ( rb) ;
608
+ let mut stmt = db. prepare ( "select a from arrow(?, ?)" ) ?;
609
+ let rb = stmt. query_arrow ( param) ?. next ( ) . expect ( "no record batch" ) ;
610
+
611
+ let output_any_array = rb. column ( 0 ) ;
612
+ match ( output_any_array. data_type ( ) , expected_array. data_type ( ) ) {
613
+ // TODO: DuckDB doesnt return timestamp_tz properly yet, so we just check that the units are the same
614
+ ( DataType :: Timestamp ( unit_a, _) , DataType :: Timestamp ( unit_b, _) ) => assert_eq ! ( unit_a, unit_b) ,
615
+ ( a, b) => assert_eq ! ( a, b) ,
616
+ }
617
+
618
+ let maybe_output_array = output_any_array. as_primitive_opt :: < T2 > ( ) ;
619
+
620
+ match maybe_output_array {
621
+ Some ( output_array) => {
622
+ // Check that the output array is the same as the input array
623
+ assert_eq ! ( output_array. len( ) , expected_array. len( ) ) ;
624
+ for i in 0 ..output_array. len ( ) {
625
+ assert_eq ! ( output_array. is_valid( i) , expected_array. is_valid( i) ) ;
626
+ if output_array. is_valid ( i) {
627
+ assert_eq ! ( output_array. value( i) , expected_array. value( i) ) ;
628
+ }
629
+ }
630
+ }
631
+ None => {
632
+ panic ! ( "Output array is not a PrimitiveArray {:?}" , rb. column( 0 ) . data_type( ) ) ;
633
+ }
634
+ }
635
+
636
+ Ok ( ( ) )
637
+ }
638
+
639
+ #[ test]
640
+ fn test_timestamp_roundtrip ( ) -> Result < ( ) , Box < dyn Error > > {
641
+ check_rust_primitive_array_roundtrip ( Int32Array :: from ( vec ! [ 1 , 2 , 3 ] ) , Int32Array :: from ( vec ! [ 1 , 2 , 3 ] ) ) ?;
642
+
643
+ check_rust_primitive_array_roundtrip (
644
+ TimestampMicrosecondArray :: from ( vec ! [ 1 , 2 , 3 ] ) ,
645
+ TimestampMicrosecondArray :: from ( vec ! [ 1 , 2 , 3 ] ) ,
646
+ ) ?;
647
+
648
+ check_rust_primitive_array_roundtrip (
649
+ TimestampNanosecondArray :: from ( vec ! [ 1 , 2 , 3 ] ) ,
650
+ TimestampNanosecondArray :: from ( vec ! [ 1 , 2 , 3 ] ) ,
651
+ ) ?;
652
+
653
+ check_rust_primitive_array_roundtrip (
654
+ TimestampSecondArray :: from ( vec ! [ 1 , 2 , 3 ] ) ,
655
+ TimestampSecondArray :: from ( vec ! [ 1 , 2 , 3 ] ) ,
656
+ ) ?;
657
+
658
+ check_rust_primitive_array_roundtrip (
659
+ TimestampMillisecondArray :: from ( vec ! [ 1 , 2 , 3 ] ) ,
660
+ TimestampMillisecondArray :: from ( vec ! [ 1 , 2 , 3 ] ) ,
661
+ ) ?;
662
+
663
+ // DuckDB can only return timestamp_tz in microseconds
664
+ // Note: DuckDB by default returns timestamp_tz with UTC because the rust
665
+ // driver doesnt support timestamp_tz properly when reading. In the
666
+ // future we should be able to roundtrip timestamp_tz with other timezones too
667
+ check_rust_primitive_array_roundtrip (
668
+ TimestampNanosecondArray :: from ( vec ! [ 1000 , 2000 , 3000 ] ) . with_timezone_utc ( ) ,
669
+ TimestampMicrosecondArray :: from ( vec ! [ 1 , 2 , 3 ] ) . with_timezone_utc ( ) ,
670
+ ) ?;
671
+
672
+ check_rust_primitive_array_roundtrip (
673
+ TimestampMillisecondArray :: from ( vec ! [ 1 , 2 , 3 ] ) . with_timezone_utc ( ) ,
674
+ TimestampMicrosecondArray :: from ( vec ! [ 1000 , 2000 , 3000 ] ) . with_timezone_utc ( ) ,
675
+ ) ?;
676
+
677
+ check_rust_primitive_array_roundtrip (
678
+ TimestampSecondArray :: from ( vec ! [ 1 , 2 , 3 ] ) . with_timezone_utc ( ) ,
679
+ TimestampMicrosecondArray :: from ( vec ! [ 1_000_000 , 2_000_000 , 3_000_000 ] ) . with_timezone_utc ( ) ,
680
+ ) ?;
681
+
682
+ check_rust_primitive_array_roundtrip (
683
+ TimestampMicrosecondArray :: from ( vec ! [ 1 , 2 , 3 ] ) . with_timezone_utc ( ) ,
684
+ TimestampMicrosecondArray :: from ( vec ! [ 1 , 2 , 3 ] ) . with_timezone_utc ( ) ,
685
+ ) ?;
686
+
687
+ check_rust_primitive_array_roundtrip ( Date32Array :: from ( vec ! [ 1 , 2 , 3 ] ) , Date32Array :: from ( vec ! [ 1 , 2 , 3 ] ) ) ?;
688
+
689
+ let mid = arrow:: temporal_conversions:: MILLISECONDS_IN_DAY ;
690
+ check_rust_primitive_array_roundtrip (
691
+ Date64Array :: from ( vec ! [ mid, 2 * mid, 3 * mid] ) ,
692
+ Date32Array :: from ( vec ! [ 1 , 2 , 3 ] ) ,
693
+ ) ?;
694
+
695
+ check_rust_primitive_array_roundtrip (
696
+ Time32SecondArray :: from ( vec ! [ 1 , 2 , 3 ] ) ,
697
+ Time64MicrosecondArray :: from ( vec ! [ 1_000_000 , 2_000_000 , 3_000_000 ] ) ,
698
+ ) ?;
699
+
700
+ Ok ( ( ) )
701
+ }
702
+
703
+ #[ test]
704
+ fn test_timestamp_tz_insert ( ) -> Result < ( ) , Box < dyn Error > > {
705
+ // TODO: This test should be reworked once we support TIMESTAMP_TZ properly
706
+
707
+ let db = Connection :: open_in_memory ( ) ?;
708
+ db. register_table_function :: < ArrowVTab > ( "arrow" ) ?;
709
+
710
+ let array = TimestampMicrosecondArray :: from ( vec ! [ 1 ] ) . with_timezone ( "+05:00" ) ;
711
+ let schema = Schema :: new ( vec ! [ Field :: new( "a" , array. data_type( ) . clone( ) , false ) ] ) ;
712
+
713
+ // Since we cant get TIMESTAMP_TZ from the rust client yet, we just check that we can insert it properly here.
714
+ let rb = RecordBatch :: try_new ( Arc :: new ( schema) , vec ! [ Arc :: new( array) ] ) . expect ( "failed to create record batch" ) ;
715
+ let param = arrow_recordbatch_to_query_params ( rb) ;
716
+ let mut stmt = db. prepare ( "select typeof(a)::VARCHAR from arrow(?, ?)" ) ?;
717
+ let mut arr = stmt. query_arrow ( param) ?;
718
+ let rb = arr. next ( ) . expect ( "no record batch" ) ;
719
+ assert_eq ! ( rb. num_columns( ) , 1 ) ;
720
+ let column = rb. column ( 0 ) . as_any ( ) . downcast_ref :: < StringArray > ( ) . unwrap ( ) ;
721
+ assert_eq ! ( column. value( 0 ) , "TIMESTAMP WITH TIME ZONE" ) ;
722
+ Ok ( ( ) )
723
+ }
537
724
}
0 commit comments