@@ -6,8 +6,8 @@ use arrow::{
6
6
} ;
7
7
8
8
use crate :: {
9
- core:: { DataChunkHandle , LogicalTypeId } ,
10
- vtab:: arrow:: { data_chunk_to_arrow, write_arrow_array_to_vector, WritableVector } ,
9
+ core:: DataChunkHandle ,
10
+ vtab:: arrow:: { data_chunk_to_arrow, to_duckdb_logical_type , write_arrow_array_to_vector, WritableVector } ,
11
11
} ;
12
12
13
13
use super :: { ScalarFunctionSignature , ScalarParams , VScalar } ;
@@ -35,14 +35,12 @@ impl From<ArrowScalarParams> for ScalarParams {
35
35
ArrowScalarParams :: Exact ( params) => Self :: Exact (
36
36
params
37
37
. into_iter ( )
38
- . map ( |v| LogicalTypeId :: try_from ( & v) . expect ( "type should be converted" ) . into ( ) )
38
+ . map ( |v| to_duckdb_logical_type ( & v) . expect ( "type should be converted" ) )
39
39
. collect ( ) ,
40
40
) ,
41
- ArrowScalarParams :: Variadic ( param) => Self :: Variadic (
42
- LogicalTypeId :: try_from ( & param)
43
- . expect ( "type should be converted" )
44
- . into ( ) ,
45
- ) ,
41
+ ArrowScalarParams :: Variadic ( param) => {
42
+ Self :: Variadic ( to_duckdb_logical_type ( & param) . expect ( "type should be converted" ) )
43
+ }
46
44
}
47
45
}
48
46
}
@@ -107,9 +105,7 @@ where
107
105
. into_iter ( )
108
106
. map ( |sig| ScalarFunctionSignature {
109
107
parameters : sig. parameters . map ( Into :: into) ,
110
- return_type : LogicalTypeId :: try_from ( & sig. return_type )
111
- . expect ( "type should be converted" )
112
- . into ( ) ,
108
+ return_type : to_duckdb_logical_type ( & sig. return_type ) . expect ( "type should be converted" ) ,
113
109
} )
114
110
. collect ( )
115
111
}
@@ -333,4 +329,58 @@ mod test {
333
329
334
330
Ok ( ( ) )
335
331
}
332
+
333
+ #[ test]
334
+ fn test_split_function ( ) -> Result < ( ) , Box < dyn Error > > {
335
+ struct SplitFunction { }
336
+
337
+ impl VArrowScalar for SplitFunction {
338
+ type State = ( ) ;
339
+
340
+ fn invoke ( _: & Self :: State , input : RecordBatch ) -> Result < Arc < dyn Array > , Box < dyn std:: error:: Error > > {
341
+ let strings = input. column ( 0 ) . as_any ( ) . downcast_ref :: < StringArray > ( ) . unwrap ( ) ;
342
+
343
+ let mut builder = arrow:: array:: ListBuilder :: new ( arrow:: array:: StringBuilder :: with_capacity (
344
+ strings. len ( ) ,
345
+ strings. len ( ) * 10 ,
346
+ ) ) ;
347
+
348
+ for s in strings. iter ( ) {
349
+ let s = s. unwrap ( ) ;
350
+ for split_value in s. split ( ' ' ) . collect :: < Vec < _ > > ( ) {
351
+ builder. values ( ) . append_value ( split_value) ;
352
+ }
353
+ builder. append ( true ) ;
354
+ }
355
+
356
+ Ok ( Arc :: new ( builder. finish ( ) ) )
357
+ }
358
+
359
+ fn signatures ( ) -> Vec < ArrowFunctionSignature > {
360
+ vec ! [ ArrowFunctionSignature :: exact(
361
+ vec![ DataType :: Utf8 ] ,
362
+ DataType :: List ( Arc :: new( arrow:: datatypes:: Field :: new( "item" , DataType :: Utf8 , true ) ) ) ,
363
+ ) ]
364
+ }
365
+ }
366
+
367
+ let conn = Connection :: open_in_memory ( ) ?;
368
+ conn. register_scalar_function :: < SplitFunction > ( "split_string" ) ?;
369
+
370
+ // Test with single string
371
+ let batches = conn
372
+ . prepare ( "select split_string('hello world') as result" ) ?
373
+ . query_arrow ( [ ] ) ?
374
+ . collect :: < Vec < _ > > ( ) ;
375
+
376
+ let array = batches[ 0 ] . column ( 0 ) ;
377
+ let list_array = array. as_any ( ) . downcast_ref :: < arrow:: array:: ListArray > ( ) . unwrap ( ) ;
378
+ let values = list_array. value ( 0 ) ;
379
+ let string_values = values. as_any ( ) . downcast_ref :: < StringArray > ( ) . unwrap ( ) ;
380
+
381
+ assert_eq ! ( string_values. value( 0 ) , "hello" ) ;
382
+ assert_eq ! ( string_values. value( 1 ) , "world" ) ;
383
+
384
+ Ok ( ( ) )
385
+ }
336
386
}
0 commit comments