|
1 | 1 | use super::{
|
2 |
| - vector::{FlatVector, ListVector, Vector}, |
| 2 | + vector::{ArrayVector, FlatVector, ListVector, Vector}, |
3 | 3 | BindInfo, DataChunk, Free, FunctionInfo, InitInfo, LogicalType, LogicalTypeId, StructVector, VTab,
|
4 | 4 | };
|
5 | 5 | use std::ptr::null_mut;
|
@@ -196,8 +196,11 @@ pub fn to_duckdb_logical_type(data_type: &DataType) -> Result<LogicalType, Box<d
|
196 | 196 | Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?))
|
197 | 197 | } else if let DataType::LargeList(child) = data_type {
|
198 | 198 | Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?))
|
199 |
| - } else if let DataType::FixedSizeList(child, _) = data_type { |
200 |
| - Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?)) |
| 199 | + } else if let DataType::FixedSizeList(child, array_size) = data_type { |
| 200 | + Ok(LogicalType::array( |
| 201 | + &to_duckdb_logical_type(child.data_type())?, |
| 202 | + *array_size as u64, |
| 203 | + )) |
201 | 204 | } else {
|
202 | 205 | Err(
|
203 | 206 | format!("Unsupported data type: {data_type}, please file an issue https://github.com/wangfenjin/duckdb-rs")
|
@@ -234,7 +237,7 @@ pub fn record_batch_to_duckdb_data_chunk(
|
234 | 237 | list_array_to_vector(as_large_list_array(col.as_ref()), &mut chunk.list_vector(i))?;
|
235 | 238 | }
|
236 | 239 | DataType::FixedSizeList(_, _) => {
|
237 |
| - fixed_size_list_array_to_vector(as_fixed_size_list_array(col.as_ref()), &mut chunk.list_vector(i))?; |
| 240 | + fixed_size_list_array_to_vector(as_fixed_size_list_array(col.as_ref()), &mut chunk.array_vector(i))?; |
238 | 241 | }
|
239 | 242 | DataType::Struct(_) => {
|
240 | 243 | let struct_array = as_struct_array(col.as_ref());
|
@@ -455,33 +458,21 @@ fn list_array_to_vector<O: OffsetSizeTrait + AsPrimitive<usize>>(
|
455 | 458 |
|
456 | 459 | fn fixed_size_list_array_to_vector(
|
457 | 460 | array: &FixedSizeListArray,
|
458 |
| - out: &mut ListVector, |
| 461 | + out: &mut ArrayVector, |
459 | 462 | ) -> Result<(), Box<dyn std::error::Error>> {
|
460 | 463 | let value_array = array.values();
|
461 | 464 | let mut child = out.child(value_array.len());
|
462 | 465 | match value_array.data_type() {
|
463 | 466 | dt if dt.is_primitive() => {
|
464 | 467 | primitive_array_to_vector(value_array.as_ref(), &mut child)?;
|
465 |
| - for i in 0..array.len() { |
466 |
| - let offset = array.value_offset(i); |
467 |
| - let length = array.value_length(); |
468 |
| - out.set_entry(i, offset as usize, length as usize); |
469 |
| - } |
470 |
| - out.set_len(value_array.len()); |
471 | 468 | }
|
472 | 469 | DataType::Utf8 => {
|
473 | 470 | string_array_to_vector(as_string_array(value_array.as_ref()), &mut child);
|
474 | 471 | }
|
475 | 472 | _ => {
|
476 |
| - return Err("Nested list is not supported yet.".into()); |
| 473 | + return Err("Nested array is not supported yet.".into()); |
477 | 474 | }
|
478 | 475 | }
|
479 |
| - for i in 0..array.len() { |
480 |
| - let offset = array.value_offset(i); |
481 |
| - let length = array.value_length(); |
482 |
| - out.set_entry(i, offset as usize, length as usize); |
483 |
| - } |
484 |
| - out.set_len(value_array.len()); |
485 | 476 |
|
486 | 477 | Ok(())
|
487 | 478 | }
|
@@ -511,7 +502,7 @@ fn struct_array_to_vector(array: &StructArray, out: &mut StructVector) -> Result
|
511 | 502 | DataType::FixedSizeList(_, _) => {
|
512 | 503 | fixed_size_list_array_to_vector(
|
513 | 504 | as_fixed_size_list_array(column.as_ref()),
|
514 |
| - &mut out.list_vector_child(i), |
| 505 | + &mut out.array_vector_child(i), |
515 | 506 | )?;
|
516 | 507 | }
|
517 | 508 | DataType::Struct(_) => {
|
@@ -569,10 +560,10 @@ mod test {
|
569 | 560 | use crate::{Connection, Result};
|
570 | 561 | use arrow::{
|
571 | 562 | array::{
|
572 |
| - Array, ArrayRef, AsArray, Date32Array, Date64Array, Decimal256Array, Float64Array, GenericListArray, |
573 |
| - Int32Array, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, Time32SecondArray, |
574 |
| - Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, |
575 |
| - TimestampSecondArray, |
| 563 | + Array, ArrayRef, AsArray, Date32Array, Date64Array, Decimal256Array, FixedSizeListArray, Float64Array, |
| 564 | + GenericListArray, Int32Array, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, |
| 565 | + Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, |
| 566 | + TimestampNanosecondArray, TimestampSecondArray, |
576 | 567 | },
|
577 | 568 | buffer::{OffsetBuffer, ScalarBuffer},
|
578 | 569 | datatypes::{i256, ArrowPrimitiveType, DataType, Field, Fields, Schema},
|
@@ -760,6 +751,50 @@ mod test {
|
760 | 751 | Ok(())
|
761 | 752 | }
|
762 | 753 |
|
| 754 | + //field: FieldRef, size: i32, values: ArrayRef, nulls: Option<NullBuffer> |
| 755 | + #[test] |
| 756 | + fn test_fixed_array_roundtrip() -> Result<(), Box<dyn Error>> { |
| 757 | + let array = FixedSizeListArray::new( |
| 758 | + Arc::new(Field::new("item", DataType::Int32, true)), |
| 759 | + 2, |
| 760 | + Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3), Some(4), Some(5)])), |
| 761 | + None, |
| 762 | + ); |
| 763 | + |
| 764 | + let expected_output_array = array.clone(); |
| 765 | + |
| 766 | + let db = Connection::open_in_memory()?; |
| 767 | + db.register_table_function::<ArrowVTab>("arrow")?; |
| 768 | + |
| 769 | + // Roundtrip a record batch from Rust to DuckDB and back to Rust |
| 770 | + let schema = Schema::new(vec![Field::new("a", array.data_type().clone(), false)]); |
| 771 | + |
| 772 | + let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())])?; |
| 773 | + let param = arrow_recordbatch_to_query_params(rb); |
| 774 | + let mut stmt = db.prepare("select a from arrow(?, ?)")?; |
| 775 | + let rb = stmt.query_arrow(param)?.next().expect("no record batch"); |
| 776 | + |
| 777 | + let output_any_array = rb.column(0); |
| 778 | + assert!(output_any_array |
| 779 | + .data_type() |
| 780 | + .equals_datatype(expected_output_array.data_type())); |
| 781 | + |
| 782 | + match output_any_array.as_fixed_size_list_opt() { |
| 783 | + Some(output_array) => { |
| 784 | + assert_eq!(output_array.len(), expected_output_array.len()); |
| 785 | + for i in 0..output_array.len() { |
| 786 | + assert_eq!(output_array.is_valid(i), expected_output_array.is_valid(i)); |
| 787 | + if output_array.is_valid(i) { |
| 788 | + assert!(expected_output_array.value(i).eq(&output_array.value(i))); |
| 789 | + } |
| 790 | + } |
| 791 | + } |
| 792 | + None => panic!("Expected FixedSizeListArray"), |
| 793 | + } |
| 794 | + |
| 795 | + Ok(()) |
| 796 | + } |
| 797 | + |
763 | 798 | #[test]
|
764 | 799 | fn test_primitive_roundtrip_contains_nulls() -> Result<(), Box<dyn Error>> {
|
765 | 800 | let mut builder = arrow::array::PrimitiveBuilder::<arrow::datatypes::Int32Type>::new();
|
|
0 commit comments