@@ -6,8 +6,8 @@ use arrow::{
66} ;
77
88use crate :: {
9- core:: { DataChunkHandle , LogicalTypeId } ,
10- vtab:: arrow:: { data_chunk_to_arrow, write_arrow_array_to_vector, WritableVector } ,
9+ core:: DataChunkHandle ,
10+ vtab:: arrow:: { data_chunk_to_arrow, to_duckdb_logical_type , write_arrow_array_to_vector, WritableVector } ,
1111} ;
1212
1313use super :: { ScalarFunctionSignature , ScalarParams , VScalar } ;
@@ -35,14 +35,12 @@ impl From<ArrowScalarParams> for ScalarParams {
3535 ArrowScalarParams :: Exact ( params) => Self :: Exact (
3636 params
3737 . into_iter ( )
38- . map ( |v| LogicalTypeId :: try_from ( & v) . expect ( "type should be converted" ) . into ( ) )
38+ . map ( |v| to_duckdb_logical_type ( & v) . expect ( "type should be converted" ) )
3939 . collect ( ) ,
4040 ) ,
41- ArrowScalarParams :: Variadic ( param) => Self :: Variadic (
42- LogicalTypeId :: try_from ( & param)
43- . expect ( "type should be converted" )
44- . into ( ) ,
45- ) ,
41+ ArrowScalarParams :: Variadic ( param) => {
42+ Self :: Variadic ( to_duckdb_logical_type ( & param) . expect ( "type should be converted" ) )
43+ }
4644 }
4745 }
4846}
@@ -107,9 +105,7 @@ where
107105 . into_iter ( )
108106 . map ( |sig| ScalarFunctionSignature {
109107 parameters : sig. parameters . map ( Into :: into) ,
110- return_type : LogicalTypeId :: try_from ( & sig. return_type )
111- . expect ( "type should be converted" )
112- . into ( ) ,
108+ return_type : to_duckdb_logical_type ( & sig. return_type ) . expect ( "type should be converted" ) ,
113109 } )
114110 . collect ( )
115111 }
@@ -333,4 +329,58 @@ mod test {
333329
334330 Ok ( ( ) )
335331 }
332+
333+ #[ test]
334+ fn test_split_function ( ) -> Result < ( ) , Box < dyn Error > > {
335+ struct SplitFunction { }
336+
337+ impl VArrowScalar for SplitFunction {
338+ type State = ( ) ;
339+
340+ fn invoke ( _: & Self :: State , input : RecordBatch ) -> Result < Arc < dyn Array > , Box < dyn std:: error:: Error > > {
341+ let strings = input. column ( 0 ) . as_any ( ) . downcast_ref :: < StringArray > ( ) . unwrap ( ) ;
342+
343+ let mut builder = arrow:: array:: ListBuilder :: new ( arrow:: array:: StringBuilder :: with_capacity (
344+ strings. len ( ) ,
345+ strings. len ( ) * 10 ,
346+ ) ) ;
347+
348+ for s in strings. iter ( ) {
349+ let s = s. unwrap ( ) ;
350+ for split_value in s. split ( ' ' ) . collect :: < Vec < _ > > ( ) {
351+ builder. values ( ) . append_value ( split_value) ;
352+ }
353+ builder. append ( true ) ;
354+ }
355+
356+ Ok ( Arc :: new ( builder. finish ( ) ) )
357+ }
358+
359+ fn signatures ( ) -> Vec < ArrowFunctionSignature > {
360+ vec ! [ ArrowFunctionSignature :: exact(
361+ vec![ DataType :: Utf8 ] ,
362+ DataType :: List ( Arc :: new( arrow:: datatypes:: Field :: new( "item" , DataType :: Utf8 , true ) ) ) ,
363+ ) ]
364+ }
365+ }
366+
367+ let conn = Connection :: open_in_memory ( ) ?;
368+ conn. register_scalar_function :: < SplitFunction > ( "split_string" ) ?;
369+
370+ // Test with single string
371+ let batches = conn
372+ . prepare ( "select split_string('hello world') as result" ) ?
373+ . query_arrow ( [ ] ) ?
374+ . collect :: < Vec < _ > > ( ) ;
375+
376+ let array = batches[ 0 ] . column ( 0 ) ;
377+ let list_array = array. as_any ( ) . downcast_ref :: < arrow:: array:: ListArray > ( ) . unwrap ( ) ;
378+ let values = list_array. value ( 0 ) ;
379+ let string_values = values. as_any ( ) . downcast_ref :: < StringArray > ( ) . unwrap ( ) ;
380+
381+ assert_eq ! ( string_values. value( 0 ) , "hello" ) ;
382+ assert_eq ! ( string_values. value( 1 ) , "world" ) ;
383+
384+ Ok ( ( ) )
385+ }
336386}
0 commit comments