Skip to content

Commit 2e638f1

Browse files
authored
Added struct type support in arrow feature (#279)
1 parent 8eaf5de commit 2e638f1

File tree

1 file changed

+95
-66
lines changed

1 file changed

+95
-66
lines changed

src/vtab/arrow.rs

Lines changed: 95 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
use super::{
22
vector::{FlatVector, ListVector, Vector},
3-
BindInfo, DataChunk, Free, FunctionInfo, InitInfo, LogicalType, LogicalTypeId, VTab,
3+
BindInfo, DataChunk, Free, FunctionInfo, InitInfo, LogicalType, LogicalTypeId, StructVector, VTab,
44
};
55

66
use crate::vtab::vector::Inserter;
77
use arrow::array::{
8-
as_boolean_array, as_large_list_array, as_list_array, as_primitive_array, as_string_array, Array, ArrayData,
9-
AsArray, BooleanArray, Decimal128Array, FixedSizeListArray, GenericListArray, OffsetSizeTrait, PrimitiveArray,
10-
StringArray, StructArray,
8+
as_boolean_array, as_large_list_array, as_list_array, as_primitive_array, as_string_array, as_struct_array, Array,
9+
ArrayData, AsArray, BooleanArray, Decimal128Array, FixedSizeListArray, GenericListArray, OffsetSizeTrait,
10+
PrimitiveArray, StringArray, StructArray,
1111
};
1212

1313
use arrow::{
@@ -181,24 +181,22 @@ pub fn to_duckdb_logical_type(data_type: &DataType) -> Result<LogicalType, Box<d
181181
Ok(LogicalType::new(to_duckdb_type_id(data_type)?))
182182
} else if let DataType::Dictionary(_, value_type) = data_type {
183183
to_duckdb_logical_type(value_type)
184-
// } else if let DataType::Struct(fields) = data_type {
185-
// let mut shape = vec![];
186-
// for field in fields.iter() {
187-
// shape.push((
188-
// field.name().as_str(),
189-
// to_duckdb_logical_type(field.data_type())?,
190-
// ));
191-
// }
192-
// Ok(LogicalType::struct_type(shape.as_slice()))
184+
} else if let DataType::Struct(fields) = data_type {
185+
let mut shape = vec![];
186+
for field in fields.iter() {
187+
shape.push((field.name().as_str(), to_duckdb_logical_type(field.data_type())?));
188+
}
189+
Ok(LogicalType::struct_type(shape.as_slice()))
193190
} else if let DataType::List(child) = data_type {
194191
Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?))
195192
} else if let DataType::LargeList(child) = data_type {
196193
Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?))
197194
} else if let DataType::FixedSizeList(child, _) = data_type {
198195
Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?))
199196
} else {
200-
println!("Unsupported data type: {data_type}, please file an issue https://github.com/wangfenjin/duckdb-rs");
201-
todo!()
197+
unimplemented!(
198+
"Unsupported data type: {data_type}, please file an issue https://github.com/wangfenjin/duckdb-rs"
199+
)
202200
}
203201
}
204202

@@ -232,17 +230,16 @@ pub fn record_batch_to_duckdb_data_chunk(
232230
DataType::FixedSizeList(_, _) => {
233231
fixed_size_list_array_to_vector(as_fixed_size_list_array(col.as_ref()), &mut chunk.list_vector(i));
234232
}
235-
// DataType::Struct(_) => {
236-
// let struct_array = as_struct_array(col.as_ref());
237-
// let mut struct_vector = chunk.struct_vector(i);
238-
// struct_array_to_vector(struct_array, &mut struct_vector);
239-
// }
233+
DataType::Struct(_) => {
234+
let struct_array = as_struct_array(col.as_ref());
235+
let mut struct_vector = chunk.struct_vector(i);
236+
struct_array_to_vector(struct_array, &mut struct_vector);
237+
}
240238
_ => {
241-
println!(
239+
unimplemented!(
242240
"column {} is not supported yet, please file an issue https://github.com/wangfenjin/duckdb-rs",
243241
batch.schema().field(i)
244242
);
245-
todo!()
246243
}
247244
}
248245
}
@@ -458,46 +455,42 @@ fn as_fixed_size_list_array(arr: &dyn Array) -> &FixedSizeListArray {
458455
arr.as_any().downcast_ref::<FixedSizeListArray>().unwrap()
459456
}
460457

461-
// fn struct_array_to_vector(array: &StructArray, out: &mut StructVector) {
462-
// for i in 0..array.num_columns() {
463-
// let column = array.column(i);
464-
// match column.data_type() {
465-
// dt if dt.is_primitive() || matches!(dt, DataType::Boolean) => {
466-
// primitive_array_to_vector(column, &mut out.child(i));
467-
// }
468-
// DataType::Utf8 => {
469-
// string_array_to_vector(as_string_array(column.as_ref()), &mut out.child(i));
470-
// }
471-
// DataType::List(_) => {
472-
// list_array_to_vector(
473-
// as_list_array(column.as_ref()),
474-
// &mut out.list_vector_child(i),
475-
// );
476-
// }
477-
// DataType::LargeList(_) => {
478-
// list_array_to_vector(
479-
// as_large_list_array(column.as_ref()),
480-
// &mut out.list_vector_child(i),
481-
// );
482-
// }
483-
// DataType::FixedSizeList(_, _) => {
484-
// fixed_size_list_array_to_vector(
485-
// as_fixed_size_list_array(column.as_ref()),
486-
// &mut out.list_vector_child(i),
487-
// );
488-
// }
489-
// DataType::Struct(_) => {
490-
// let struct_array = as_struct_array(column.as_ref());
491-
// let mut struct_vector = out.struct_vector_child(i);
492-
// struct_array_to_vector(struct_array, &mut struct_vector);
493-
// }
494-
// _ => {
495-
// println!("Unsupported data type: {}, please file an issue https://github.com/wangfenjin/duckdb-rs", column.data_type());
496-
// todo!()
497-
// }
498-
// }
499-
// }
500-
// }
458+
fn struct_array_to_vector(array: &StructArray, out: &mut StructVector) {
459+
for i in 0..array.num_columns() {
460+
let column = array.column(i);
461+
match column.data_type() {
462+
dt if dt.is_primitive() || matches!(dt, DataType::Boolean) => {
463+
primitive_array_to_vector(column, &mut out.child(i));
464+
}
465+
DataType::Utf8 => {
466+
string_array_to_vector(as_string_array(column.as_ref()), &mut out.child(i));
467+
}
468+
DataType::List(_) => {
469+
list_array_to_vector(as_list_array(column.as_ref()), &mut out.list_vector_child(i));
470+
}
471+
DataType::LargeList(_) => {
472+
list_array_to_vector(as_large_list_array(column.as_ref()), &mut out.list_vector_child(i));
473+
}
474+
DataType::FixedSizeList(_, _) => {
475+
fixed_size_list_array_to_vector(
476+
as_fixed_size_list_array(column.as_ref()),
477+
&mut out.list_vector_child(i),
478+
);
479+
}
480+
DataType::Struct(_) => {
481+
let struct_array = as_struct_array(column.as_ref());
482+
let mut struct_vector = out.struct_vector_child(i);
483+
struct_array_to_vector(struct_array, &mut struct_vector);
484+
}
485+
_ => {
486+
unimplemented!(
487+
"Unsupported data type: {}, please file an issue https://github.com/wangfenjin/duckdb-rs",
488+
column.data_type()
489+
);
490+
}
491+
}
492+
}
493+
}
501494

502495
/// Pass RecordBatch to duckdb.
503496
///
@@ -538,11 +531,11 @@ mod test {
538531
use crate::{Connection, Result};
539532
use arrow::{
540533
array::{
541-
Array, AsArray, Date32Array, Date64Array, Float64Array, Int32Array, PrimitiveArray, StringArray,
542-
Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
543-
TimestampNanosecondArray, TimestampSecondArray,
534+
Array, ArrayRef, AsArray, Date32Array, Date64Array, Float64Array, Int32Array, PrimitiveArray, StringArray,
535+
StructArray, Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray,
536+
TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray,
544537
},
545-
datatypes::{ArrowPrimitiveType, DataType, Field, Schema},
538+
datatypes::{ArrowPrimitiveType, DataType, Field, Fields, Schema},
546539
record_batch::RecordBatch,
547540
};
548541
use std::{error::Error, sync::Arc};
@@ -588,6 +581,42 @@ mod test {
588581
Ok(())
589582
}
590583

584+
#[test]
585+
fn test_append_struct() -> Result<(), Box<dyn Error>> {
586+
let db = Connection::open_in_memory()?;
587+
db.execute_batch("CREATE TABLE t1 (s STRUCT(v VARCHAR, i INTEGER))")?;
588+
{
589+
let struct_array = StructArray::from(vec![
590+
(
591+
Arc::new(Field::new("v", DataType::Utf8, true)),
592+
Arc::new(StringArray::from(vec![Some("foo"), Some("bar")])) as ArrayRef,
593+
),
594+
(
595+
Arc::new(Field::new("i", DataType::Int32, true)),
596+
Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef,
597+
),
598+
]);
599+
600+
let schema = Schema::new(vec![Field::new(
601+
"s",
602+
DataType::Struct(Fields::from(vec![
603+
Field::new("v", DataType::Utf8, true),
604+
Field::new("i", DataType::Int32, true),
605+
])),
606+
true,
607+
)]);
608+
609+
let record_batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_array)])?;
610+
let mut app = db.appender("t1")?;
611+
app.append_record_batch(record_batch)?;
612+
}
613+
let mut stmt = db.prepare("SELECT s FROM t1")?;
614+
let rbs: Vec<RecordBatch> = stmt.query_arrow([])?.collect();
615+
assert_eq!(rbs.iter().map(|op| op.num_rows()).sum::<usize>(), 2);
616+
617+
Ok(())
618+
}
619+
591620
fn check_rust_primitive_array_roundtrip<T1, T2>(
592621
input_array: PrimitiveArray<T1>,
593622
expected_array: PrimitiveArray<T2>,

0 commit comments

Comments
 (0)