diff --git a/crates/duckdb/src/vscalar/arrow.rs b/crates/duckdb/src/vscalar/arrow.rs index 1b80089b..4edd4731 100644 --- a/crates/duckdb/src/vscalar/arrow.rs +++ b/crates/duckdb/src/vscalar/arrow.rs @@ -6,8 +6,8 @@ use arrow::{ }; use crate::{ - core::{DataChunkHandle, LogicalTypeId}, - vtab::arrow::{data_chunk_to_arrow, write_arrow_array_to_vector, WritableVector}, + core::DataChunkHandle, + vtab::arrow::{data_chunk_to_arrow, to_duckdb_logical_type, write_arrow_array_to_vector, WritableVector}, }; use super::{ScalarFunctionSignature, ScalarParams, VScalar}; @@ -35,14 +35,12 @@ impl From for ScalarParams { ArrowScalarParams::Exact(params) => Self::Exact( params .into_iter() - .map(|v| LogicalTypeId::try_from(&v).expect("type should be converted").into()) + .map(|v| to_duckdb_logical_type(&v).expect("type should be converted")) .collect(), ), - ArrowScalarParams::Variadic(param) => Self::Variadic( - LogicalTypeId::try_from(¶m) - .expect("type should be converted") - .into(), - ), + ArrowScalarParams::Variadic(param) => { + Self::Variadic(to_duckdb_logical_type(¶m).expect("type should be converted")) + } } } } @@ -107,7 +105,7 @@ where .into_iter() .map(|sig| ScalarFunctionSignature { parameters: sig.parameters.map(Into::into), - return_type: LogicalTypeId::try_from(&sig.return_type) + return_type: to_duckdb_logical_type(&sig.return_type) .expect("type should be converted") .into(), }) @@ -333,4 +331,58 @@ mod test { Ok(()) } + + #[test] + fn test_split_function() -> Result<(), Box> { + struct SplitFunction {} + + impl VArrowScalar for SplitFunction { + type State = (); + + fn invoke(_: &Self::State, input: RecordBatch) -> Result, Box> { + let strings = input.column(0).as_any().downcast_ref::().unwrap(); + + let mut builder = arrow::array::ListBuilder::new(arrow::array::StringBuilder::with_capacity( + strings.len(), + strings.len() * 10, + )); + + for s in strings.iter() { + let s = s.unwrap(); + for split_value in s.split(' ').collect::>() { + builder.values().append_value(split_value); + } + builder.append(true); + } + + Ok(Arc::new(builder.finish())) + } + + fn signatures() -> Vec { + vec![ArrowFunctionSignature::exact( + vec![DataType::Utf8], + DataType::List(Arc::new(arrow::datatypes::Field::new("item", DataType::Utf8, true))), + )] + } + } + + let conn = Connection::open_in_memory()?; + conn.register_scalar_function::("split_string")?; + + // Test with single string + let batches = conn + .prepare("select split_string('hello world') as result")? + .query_arrow([])? + .collect::>(); + + let array = batches[0].column(0); + let list_array = array.as_any().downcast_ref::().unwrap(); + let values = list_array.value(0); + let string_values = values.as_any().downcast_ref::().unwrap(); + + assert_eq!(string_values.value(0), "hello"); + assert_eq!(string_values.value(1), "world"); + + Ok(()) + } }