Skip to content

Commit f5a687c

Browse files
authored
Fix arrow list type to duckdb logical type conversion (#574)
When registering a scalar function which returns a list value, duckdb returns an error: ``` thread '<unnamed>' panicked at src/lib.rs:31:10: Failed to register split_words function: DuckDBFailure(Error { code: Unknown, extended_code: 1 }, None) ``` This PR fixes this issue by replacing `LogicalTypeId::try_from(..).into()` by `to_duckdb_logical_type`.
2 parents ef394fc + 8c56ecc commit f5a687c

File tree

1 file changed

+61
-11
lines changed

1 file changed

+61
-11
lines changed

crates/duckdb/src/vscalar/arrow.rs

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ use arrow::{
66
};
77

88
use crate::{
9-
core::{DataChunkHandle, LogicalTypeId},
10-
vtab::arrow::{data_chunk_to_arrow, write_arrow_array_to_vector, WritableVector},
9+
core::DataChunkHandle,
10+
vtab::arrow::{data_chunk_to_arrow, to_duckdb_logical_type, write_arrow_array_to_vector, WritableVector},
1111
};
1212

1313
use super::{ScalarFunctionSignature, ScalarParams, VScalar};
@@ -35,14 +35,12 @@ impl From<ArrowScalarParams> for ScalarParams {
3535
ArrowScalarParams::Exact(params) => Self::Exact(
3636
params
3737
.into_iter()
38-
.map(|v| LogicalTypeId::try_from(&v).expect("type should be converted").into())
38+
.map(|v| to_duckdb_logical_type(&v).expect("type should be converted"))
3939
.collect(),
4040
),
41-
ArrowScalarParams::Variadic(param) => Self::Variadic(
42-
LogicalTypeId::try_from(&param)
43-
.expect("type should be converted")
44-
.into(),
45-
),
41+
ArrowScalarParams::Variadic(param) => {
42+
Self::Variadic(to_duckdb_logical_type(&param).expect("type should be converted"))
43+
}
4644
}
4745
}
4846
}
@@ -107,9 +105,7 @@ where
107105
.into_iter()
108106
.map(|sig| ScalarFunctionSignature {
109107
parameters: sig.parameters.map(Into::into),
110-
return_type: LogicalTypeId::try_from(&sig.return_type)
111-
.expect("type should be converted")
112-
.into(),
108+
return_type: to_duckdb_logical_type(&sig.return_type).expect("type should be converted"),
113109
})
114110
.collect()
115111
}
@@ -333,4 +329,58 @@ mod test {
333329

334330
Ok(())
335331
}
332+
333+
#[test]
334+
fn test_split_function() -> Result<(), Box<dyn Error>> {
335+
struct SplitFunction {}
336+
337+
impl VArrowScalar for SplitFunction {
338+
type State = ();
339+
340+
fn invoke(_: &Self::State, input: RecordBatch) -> Result<Arc<dyn Array>, Box<dyn std::error::Error>> {
341+
let strings = input.column(0).as_any().downcast_ref::<StringArray>().unwrap();
342+
343+
let mut builder = arrow::array::ListBuilder::new(arrow::array::StringBuilder::with_capacity(
344+
strings.len(),
345+
strings.len() * 10,
346+
));
347+
348+
for s in strings.iter() {
349+
let s = s.unwrap();
350+
for split_value in s.split(' ').collect::<Vec<_>>() {
351+
builder.values().append_value(split_value);
352+
}
353+
builder.append(true);
354+
}
355+
356+
Ok(Arc::new(builder.finish()))
357+
}
358+
359+
fn signatures() -> Vec<ArrowFunctionSignature> {
360+
vec![ArrowFunctionSignature::exact(
361+
vec![DataType::Utf8],
362+
DataType::List(Arc::new(arrow::datatypes::Field::new("item", DataType::Utf8, true))),
363+
)]
364+
}
365+
}
366+
367+
let conn = Connection::open_in_memory()?;
368+
conn.register_scalar_function::<SplitFunction>("split_string")?;
369+
370+
// Test with single string
371+
let batches = conn
372+
.prepare("select split_string('hello world') as result")?
373+
.query_arrow([])?
374+
.collect::<Vec<_>>();
375+
376+
let array = batches[0].column(0);
377+
let list_array = array.as_any().downcast_ref::<arrow::array::ListArray>().unwrap();
378+
let values = list_array.value(0);
379+
let string_values = values.as_any().downcast_ref::<StringArray>().unwrap();
380+
381+
assert_eq!(string_values.value(0), "hello");
382+
assert_eq!(string_values.value(1), "world");
383+
384+
Ok(())
385+
}
336386
}

0 commit comments

Comments
 (0)