Skip to content

Commit f628e5a

Browse files
JeadieMaxxen
andauthored
Add support for DuckDB arrays when using Arrow's FixedSizeList (#323)
* support UTF8[] * add tests * fix test * format * clippy * bump cause github is broken * add support for DuckDB arrays when using Arrow's FixedSizeList * fmt * add ArrayVector * update path in remote test --------- Co-authored-by: Max Gabrielsson <[email protected]>
1 parent 74fce0f commit f628e5a

File tree

5 files changed

+117
-25
lines changed

5 files changed

+117
-25
lines changed

crates/duckdb/src/extension.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ mod test {
3838
let db = Connection::open_in_memory()?;
3939
assert_eq!(
4040
300f32,
41-
db.query_row::<f32, _, _>(r#"SELECT SUM(value) FROM read_parquet('https://github.com/wangfenjin/duckdb-rs/raw/main/examples/int32_decimal.parquet');"#, [], |r| r.get(0))?
41+
db.query_row::<f32, _, _>(r#"SELECT SUM(value) FROM read_parquet('https://github.com/duckdb/duckdb-rs/raw/main/crates/duckdb/examples/int32_decimal.parquet');"#, [], |r| r.get(0))?
4242
);
4343
Ok(())
4444
}

crates/duckdb/src/vtab/arrow.rs

Lines changed: 58 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use super::{
2-
vector::{FlatVector, ListVector, Vector},
2+
vector::{ArrayVector, FlatVector, ListVector, Vector},
33
BindInfo, DataChunk, Free, FunctionInfo, InitInfo, LogicalType, LogicalTypeId, StructVector, VTab,
44
};
55
use std::ptr::null_mut;
@@ -196,8 +196,11 @@ pub fn to_duckdb_logical_type(data_type: &DataType) -> Result<LogicalType, Box<d
196196
Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?))
197197
} else if let DataType::LargeList(child) = data_type {
198198
Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?))
199-
} else if let DataType::FixedSizeList(child, _) = data_type {
200-
Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?))
199+
} else if let DataType::FixedSizeList(child, array_size) = data_type {
200+
Ok(LogicalType::array(
201+
&to_duckdb_logical_type(child.data_type())?,
202+
*array_size as u64,
203+
))
201204
} else {
202205
Err(
203206
format!("Unsupported data type: {data_type}, please file an issue https://github.com/wangfenjin/duckdb-rs")
@@ -234,7 +237,7 @@ pub fn record_batch_to_duckdb_data_chunk(
234237
list_array_to_vector(as_large_list_array(col.as_ref()), &mut chunk.list_vector(i))?;
235238
}
236239
DataType::FixedSizeList(_, _) => {
237-
fixed_size_list_array_to_vector(as_fixed_size_list_array(col.as_ref()), &mut chunk.list_vector(i))?;
240+
fixed_size_list_array_to_vector(as_fixed_size_list_array(col.as_ref()), &mut chunk.array_vector(i))?;
238241
}
239242
DataType::Struct(_) => {
240243
let struct_array = as_struct_array(col.as_ref());
@@ -455,33 +458,21 @@ fn list_array_to_vector<O: OffsetSizeTrait + AsPrimitive<usize>>(
455458

456459
fn fixed_size_list_array_to_vector(
457460
array: &FixedSizeListArray,
458-
out: &mut ListVector,
461+
out: &mut ArrayVector,
459462
) -> Result<(), Box<dyn std::error::Error>> {
460463
let value_array = array.values();
461464
let mut child = out.child(value_array.len());
462465
match value_array.data_type() {
463466
dt if dt.is_primitive() => {
464467
primitive_array_to_vector(value_array.as_ref(), &mut child)?;
465-
for i in 0..array.len() {
466-
let offset = array.value_offset(i);
467-
let length = array.value_length();
468-
out.set_entry(i, offset as usize, length as usize);
469-
}
470-
out.set_len(value_array.len());
471468
}
472469
DataType::Utf8 => {
473470
string_array_to_vector(as_string_array(value_array.as_ref()), &mut child);
474471
}
475472
_ => {
476-
return Err("Nested list is not supported yet.".into());
473+
return Err("Nested array is not supported yet.".into());
477474
}
478475
}
479-
for i in 0..array.len() {
480-
let offset = array.value_offset(i);
481-
let length = array.value_length();
482-
out.set_entry(i, offset as usize, length as usize);
483-
}
484-
out.set_len(value_array.len());
485476

486477
Ok(())
487478
}
@@ -511,7 +502,7 @@ fn struct_array_to_vector(array: &StructArray, out: &mut StructVector) -> Result
511502
DataType::FixedSizeList(_, _) => {
512503
fixed_size_list_array_to_vector(
513504
as_fixed_size_list_array(column.as_ref()),
514-
&mut out.list_vector_child(i),
505+
&mut out.array_vector_child(i),
515506
)?;
516507
}
517508
DataType::Struct(_) => {
@@ -569,10 +560,10 @@ mod test {
569560
use crate::{Connection, Result};
570561
use arrow::{
571562
array::{
572-
Array, ArrayRef, AsArray, Date32Array, Date64Array, Decimal256Array, Float64Array, GenericListArray,
573-
Int32Array, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, Time32SecondArray,
574-
Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
575-
TimestampSecondArray,
563+
Array, ArrayRef, AsArray, Date32Array, Date64Array, Decimal256Array, FixedSizeListArray, Float64Array,
564+
GenericListArray, Int32Array, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray,
565+
Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
566+
TimestampNanosecondArray, TimestampSecondArray,
576567
},
577568
buffer::{OffsetBuffer, ScalarBuffer},
578569
datatypes::{i256, ArrowPrimitiveType, DataType, Field, Fields, Schema},
@@ -760,6 +751,50 @@ mod test {
760751
Ok(())
761752
}
762753

754+
//field: FieldRef, size: i32, values: ArrayRef, nulls: Option<NullBuffer>
755+
#[test]
756+
fn test_fixed_array_roundtrip() -> Result<(), Box<dyn Error>> {
757+
let array = FixedSizeListArray::new(
758+
Arc::new(Field::new("item", DataType::Int32, true)),
759+
2,
760+
Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3), Some(4), Some(5)])),
761+
None,
762+
);
763+
764+
let expected_output_array = array.clone();
765+
766+
let db = Connection::open_in_memory()?;
767+
db.register_table_function::<ArrowVTab>("arrow")?;
768+
769+
// Roundtrip a record batch from Rust to DuckDB and back to Rust
770+
let schema = Schema::new(vec![Field::new("a", array.data_type().clone(), false)]);
771+
772+
let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())])?;
773+
let param = arrow_recordbatch_to_query_params(rb);
774+
let mut stmt = db.prepare("select a from arrow(?, ?)")?;
775+
let rb = stmt.query_arrow(param)?.next().expect("no record batch");
776+
777+
let output_any_array = rb.column(0);
778+
assert!(output_any_array
779+
.data_type()
780+
.equals_datatype(expected_output_array.data_type()));
781+
782+
match output_any_array.as_fixed_size_list_opt() {
783+
Some(output_array) => {
784+
assert_eq!(output_array.len(), expected_output_array.len());
785+
for i in 0..output_array.len() {
786+
assert_eq!(output_array.is_valid(i), expected_output_array.is_valid(i));
787+
if output_array.is_valid(i) {
788+
assert!(expected_output_array.value(i).eq(&output_array.value(i)));
789+
}
790+
}
791+
}
792+
None => panic!("Expected FixedSizeListArray"),
793+
}
794+
795+
Ok(())
796+
}
797+
763798
#[test]
764799
fn test_primitive_roundtrip_contains_nulls() -> Result<(), Box<dyn Error>> {
765800
let mut builder = arrow::array::PrimitiveBuilder::<arrow::datatypes::Int32Type>::new();

crates/duckdb/src/vtab/data_chunk.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use super::{
22
logical_type::LogicalType,
3-
vector::{FlatVector, ListVector, StructVector},
3+
vector::{ArrayVector, FlatVector, ListVector, StructVector},
44
};
55
use crate::ffi::{
66
duckdb_create_data_chunk, duckdb_data_chunk, duckdb_data_chunk_get_column_count, duckdb_data_chunk_get_size,
@@ -35,6 +35,11 @@ impl DataChunk {
3535
ListVector::from(unsafe { duckdb_data_chunk_get_vector(self.ptr, idx as u64) })
3636
}
3737

38+
/// Get a array vector from the column index.
39+
pub fn array_vector(&self, idx: usize) -> ArrayVector {
40+
ArrayVector::from(unsafe { duckdb_data_chunk_get_vector(self.ptr, idx as u64) })
41+
}
42+
3843
/// Get struct vector at the column index: `idx`.
3944
pub fn struct_vector(&self, idx: usize) -> StructVector {
4045
StructVector::from(unsafe { duckdb_data_chunk_get_vector(self.ptr, idx as u64) })

crates/duckdb/src/vtab/logical_type.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,15 @@ impl LogicalType {
182182
}
183183
}
184184

185+
/// Creates an array type from its child type.
186+
pub fn array(child_type: &LogicalType, array_size: u64) -> Self {
187+
unsafe {
188+
Self {
189+
ptr: duckdb_create_array_type(child_type.ptr, array_size),
190+
}
191+
}
192+
}
193+
185194
/// Creates a decimal type from its `width` and `scale`.
186195
pub fn decimal(width: u8, scale: u8) -> Self {
187196
unsafe {

crates/duckdb/src/vtab/vector.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
use std::{any::Any, ffi::CString, slice};
22

3+
use libduckdb_sys::{duckdb_array_type_array_size, duckdb_array_vector_get_child};
4+
35
use super::LogicalType;
46
use crate::ffi::{
57
duckdb_list_entry, duckdb_list_vector_get_child, duckdb_list_vector_get_size, duckdb_list_vector_reserve,
@@ -170,6 +172,42 @@ impl ListVector {
170172
}
171173
}
172174

175+
/// A array vector. (fixed-size list)
176+
pub struct ArrayVector {
177+
/// ArrayVector does not own the vector pointer.
178+
ptr: duckdb_vector,
179+
}
180+
181+
impl From<duckdb_vector> for ArrayVector {
182+
fn from(ptr: duckdb_vector) -> Self {
183+
Self { ptr }
184+
}
185+
}
186+
187+
impl ArrayVector {
188+
/// Get the logical type of this ArrayVector.
189+
pub fn logical_type(&self) -> LogicalType {
190+
LogicalType::from(unsafe { duckdb_vector_get_column_type(self.ptr) })
191+
}
192+
193+
pub fn get_array_size(&self) -> u64 {
194+
let ty = self.logical_type();
195+
unsafe { duckdb_array_type_array_size(ty.ptr) as u64 }
196+
}
197+
198+
/// Returns the child vector.
199+
/// capacity should be a multiple of the array size.
200+
// TODO: not ideal interface. Where should we keep count.
201+
pub fn child(&self, capacity: usize) -> FlatVector {
202+
FlatVector::with_capacity(unsafe { duckdb_array_vector_get_child(self.ptr) }, capacity)
203+
}
204+
205+
/// Set primitive data to the child node.
206+
pub fn set_child<T: Copy>(&self, data: &[T]) {
207+
self.child(data.len()).copy(data);
208+
}
209+
}
210+
173211
/// A struct vector.
174212
pub struct StructVector {
175213
/// ListVector does not own the vector pointer.
@@ -198,6 +236,11 @@ impl StructVector {
198236
ListVector::from(unsafe { duckdb_struct_vector_get_child(self.ptr, idx as u64) })
199237
}
200238

239+
/// Take the child as [ArrayVector].
240+
pub fn array_vector_child(&self, idx: usize) -> ArrayVector {
241+
ArrayVector::from(unsafe { duckdb_struct_vector_get_child(self.ptr, idx as u64) })
242+
}
243+
201244
/// Get the logical type of this struct vector.
202245
pub fn logical_type(&self) -> LogicalType {
203246
LogicalType::from(unsafe { duckdb_vector_get_column_type(self.ptr) })

0 commit comments

Comments
 (0)