@@ -8,7 +8,7 @@ use crate::vtab::vector::Inserter;
8
8
use arrow:: array:: {
9
9
as_boolean_array, as_generic_binary_array, as_large_list_array, as_list_array, as_primitive_array, as_string_array,
10
10
as_struct_array, Array , ArrayData , AsArray , BinaryArray , BooleanArray , Decimal128Array , FixedSizeListArray ,
11
- GenericListArray , OffsetSizeTrait , PrimitiveArray , StringArray , StructArray ,
11
+ GenericListArray , GenericStringArray , LargeStringArray , OffsetSizeTrait , PrimitiveArray , StructArray ,
12
12
} ;
13
13
14
14
use arrow:: {
@@ -229,6 +229,15 @@ pub fn record_batch_to_duckdb_data_chunk(
229
229
DataType :: Utf8 => {
230
230
string_array_to_vector ( as_string_array ( col. as_ref ( ) ) , & mut chunk. flat_vector ( i) ) ;
231
231
}
232
+ DataType :: LargeUtf8 => {
233
+ string_array_to_vector (
234
+ col. as_ref ( )
235
+ . as_any ( )
236
+ . downcast_ref :: < LargeStringArray > ( )
237
+ . ok_or_else ( || Box :: < dyn std:: error:: Error > :: from ( "Unable to downcast to LargeStringArray" ) ) ?,
238
+ & mut chunk. flat_vector ( i) ,
239
+ ) ;
240
+ }
232
241
DataType :: Binary => {
233
242
binary_array_to_vector ( as_generic_binary_array ( col. as_ref ( ) ) , & mut chunk. flat_vector ( i) ) ;
234
243
}
@@ -453,7 +462,7 @@ fn boolean_array_to_vector(array: &BooleanArray, out: &mut FlatVector) {
453
462
}
454
463
}
455
464
456
- fn string_array_to_vector ( array : & StringArray , out : & mut FlatVector ) {
465
+ fn string_array_to_vector < O : OffsetSizeTrait > ( array : & GenericStringArray < O > , out : & mut FlatVector ) {
457
466
assert ! ( array. len( ) <= out. capacity( ) ) ;
458
467
459
468
// TODO: zero copy assignment
@@ -612,12 +621,12 @@ mod test {
612
621
use arrow:: {
613
622
array:: {
614
623
Array , ArrayRef , AsArray , BinaryArray , Date32Array , Date64Array , Decimal128Array , Decimal256Array ,
615
- FixedSizeListArray , GenericListArray , Int32Array , ListArray , OffsetSizeTrait , PrimitiveArray , StringArray ,
616
- StructArray , Time32SecondArray , Time64MicrosecondArray , TimestampMicrosecondArray ,
617
- TimestampMillisecondArray , TimestampNanosecondArray , TimestampSecondArray ,
624
+ FixedSizeListArray , GenericByteArray , GenericListArray , Int32Array , LargeStringArray , ListArray ,
625
+ OffsetSizeTrait , PrimitiveArray , StringArray , StructArray , Time32SecondArray , Time64MicrosecondArray ,
626
+ TimestampMicrosecondArray , TimestampMillisecondArray , TimestampNanosecondArray , TimestampSecondArray ,
618
627
} ,
619
628
buffer:: { OffsetBuffer , ScalarBuffer } ,
620
- datatypes:: { i256, ArrowPrimitiveType , DataType , Field , Fields , Schema } ,
629
+ datatypes:: { i256, ArrowPrimitiveType , ByteArrayType , DataType , Field , Fields , Schema } ,
621
630
record_batch:: RecordBatch ,
622
631
} ;
623
632
use std:: { error:: Error , sync:: Arc } ;
@@ -784,6 +793,48 @@ mod test {
784
793
Ok ( ( ) )
785
794
}
786
795
796
+ fn check_generic_byte_roundtrip < T1 , T2 > (
797
+ arry_in : GenericByteArray < T1 > ,
798
+ arry_out : GenericByteArray < T2 > ,
799
+ ) -> Result < ( ) , Box < dyn Error > >
800
+ where
801
+ T1 : ByteArrayType ,
802
+ T2 : ByteArrayType ,
803
+ {
804
+ let db = Connection :: open_in_memory ( ) ?;
805
+ db. register_table_function :: < ArrowVTab > ( "arrow" ) ?;
806
+
807
+ // Roundtrip a record batch from Rust to DuckDB and back to Rust
808
+ let schema = Schema :: new ( vec ! [ Field :: new( "a" , arry_in. data_type( ) . clone( ) , false ) ] ) ;
809
+
810
+ let rb = RecordBatch :: try_new ( Arc :: new ( schema) , vec ! [ Arc :: new( arry_in. clone( ) ) ] ) ?;
811
+ let param = arrow_recordbatch_to_query_params ( rb) ;
812
+ let mut stmt = db. prepare ( "select a from arrow(?, ?)" ) ?;
813
+ let rb = stmt. query_arrow ( param) ?. next ( ) . expect ( "no record batch" ) ;
814
+
815
+ let output_any_array = rb. column ( 0 ) ;
816
+
817
+ assert ! (
818
+ output_any_array. data_type( ) . equals_datatype( arry_out. data_type( ) ) ,
819
+ "{} != {}" ,
820
+ output_any_array. data_type( ) ,
821
+ arry_out. data_type( )
822
+ ) ;
823
+
824
+ match output_any_array. as_bytes_opt :: < T2 > ( ) {
825
+ Some ( output_array) => {
826
+ assert_eq ! ( output_array. len( ) , arry_out. len( ) ) ;
827
+ for i in 0 ..output_array. len ( ) {
828
+ assert_eq ! ( output_array. is_valid( i) , arry_out. is_valid( i) ) ;
829
+ assert_eq ! ( output_array. value_data( ) , arry_out. value_data( ) )
830
+ }
831
+ }
832
+ None => panic ! ( "Expected GenericByteArray" ) ,
833
+ }
834
+
835
+ Ok ( ( ) )
836
+ }
837
+
787
838
#[ test]
788
839
fn test_array_roundtrip ( ) -> Result < ( ) , Box < dyn Error > > {
789
840
check_generic_array_roundtrip ( ListArray :: new (
@@ -862,6 +913,21 @@ mod test {
862
913
Ok ( ( ) )
863
914
}
864
915
916
+ #[ test]
917
+ fn test_utf8_roundtrip ( ) -> Result < ( ) , Box < dyn Error > > {
918
+ check_generic_byte_roundtrip (
919
+ StringArray :: from ( vec ! [ Some ( "foo" ) , Some ( "Baz" ) , Some ( "bar" ) ] ) ,
920
+ StringArray :: from ( vec ! [ Some ( "foo" ) , Some ( "Baz" ) , Some ( "bar" ) ] ) ,
921
+ ) ?;
922
+
923
+ // [`LargeStringArray`] will be downcasted to [`StringArray`].
924
+ check_generic_byte_roundtrip (
925
+ LargeStringArray :: from ( vec ! [ Some ( "foo" ) , Some ( "Baz" ) , Some ( "bar" ) ] ) ,
926
+ StringArray :: from ( vec ! [ Some ( "foo" ) , Some ( "Baz" ) , Some ( "bar" ) ] ) ,
927
+ ) ?;
928
+ Ok ( ( ) )
929
+ }
930
+
865
931
#[ test]
866
932
fn test_timestamp_roundtrip ( ) -> Result < ( ) , Box < dyn Error > > {
867
933
check_rust_primitive_array_roundtrip ( Int32Array :: from ( vec ! [ 1 , 2 , 3 ] ) , Int32Array :: from ( vec ! [ 1 , 2 , 3 ] ) ) ?;
0 commit comments