@@ -6,8 +6,8 @@ use super::{
66use crate :: vtab:: vector:: Inserter ;
77use arrow:: array:: {
88 as_boolean_array, as_large_list_array, as_list_array, as_primitive_array, as_string_array, Array , ArrayData ,
9- BooleanArray , Decimal128Array , FixedSizeListArray , GenericListArray , OffsetSizeTrait , PrimitiveArray , StringArray ,
10- StructArray ,
9+ AsArray , BooleanArray , Decimal128Array , FixedSizeListArray , GenericListArray , OffsetSizeTrait , PrimitiveArray ,
10+ StringArray , StructArray ,
1111} ;
1212
1313use arrow:: {
@@ -138,9 +138,15 @@ pub fn to_duckdb_type_id(data_type: &DataType) -> Result<LogicalTypeId, Box<dyn
138138 DataType :: UInt64 => UBigint ,
139139 DataType :: Float32 => Float ,
140140 DataType :: Float64 => Double ,
141- DataType :: Timestamp ( _, _) => Timestamp ,
142- DataType :: Date32 => Time ,
143- DataType :: Date64 => Time ,
141+ DataType :: Timestamp ( unit, None ) => match unit {
142+ TimeUnit :: Second => TimestampS ,
143+ TimeUnit :: Millisecond => TimestampMs ,
144+ TimeUnit :: Microsecond => Timestamp ,
145+ TimeUnit :: Nanosecond => TimestampNs ,
146+ } ,
147+ DataType :: Timestamp ( _, Some ( _) ) => TimestampTZ ,
148+ DataType :: Date32 => Date ,
149+ DataType :: Date64 => Date ,
144150 DataType :: Time32 ( _) => Time ,
145151 DataType :: Time64 ( _) => Time ,
146152 DataType :: Duration ( _) => Interval ,
@@ -250,6 +256,16 @@ fn primitive_array_to_flat_vector<T: ArrowPrimitiveType>(array: &PrimitiveArray<
250256 out_vector. copy :: < T :: Native > ( array. values ( ) ) ;
251257}
252258
259+ fn primitive_array_to_flat_vector_cast < T : ArrowPrimitiveType > (
260+ data_type : DataType ,
261+ array : & dyn Array ,
262+ out_vector : & mut dyn Vector ,
263+ ) {
264+ let array = arrow:: compute:: kernels:: cast:: cast ( array, & data_type) . unwrap ( ) ;
265+ let out_vector: & mut FlatVector = out_vector. as_mut_any ( ) . downcast_mut ( ) . unwrap ( ) ;
266+ out_vector. copy :: < T :: Native > ( array. as_primitive :: < T > ( ) . values ( ) ) ;
267+ }
268+
253269fn primitive_array_to_vector ( array : & dyn Array , out : & mut dyn Vector ) {
254270 match array. data_type ( ) {
255271 DataType :: Boolean => {
@@ -303,6 +319,7 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) {
303319 out. as_mut_any ( ) . downcast_mut ( ) . unwrap ( ) ,
304320 ) ;
305321 }
322+ DataType :: Float16 => todo ! ( "Float16 is not supported yet" ) ,
306323 DataType :: Float32 => {
307324 primitive_array_to_flat_vector :: < Float32Type > (
308325 as_primitive_array ( array) ,
@@ -324,22 +341,55 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) {
324341 out. as_mut_any ( ) . downcast_mut ( ) . unwrap ( ) ,
325342 ) ;
326343 }
327- // DataType::Decimal256(_, _) => {
328- // primitive_array_to_flat_vector::<Decimal256Type>(
329- // as_primitive_array(array),
330- // out.as_mut_any().downcast_mut().unwrap(),
331- // );
332- // }
333- _ => {
334- todo ! ( )
344+ DataType :: Decimal256 ( _, _) => todo ! ( "Decimal256 is not supported yet" ) ,
345+
346+ // DuckDB Only supports timetamp_tz in microsecond precision
347+ DataType :: Timestamp ( _, Some ( tz) ) => primitive_array_to_flat_vector_cast :: < TimestampMicrosecondType > (
348+ DataType :: Timestamp ( TimeUnit :: Microsecond , Some ( tz. clone ( ) ) ) ,
349+ array,
350+ out,
351+ ) ,
352+ DataType :: Timestamp ( unit, None ) => match unit {
353+ TimeUnit :: Second => primitive_array_to_flat_vector :: < TimestampSecondType > (
354+ as_primitive_array ( array) ,
355+ out. as_mut_any ( ) . downcast_mut ( ) . unwrap ( ) ,
356+ ) ,
357+ TimeUnit :: Millisecond => primitive_array_to_flat_vector :: < TimestampMillisecondType > (
358+ as_primitive_array ( array) ,
359+ out. as_mut_any ( ) . downcast_mut ( ) . unwrap ( ) ,
360+ ) ,
361+ TimeUnit :: Microsecond => primitive_array_to_flat_vector :: < TimestampMicrosecondType > (
362+ as_primitive_array ( array) ,
363+ out. as_mut_any ( ) . downcast_mut ( ) . unwrap ( ) ,
364+ ) ,
365+ TimeUnit :: Nanosecond => primitive_array_to_flat_vector :: < TimestampNanosecondType > (
366+ as_primitive_array ( array) ,
367+ out. as_mut_any ( ) . downcast_mut ( ) . unwrap ( ) ,
368+ ) ,
369+ } ,
370+ DataType :: Date32 => {
371+ primitive_array_to_flat_vector :: < Date32Type > (
372+ as_primitive_array ( array) ,
373+ out. as_mut_any ( ) . downcast_mut ( ) . unwrap ( ) ,
374+ ) ;
375+ }
376+ DataType :: Date64 => primitive_array_to_flat_vector_cast :: < Date32Type > ( Date32Type :: DATA_TYPE , array, out) ,
377+ DataType :: Time32 ( _) => {
378+ primitive_array_to_flat_vector_cast :: < Time64MicrosecondType > ( Time64MicrosecondType :: DATA_TYPE , array, out)
335379 }
380+ DataType :: Time64 ( _) => {
381+ primitive_array_to_flat_vector_cast :: < Time64MicrosecondType > ( Time64MicrosecondType :: DATA_TYPE , array, out)
382+ }
383+ _ => todo ! (
384+ "Converting '{dtype:#?}' to primitive flat vector is not supported" ,
385+ dtype = array. data_type( )
386+ ) ,
336387 }
337388}
338389
339- /// Convert Arrow [BooleanArray ] to a duckdb vector.
390+ /// Convert Arrow [Decimal128Array ] to a duckdb vector.
340391fn decimal_array_to_vector ( array : & Decimal128Array , out : & mut FlatVector ) {
341392 assert ! ( array. len( ) <= out. capacity( ) ) ;
342-
343393 for i in 0 ..array. len ( ) {
344394 out. as_mut_slice ( ) [ i] = array. value_as_string ( i) . parse :: < f64 > ( ) . unwrap ( ) ;
345395 }
@@ -488,8 +538,12 @@ mod test {
488538 use super :: { arrow_recordbatch_to_query_params, ArrowVTab } ;
489539 use crate :: { Connection , Result } ;
490540 use arrow:: {
491- array:: { Float64Array , Int32Array } ,
492- datatypes:: { DataType , Field , Schema } ,
541+ array:: {
542+ Array , AsArray , Date32Array , Date64Array , Float64Array , Int32Array , PrimitiveArray , StringArray ,
543+ Time32SecondArray , Time64MicrosecondArray , TimestampMicrosecondArray , TimestampMillisecondArray ,
544+ TimestampNanosecondArray , TimestampSecondArray ,
545+ } ,
546+ datatypes:: { ArrowPrimitiveType , DataType , Field , Schema } ,
493547 record_batch:: RecordBatch ,
494548 } ;
495549 use std:: { error:: Error , sync:: Arc } ;
@@ -534,4 +588,137 @@ mod test {
534588 assert_eq ! ( column. value( 0 ) , 15 ) ;
535589 Ok ( ( ) )
536590 }
591+
592+ fn check_rust_primitive_array_roundtrip < T1 , T2 > (
593+ input_array : PrimitiveArray < T1 > ,
594+ expected_array : PrimitiveArray < T2 > ,
595+ ) -> Result < ( ) , Box < dyn Error > >
596+ where
597+ T1 : ArrowPrimitiveType ,
598+ T2 : ArrowPrimitiveType ,
599+ {
600+ let db = Connection :: open_in_memory ( ) ?;
601+ db. register_table_function :: < ArrowVTab > ( "arrow" ) ?;
602+
603+ // Roundtrip a record batch from Rust to DuckDB and back to Rust
604+ let schema = Schema :: new ( vec ! [ Field :: new( "a" , input_array. data_type( ) . clone( ) , false ) ] ) ;
605+
606+ let rb = RecordBatch :: try_new ( Arc :: new ( schema) , vec ! [ Arc :: new( input_array. clone( ) ) ] ) ?;
607+ let param = arrow_recordbatch_to_query_params ( rb) ;
608+ let mut stmt = db. prepare ( "select a from arrow(?, ?)" ) ?;
609+ let rb = stmt. query_arrow ( param) ?. next ( ) . expect ( "no record batch" ) ;
610+
611+ let output_any_array = rb. column ( 0 ) ;
612+ match ( output_any_array. data_type ( ) , expected_array. data_type ( ) ) {
613+ // TODO: DuckDB doesnt return timestamp_tz properly yet, so we just check that the units are the same
614+ ( DataType :: Timestamp ( unit_a, _) , DataType :: Timestamp ( unit_b, _) ) => assert_eq ! ( unit_a, unit_b) ,
615+ ( a, b) => assert_eq ! ( a, b) ,
616+ }
617+
618+ let maybe_output_array = output_any_array. as_primitive_opt :: < T2 > ( ) ;
619+
620+ match maybe_output_array {
621+ Some ( output_array) => {
622+ // Check that the output array is the same as the input array
623+ assert_eq ! ( output_array. len( ) , expected_array. len( ) ) ;
624+ for i in 0 ..output_array. len ( ) {
625+ assert_eq ! ( output_array. is_valid( i) , expected_array. is_valid( i) ) ;
626+ if output_array. is_valid ( i) {
627+ assert_eq ! ( output_array. value( i) , expected_array. value( i) ) ;
628+ }
629+ }
630+ }
631+ None => {
632+ panic ! ( "Output array is not a PrimitiveArray {:?}" , rb. column( 0 ) . data_type( ) ) ;
633+ }
634+ }
635+
636+ Ok ( ( ) )
637+ }
638+
639+ #[ test]
640+ fn test_timestamp_roundtrip ( ) -> Result < ( ) , Box < dyn Error > > {
641+ check_rust_primitive_array_roundtrip ( Int32Array :: from ( vec ! [ 1 , 2 , 3 ] ) , Int32Array :: from ( vec ! [ 1 , 2 , 3 ] ) ) ?;
642+
643+ check_rust_primitive_array_roundtrip (
644+ TimestampMicrosecondArray :: from ( vec ! [ 1 , 2 , 3 ] ) ,
645+ TimestampMicrosecondArray :: from ( vec ! [ 1 , 2 , 3 ] ) ,
646+ ) ?;
647+
648+ check_rust_primitive_array_roundtrip (
649+ TimestampNanosecondArray :: from ( vec ! [ 1 , 2 , 3 ] ) ,
650+ TimestampNanosecondArray :: from ( vec ! [ 1 , 2 , 3 ] ) ,
651+ ) ?;
652+
653+ check_rust_primitive_array_roundtrip (
654+ TimestampSecondArray :: from ( vec ! [ 1 , 2 , 3 ] ) ,
655+ TimestampSecondArray :: from ( vec ! [ 1 , 2 , 3 ] ) ,
656+ ) ?;
657+
658+ check_rust_primitive_array_roundtrip (
659+ TimestampMillisecondArray :: from ( vec ! [ 1 , 2 , 3 ] ) ,
660+ TimestampMillisecondArray :: from ( vec ! [ 1 , 2 , 3 ] ) ,
661+ ) ?;
662+
663+ // DuckDB can only return timestamp_tz in microseconds
664+ // Note: DuckDB by default returns timestamp_tz with UTC because the rust
665+ // driver doesnt support timestamp_tz properly when reading. In the
666+ // future we should be able to roundtrip timestamp_tz with other timezones too
667+ check_rust_primitive_array_roundtrip (
668+ TimestampNanosecondArray :: from ( vec ! [ 1000 , 2000 , 3000 ] ) . with_timezone_utc ( ) ,
669+ TimestampMicrosecondArray :: from ( vec ! [ 1 , 2 , 3 ] ) . with_timezone_utc ( ) ,
670+ ) ?;
671+
672+ check_rust_primitive_array_roundtrip (
673+ TimestampMillisecondArray :: from ( vec ! [ 1 , 2 , 3 ] ) . with_timezone_utc ( ) ,
674+ TimestampMicrosecondArray :: from ( vec ! [ 1000 , 2000 , 3000 ] ) . with_timezone_utc ( ) ,
675+ ) ?;
676+
677+ check_rust_primitive_array_roundtrip (
678+ TimestampSecondArray :: from ( vec ! [ 1 , 2 , 3 ] ) . with_timezone_utc ( ) ,
679+ TimestampMicrosecondArray :: from ( vec ! [ 1_000_000 , 2_000_000 , 3_000_000 ] ) . with_timezone_utc ( ) ,
680+ ) ?;
681+
682+ check_rust_primitive_array_roundtrip (
683+ TimestampMicrosecondArray :: from ( vec ! [ 1 , 2 , 3 ] ) . with_timezone_utc ( ) ,
684+ TimestampMicrosecondArray :: from ( vec ! [ 1 , 2 , 3 ] ) . with_timezone_utc ( ) ,
685+ ) ?;
686+
687+ check_rust_primitive_array_roundtrip ( Date32Array :: from ( vec ! [ 1 , 2 , 3 ] ) , Date32Array :: from ( vec ! [ 1 , 2 , 3 ] ) ) ?;
688+
689+ let mid = arrow:: temporal_conversions:: MILLISECONDS_IN_DAY ;
690+ check_rust_primitive_array_roundtrip (
691+ Date64Array :: from ( vec ! [ mid, 2 * mid, 3 * mid] ) ,
692+ Date32Array :: from ( vec ! [ 1 , 2 , 3 ] ) ,
693+ ) ?;
694+
695+ check_rust_primitive_array_roundtrip (
696+ Time32SecondArray :: from ( vec ! [ 1 , 2 , 3 ] ) ,
697+ Time64MicrosecondArray :: from ( vec ! [ 1_000_000 , 2_000_000 , 3_000_000 ] ) ,
698+ ) ?;
699+
700+ Ok ( ( ) )
701+ }
702+
703+ #[ test]
704+ fn test_timestamp_tz_insert ( ) -> Result < ( ) , Box < dyn Error > > {
705+ // TODO: This test should be reworked once we support TIMESTAMP_TZ properly
706+
707+ let db = Connection :: open_in_memory ( ) ?;
708+ db. register_table_function :: < ArrowVTab > ( "arrow" ) ?;
709+
710+ let array = TimestampMicrosecondArray :: from ( vec ! [ 1 ] ) . with_timezone ( "+05:00" ) ;
711+ let schema = Schema :: new ( vec ! [ Field :: new( "a" , array. data_type( ) . clone( ) , false ) ] ) ;
712+
713+ // Since we cant get TIMESTAMP_TZ from the rust client yet, we just check that we can insert it properly here.
714+ let rb = RecordBatch :: try_new ( Arc :: new ( schema) , vec ! [ Arc :: new( array) ] ) . expect ( "failed to create record batch" ) ;
715+ let param = arrow_recordbatch_to_query_params ( rb) ;
716+ let mut stmt = db. prepare ( "select typeof(a)::VARCHAR from arrow(?, ?)" ) ?;
717+ let mut arr = stmt. query_arrow ( param) ?;
718+ let rb = arr. next ( ) . expect ( "no record batch" ) ;
719+ assert_eq ! ( rb. num_columns( ) , 1 ) ;
720+ let column = rb. column ( 0 ) . as_any ( ) . downcast_ref :: < StringArray > ( ) . unwrap ( ) ;
721+ assert_eq ! ( column. value( 0 ) , "TIMESTAMP WITH TIME ZONE" ) ;
722+ Ok ( ( ) )
723+ }
537724}
0 commit comments