Skip to content

Commit 244a1c0

Browse files
committed
add unit test
Signed-off-by: Runji Wang <[email protected]>
1 parent 589ca61 commit 244a1c0

File tree

1 file changed

+55
-1
lines changed

1 file changed

+55
-1
lines changed

crates/duckdb/src/vscalar/arrow.rs

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use arrow::{
66
};
77

88
use crate::{
9-
core::{DataChunkHandle, LogicalTypeId},
9+
core::DataChunkHandle,
1010
vtab::arrow::{data_chunk_to_arrow, to_duckdb_logical_type, write_arrow_array_to_vector, WritableVector},
1111
};
1212

@@ -331,4 +331,58 @@ mod test {
331331

332332
Ok(())
333333
}
334+
335+
#[test]
336+
fn test_split_function() -> Result<(), Box<dyn Error>> {
337+
struct SplitFunction {}
338+
339+
impl VArrowScalar for SplitFunction {
340+
type State = ();
341+
342+
fn invoke(_: &Self::State, input: RecordBatch) -> Result<Arc<dyn Array>, Box<dyn std::error::Error>> {
343+
let strings = input.column(0).as_any().downcast_ref::<StringArray>().unwrap();
344+
345+
let mut builder = arrow::array::ListBuilder::new(arrow::array::StringBuilder::with_capacity(
346+
strings.len(),
347+
strings.len() * 10,
348+
));
349+
350+
for s in strings.iter() {
351+
let s = s.unwrap();
352+
for split_value in s.split(' ').collect::<Vec<_>>() {
353+
builder.values().append_value(split_value);
354+
}
355+
builder.append(true);
356+
}
357+
358+
Ok(Arc::new(builder.finish()))
359+
}
360+
361+
fn signatures() -> Vec<ArrowFunctionSignature> {
362+
vec![ArrowFunctionSignature::exact(
363+
vec![DataType::Utf8],
364+
DataType::List(Arc::new(arrow::datatypes::Field::new("item", DataType::Utf8, true))),
365+
)]
366+
}
367+
}
368+
369+
let conn = Connection::open_in_memory()?;
370+
conn.register_scalar_function::<SplitFunction>("split_string")?;
371+
372+
// Test with single string
373+
let batches = conn
374+
.prepare("select split_string('hello world') as result")?
375+
.query_arrow([])?
376+
.collect::<Vec<_>>();
377+
378+
let array = batches[0].column(0);
379+
let list_array = array.as_any().downcast_ref::<arrow::array::ListArray>().unwrap();
380+
let values = list_array.value(0);
381+
let string_values = values.as_any().downcast_ref::<StringArray>().unwrap();
382+
383+
assert_eq!(string_values.value(0), "hello");
384+
assert_eq!(string_values.value(1), "world");
385+
386+
Ok(())
387+
}
334388
}

0 commit comments

Comments
 (0)