Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 61 additions & 11 deletions crates/duckdb/src/vscalar/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use arrow::{
};

use crate::{
core::{DataChunkHandle, LogicalTypeId},
vtab::arrow::{data_chunk_to_arrow, write_arrow_array_to_vector, WritableVector},
core::DataChunkHandle,
vtab::arrow::{data_chunk_to_arrow, to_duckdb_logical_type, write_arrow_array_to_vector, WritableVector},
};

use super::{ScalarFunctionSignature, ScalarParams, VScalar};
Expand Down Expand Up @@ -35,14 +35,12 @@ impl From<ArrowScalarParams> for ScalarParams {
ArrowScalarParams::Exact(params) => Self::Exact(
params
.into_iter()
.map(|v| LogicalTypeId::try_from(&v).expect("type should be converted").into())
.map(|v| to_duckdb_logical_type(&v).expect("type should be converted"))
.collect(),
),
ArrowScalarParams::Variadic(param) => Self::Variadic(
LogicalTypeId::try_from(&param)
.expect("type should be converted")
.into(),
),
ArrowScalarParams::Variadic(param) => {
Self::Variadic(to_duckdb_logical_type(&param).expect("type should be converted"))
}
}
}
}
Expand Down Expand Up @@ -107,9 +105,7 @@ where
.into_iter()
.map(|sig| ScalarFunctionSignature {
parameters: sig.parameters.map(Into::into),
return_type: LogicalTypeId::try_from(&sig.return_type)
.expect("type should be converted")
.into(),
return_type: to_duckdb_logical_type(&sig.return_type).expect("type should be converted"),
})
.collect()
}
Expand Down Expand Up @@ -333,4 +329,58 @@ mod test {

Ok(())
}

#[test]
fn test_split_function() -> Result<(), Box<dyn Error>> {
struct SplitFunction {}

impl VArrowScalar for SplitFunction {
type State = ();

fn invoke(_: &Self::State, input: RecordBatch) -> Result<Arc<dyn Array>, Box<dyn std::error::Error>> {
let strings = input.column(0).as_any().downcast_ref::<StringArray>().unwrap();

let mut builder = arrow::array::ListBuilder::new(arrow::array::StringBuilder::with_capacity(
strings.len(),
strings.len() * 10,
));

for s in strings.iter() {
let s = s.unwrap();
for split_value in s.split(' ').collect::<Vec<_>>() {
builder.values().append_value(split_value);
}
builder.append(true);
}

Ok(Arc::new(builder.finish()))
}

fn signatures() -> Vec<ArrowFunctionSignature> {
vec![ArrowFunctionSignature::exact(
vec![DataType::Utf8],
DataType::List(Arc::new(arrow::datatypes::Field::new("item", DataType::Utf8, true))),
)]
}
}

let conn = Connection::open_in_memory()?;
conn.register_scalar_function::<SplitFunction>("split_string")?;

// Test with single string
let batches = conn
.prepare("select split_string('hello world') as result")?
.query_arrow([])?
.collect::<Vec<_>>();

let array = batches[0].column(0);
let list_array = array.as_any().downcast_ref::<arrow::array::ListArray>().unwrap();
let values = list_array.value(0);
let string_values = values.as_any().downcast_ref::<StringArray>().unwrap();

assert_eq!(string_values.value(0), "hello");
assert_eq!(string_values.value(1), "world");

Ok(())
}
}
Loading