Skip to content

Commit 6030134

Browse files
Add Utf8View & BinaryView to supported Arrow vtab types (#397)
* Add Utf8View & BinaryView to supported Arrow vtab types (#10) * Add Utf8View & BinaryView to supported Arrow vtab types * Fix implementation * Support lists * Add tests * Fix linting issues
1 parent 4fdffea commit 6030134

File tree

1 file changed

+187
-9
lines changed

1 file changed

+187
-9
lines changed

crates/duckdb/src/vtab/arrow.rs

Lines changed: 187 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@ use crate::core::{ArrayVector, FlatVector, Inserter, ListVector, StructVector, V
55
use arrow::{
66
array::{
77
as_boolean_array, as_generic_binary_array, as_large_list_array, as_list_array, as_primitive_array,
8-
as_string_array, as_struct_array, Array, ArrayData, AsArray, BinaryArray, BooleanArray, Decimal128Array,
9-
FixedSizeBinaryArray, FixedSizeListArray, GenericListArray, GenericStringArray, IntervalMonthDayNanoArray,
10-
LargeBinaryArray, LargeStringArray, OffsetSizeTrait, PrimitiveArray, StructArray,
8+
as_string_array, as_struct_array, Array, ArrayData, AsArray, BinaryArray, BinaryViewArray, BooleanArray,
9+
Decimal128Array, FixedSizeBinaryArray, FixedSizeListArray, GenericListArray, GenericStringArray,
10+
IntervalMonthDayNanoArray, LargeBinaryArray, LargeStringArray, OffsetSizeTrait, PrimitiveArray,
11+
StringViewArray, StructArray,
1112
},
1213
compute::cast,
1314
};
@@ -157,8 +158,8 @@ pub fn to_duckdb_type_id(data_type: &DataType) -> Result<LogicalTypeId, Box<dyn
157158
DataType::Time64(_) => Time,
158159
DataType::Duration(_) => Interval,
159160
DataType::Interval(_) => Interval,
160-
DataType::Binary | DataType::LargeBinary | DataType::FixedSizeBinary(_) => Blob,
161-
DataType::Utf8 | DataType::LargeUtf8 => Varchar,
161+
DataType::Binary | DataType::LargeBinary | DataType::FixedSizeBinary(_) | DataType::BinaryView => Blob,
162+
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Varchar,
162163
DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _) => List,
163164
DataType::Struct(_) => Struct,
164165
DataType::Union(_, _) => Union,
@@ -201,8 +202,10 @@ pub fn to_duckdb_logical_type(data_type: &DataType) -> Result<LogicalTypeHandle,
201202
DataType::Boolean
202203
| DataType::Utf8
203204
| DataType::LargeUtf8
205+
| DataType::Utf8View
204206
| DataType::Binary
205207
| DataType::LargeBinary
208+
| DataType::BinaryView
206209
| DataType::FixedSizeBinary(_) => Ok(LogicalTypeHandle::from(to_duckdb_type_id(data_type)?)),
207210
dtype if dtype.is_primitive() => Ok(LogicalTypeHandle::from(to_duckdb_type_id(data_type)?)),
208211
_ => Err(format!(
@@ -242,6 +245,15 @@ pub fn record_batch_to_duckdb_data_chunk(
242245
&mut chunk.flat_vector(i),
243246
);
244247
}
248+
DataType::Utf8View => {
249+
string_view_array_to_vector(
250+
col.as_ref()
251+
.as_any()
252+
.downcast_ref::<StringViewArray>()
253+
.ok_or_else(|| Box::<dyn std::error::Error>::from("Unable to downcast to StringViewArray"))?,
254+
&mut chunk.flat_vector(i),
255+
);
256+
}
245257
DataType::Binary => {
246258
binary_array_to_vector(as_generic_binary_array(col.as_ref()), &mut chunk.flat_vector(i));
247259
}
@@ -257,6 +269,15 @@ pub fn record_batch_to_duckdb_data_chunk(
257269
&mut chunk.flat_vector(i),
258270
);
259271
}
272+
DataType::BinaryView => {
273+
binary_view_array_to_vector(
274+
col.as_ref()
275+
.as_any()
276+
.downcast_ref::<BinaryViewArray>()
277+
.ok_or_else(|| Box::<dyn std::error::Error>::from("Unable to downcast to BinaryViewArray"))?,
278+
&mut chunk.flat_vector(i),
279+
);
280+
}
260281
DataType::List(_) => {
261282
list_array_to_vector(as_list_array(col.as_ref()), &mut chunk.list_vector(i))?;
262283
}
@@ -486,6 +507,16 @@ fn string_array_to_vector<O: OffsetSizeTrait>(array: &GenericStringArray<O>, out
486507
set_nulls_in_flat_vector(array, out);
487508
}
488509

510+
fn string_view_array_to_vector(array: &StringViewArray, out: &mut FlatVector) {
511+
assert!(array.len() <= out.capacity());
512+
513+
for i in 0..array.len() {
514+
let s = array.value(i);
515+
out.insert(i, s);
516+
}
517+
set_nulls_in_flat_vector(array, out);
518+
}
519+
489520
fn binary_array_to_vector(array: &BinaryArray, out: &mut FlatVector) {
490521
assert!(array.len() <= out.capacity());
491522

@@ -496,6 +527,16 @@ fn binary_array_to_vector(array: &BinaryArray, out: &mut FlatVector) {
496527
set_nulls_in_flat_vector(array, out);
497528
}
498529

530+
fn binary_view_array_to_vector(array: &BinaryViewArray, out: &mut FlatVector) {
531+
assert!(array.len() <= out.capacity());
532+
533+
for i in 0..array.len() {
534+
let s = array.value(i);
535+
out.insert(i, s);
536+
}
537+
set_nulls_in_flat_vector(array, out);
538+
}
539+
499540
fn fixed_size_binary_array_to_vector(array: &FixedSizeBinaryArray, out: &mut FlatVector) {
500541
assert!(array.len() <= out.capacity());
501542

@@ -531,9 +572,29 @@ fn list_array_to_vector<O: OffsetSizeTrait + AsPrimitive<usize>>(
531572
DataType::Utf8 => {
532573
string_array_to_vector(as_string_array(value_array.as_ref()), &mut child);
533574
}
575+
DataType::Utf8View => {
576+
string_view_array_to_vector(
577+
value_array
578+
.as_ref()
579+
.as_any()
580+
.downcast_ref::<StringViewArray>()
581+
.ok_or_else(|| Box::<dyn std::error::Error>::from("Unable to downcast to StringViewArray"))?,
582+
&mut child,
583+
);
584+
}
534585
DataType::Binary => {
535586
binary_array_to_vector(as_generic_binary_array(value_array.as_ref()), &mut child);
536587
}
588+
DataType::BinaryView => {
589+
binary_view_array_to_vector(
590+
value_array
591+
.as_ref()
592+
.as_any()
593+
.downcast_ref::<BinaryViewArray>()
594+
.ok_or_else(|| Box::<dyn std::error::Error>::from("Unable to downcast to BinaryViewArray"))?,
595+
&mut child,
596+
);
597+
}
537598
_ => {
538599
return Err("Nested list is not supported yet.".into());
539600
}
@@ -702,11 +763,12 @@ mod test {
702763
use crate::{Connection, Result};
703764
use arrow::{
704765
array::{
705-
Array, ArrayRef, AsArray, BinaryArray, Date32Array, Date64Array, Decimal128Array, Decimal256Array,
706-
DurationSecondArray, FixedSizeListArray, GenericByteArray, GenericListArray, Int32Array,
766+
Array, ArrayRef, AsArray, BinaryArray, BinaryViewArray, Date32Array, Date64Array, Decimal128Array,
767+
Decimal256Array, DurationSecondArray, FixedSizeListArray, GenericByteArray, GenericListArray, Int32Array,
707768
IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeStringArray, ListArray,
708-
OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, Time32SecondArray, Time64MicrosecondArray,
709-
TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray,
769+
OffsetSizeTrait, PrimitiveArray, StringArray, StringViewArray, StructArray, Time32SecondArray,
770+
Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
771+
TimestampSecondArray,
710772
},
711773
buffer::{OffsetBuffer, ScalarBuffer},
712774
datatypes::{
@@ -1264,4 +1326,120 @@ mod test {
12641326
assert_eq!(column.len(), 1);
12651327
assert_eq!(column.value(0), b"test");
12661328
}
1329+
1330+
#[test]
1331+
fn test_string_view_roundtrip() -> Result<(), Box<dyn Error>> {
1332+
let db = Connection::open_in_memory()?;
1333+
db.register_table_function::<ArrowVTab>("arrow")?;
1334+
1335+
let array = StringViewArray::from(vec![Some("foo"), Some("bar"), Some("baz")]);
1336+
let schema = Schema::new(vec![Field::new("a", array.data_type().clone(), false)]);
1337+
let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())])?;
1338+
1339+
let param = arrow_recordbatch_to_query_params(rb);
1340+
let mut stmt = db.prepare("select a from arrow(?, ?)")?;
1341+
let rb = stmt.query_arrow(param)?.next().expect("no record batch");
1342+
1343+
let output_array = rb
1344+
.column(0)
1345+
.as_any()
1346+
.downcast_ref::<StringArray>()
1347+
.expect("Expected StringArray");
1348+
1349+
assert_eq!(output_array.len(), 3);
1350+
assert_eq!(output_array.value(0), "foo");
1351+
assert_eq!(output_array.value(1), "bar");
1352+
assert_eq!(output_array.value(2), "baz");
1353+
1354+
Ok(())
1355+
}
1356+
1357+
#[test]
1358+
fn test_binary_view_roundtrip() -> Result<(), Box<dyn Error>> {
1359+
let db = Connection::open_in_memory()?;
1360+
db.register_table_function::<ArrowVTab>("arrow")?;
1361+
1362+
let array = BinaryViewArray::from(vec![
1363+
Some(b"hello".as_ref()),
1364+
Some(b"world".as_ref()),
1365+
Some(b"!".as_ref()),
1366+
]);
1367+
let schema = Schema::new(vec![Field::new("a", array.data_type().clone(), false)]);
1368+
let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())])?;
1369+
1370+
let param = arrow_recordbatch_to_query_params(rb);
1371+
let mut stmt = db.prepare("select a from arrow(?, ?)")?;
1372+
let rb = stmt.query_arrow(param)?.next().expect("no record batch");
1373+
1374+
let output_array = rb
1375+
.column(0)
1376+
.as_any()
1377+
.downcast_ref::<BinaryArray>()
1378+
.expect("Expected BinaryArray");
1379+
1380+
assert_eq!(output_array.len(), 3);
1381+
assert_eq!(output_array.value(0), b"hello");
1382+
assert_eq!(output_array.value(1), b"world");
1383+
assert_eq!(output_array.value(2), b"!");
1384+
1385+
Ok(())
1386+
}
1387+
1388+
#[test]
1389+
fn test_string_view_nulls_roundtrip() -> Result<(), Box<dyn Error>> {
1390+
let db = Connection::open_in_memory()?;
1391+
db.register_table_function::<ArrowVTab>("arrow")?;
1392+
1393+
let array = StringViewArray::from(vec![Some("foo"), None, Some("baz")]);
1394+
let schema = Schema::new(vec![Field::new("a", array.data_type().clone(), true)]);
1395+
let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())])?;
1396+
1397+
let param = arrow_recordbatch_to_query_params(rb);
1398+
let mut stmt = db.prepare("select a from arrow(?, ?)")?;
1399+
let rb = stmt.query_arrow(param)?.next().expect("no record batch");
1400+
1401+
let output_array = rb
1402+
.column(0)
1403+
.as_any()
1404+
.downcast_ref::<StringArray>()
1405+
.expect("Expected StringArray");
1406+
1407+
assert_eq!(output_array.len(), 3);
1408+
assert!(output_array.is_valid(0));
1409+
assert!(!output_array.is_valid(1));
1410+
assert!(output_array.is_valid(2));
1411+
assert_eq!(output_array.value(0), "foo");
1412+
assert_eq!(output_array.value(2), "baz");
1413+
1414+
Ok(())
1415+
}
1416+
1417+
#[test]
1418+
fn test_binary_view_nulls_roundtrip() -> Result<(), Box<dyn Error>> {
1419+
let db = Connection::open_in_memory()?;
1420+
db.register_table_function::<ArrowVTab>("arrow")?;
1421+
1422+
let array = BinaryViewArray::from(vec![Some(b"hello".as_ref()), None, Some(b"!".as_ref())]);
1423+
let schema = Schema::new(vec![Field::new("a", array.data_type().clone(), true)]);
1424+
let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())])?;
1425+
1426+
let param = arrow_recordbatch_to_query_params(rb);
1427+
let mut stmt = db.prepare("select a from arrow(?, ?)")?;
1428+
let rb = stmt.query_arrow(param)?.next().expect("no record batch");
1429+
1430+
let output_array = rb
1431+
.column(0)
1432+
.as_any()
1433+
.downcast_ref::<BinaryArray>()
1434+
.expect("Expected BinaryArray");
1435+
1436+
assert_eq!(output_array.len(), 3);
1437+
assert!(output_array.is_valid(0));
1438+
assert!(!output_array.is_valid(1));
1439+
assert!(output_array.is_valid(2));
1440+
assert_eq!(output_array.value(0), b"hello");
1441+
assert_eq!(output_array.value(2), b"!");
1442+
1443+
Ok(())
1444+
}
12671445
}

0 commit comments

Comments
 (0)