Skip to content

Commit 88dd455

Browse files
authored
Support Arrow type LargeUtf8. (#341)
* support LargeUtf8 * lint * fix tests * Fix tests check_generic_byte_roundtrip * fix test * fix clippy
1 parent 1c5e7cd commit 88dd455

File tree

3 files changed

+76
-13
lines changed

3 files changed

+76
-13
lines changed

crates/duckdb/src/error.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ impl From<::std::ffi::NulError> for Error {
122122
}
123123
}
124124

125-
const UNKNOWN_COLUMN: usize = std::usize::MAX;
125+
const UNKNOWN_COLUMN: usize = usize::MAX;
126126

127127
/// The conversion isn't precise, but it's convenient to have it
128128
/// to allow use of `get_raw(…).as_…()?` in callbacks that take `Error`.

crates/duckdb/src/types/mod.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -261,10 +261,7 @@ impl fmt::Display for Type {
261261
mod test {
262262
use super::Value;
263263
use crate::{params, Connection, Error, Result, Statement};
264-
use std::{
265-
f64::EPSILON,
266-
os::raw::{c_double, c_int},
267-
};
264+
use std::os::raw::{c_double, c_int};
268265

269266
fn checked_memory_handle() -> Result<Connection> {
270267
let db = Connection::open_in_memory()?;
@@ -385,7 +382,7 @@ mod test {
385382
assert_eq!(vec![1, 2], row.get::<_, Vec<u8>>(0)?);
386383
assert_eq!("text", row.get::<_, String>(1)?);
387384
assert_eq!(1, row.get::<_, c_int>(2)?);
388-
assert!((1.5 - row.get::<_, c_double>(3)?).abs() < EPSILON);
385+
assert!((1.5 - row.get::<_, c_double>(3)?).abs() < f64::EPSILON);
389386
assert_eq!(row.get::<_, Option<c_int>>(4)?, None);
390387
assert_eq!(row.get::<_, Option<c_double>>(4)?, None);
391388
assert_eq!(row.get::<_, Option<String>>(4)?, None);
@@ -453,7 +450,7 @@ mod test {
453450
assert_eq!(Value::Text(String::from("text")), row.get::<_, Value>(1)?);
454451
assert_eq!(Value::Int(1), row.get::<_, Value>(2)?);
455452
match row.get::<_, Value>(3)? {
456-
Value::Float(val) => assert!((1.5 - val).abs() < EPSILON as f32),
453+
Value::Float(val) => assert!((1.5 - val).abs() < f32::EPSILON),
457454
x => panic!("Invalid Value {x:?}"),
458455
}
459456
assert_eq!(Value::Null, row.get::<_, Value>(4)?);

crates/duckdb/src/vtab/arrow.rs

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use crate::vtab::vector::Inserter;
88
use arrow::array::{
99
as_boolean_array, as_generic_binary_array, as_large_list_array, as_list_array, as_primitive_array, as_string_array,
1010
as_struct_array, Array, ArrayData, AsArray, BinaryArray, BooleanArray, Decimal128Array, FixedSizeListArray,
11-
GenericListArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray,
11+
GenericListArray, GenericStringArray, LargeStringArray, OffsetSizeTrait, PrimitiveArray, StructArray,
1212
};
1313

1414
use arrow::{
@@ -229,6 +229,15 @@ pub fn record_batch_to_duckdb_data_chunk(
229229
DataType::Utf8 => {
230230
string_array_to_vector(as_string_array(col.as_ref()), &mut chunk.flat_vector(i));
231231
}
232+
DataType::LargeUtf8 => {
233+
string_array_to_vector(
234+
col.as_ref()
235+
.as_any()
236+
.downcast_ref::<LargeStringArray>()
237+
.ok_or_else(|| Box::<dyn std::error::Error>::from("Unable to downcast to LargeStringArray"))?,
238+
&mut chunk.flat_vector(i),
239+
);
240+
}
232241
DataType::Binary => {
233242
binary_array_to_vector(as_generic_binary_array(col.as_ref()), &mut chunk.flat_vector(i));
234243
}
@@ -453,7 +462,7 @@ fn boolean_array_to_vector(array: &BooleanArray, out: &mut FlatVector) {
453462
}
454463
}
455464

456-
fn string_array_to_vector(array: &StringArray, out: &mut FlatVector) {
465+
fn string_array_to_vector<O: OffsetSizeTrait>(array: &GenericStringArray<O>, out: &mut FlatVector) {
457466
assert!(array.len() <= out.capacity());
458467

459468
// TODO: zero copy assignment
@@ -612,12 +621,12 @@ mod test {
612621
use arrow::{
613622
array::{
614623
Array, ArrayRef, AsArray, BinaryArray, Date32Array, Date64Array, Decimal128Array, Decimal256Array,
615-
FixedSizeListArray, GenericListArray, Int32Array, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray,
616-
StructArray, Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray,
617-
TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray,
624+
FixedSizeListArray, GenericByteArray, GenericListArray, Int32Array, LargeStringArray, ListArray,
625+
OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, Time32SecondArray, Time64MicrosecondArray,
626+
TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray,
618627
},
619628
buffer::{OffsetBuffer, ScalarBuffer},
620-
datatypes::{i256, ArrowPrimitiveType, DataType, Field, Fields, Schema},
629+
datatypes::{i256, ArrowPrimitiveType, ByteArrayType, DataType, Field, Fields, Schema},
621630
record_batch::RecordBatch,
622631
};
623632
use std::{error::Error, sync::Arc};
@@ -784,6 +793,48 @@ mod test {
784793
Ok(())
785794
}
786795

796+
fn check_generic_byte_roundtrip<T1, T2>(
797+
arry_in: GenericByteArray<T1>,
798+
arry_out: GenericByteArray<T2>,
799+
) -> Result<(), Box<dyn Error>>
800+
where
801+
T1: ByteArrayType,
802+
T2: ByteArrayType,
803+
{
804+
let db = Connection::open_in_memory()?;
805+
db.register_table_function::<ArrowVTab>("arrow")?;
806+
807+
// Roundtrip a record batch from Rust to DuckDB and back to Rust
808+
let schema = Schema::new(vec![Field::new("a", arry_in.data_type().clone(), false)]);
809+
810+
let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arry_in.clone())])?;
811+
let param = arrow_recordbatch_to_query_params(rb);
812+
let mut stmt = db.prepare("select a from arrow(?, ?)")?;
813+
let rb = stmt.query_arrow(param)?.next().expect("no record batch");
814+
815+
let output_any_array = rb.column(0);
816+
817+
assert!(
818+
output_any_array.data_type().equals_datatype(arry_out.data_type()),
819+
"{} != {}",
820+
output_any_array.data_type(),
821+
arry_out.data_type()
822+
);
823+
824+
match output_any_array.as_bytes_opt::<T2>() {
825+
Some(output_array) => {
826+
assert_eq!(output_array.len(), arry_out.len());
827+
for i in 0..output_array.len() {
828+
assert_eq!(output_array.is_valid(i), arry_out.is_valid(i));
829+
assert_eq!(output_array.value_data(), arry_out.value_data())
830+
}
831+
}
832+
None => panic!("Expected GenericByteArray"),
833+
}
834+
835+
Ok(())
836+
}
837+
787838
#[test]
788839
fn test_array_roundtrip() -> Result<(), Box<dyn Error>> {
789840
check_generic_array_roundtrip(ListArray::new(
@@ -862,6 +913,21 @@ mod test {
862913
Ok(())
863914
}
864915

916+
#[test]
917+
fn test_utf8_roundtrip() -> Result<(), Box<dyn Error>> {
918+
check_generic_byte_roundtrip(
919+
StringArray::from(vec![Some("foo"), Some("Baz"), Some("bar")]),
920+
StringArray::from(vec![Some("foo"), Some("Baz"), Some("bar")]),
921+
)?;
922+
923+
// [`LargeStringArray`] will be downcasted to [`StringArray`].
924+
check_generic_byte_roundtrip(
925+
LargeStringArray::from(vec![Some("foo"), Some("Baz"), Some("bar")]),
926+
StringArray::from(vec![Some("foo"), Some("Baz"), Some("bar")]),
927+
)?;
928+
Ok(())
929+
}
930+
865931
#[test]
866932
fn test_timestamp_roundtrip() -> Result<(), Box<dyn Error>> {
867933
check_rust_primitive_array_roundtrip(Int32Array::from(vec![1, 2, 3]), Int32Array::from(vec![1, 2, 3]))?;

0 commit comments

Comments
 (0)