Skip to content

Commit d110924

Browse files
authored
fix: arrow vtab panic (#293)
1 parent 2e638f1 commit d110924

File tree

2 files changed

+77
-37
lines changed

2 files changed

+77
-37
lines changed

libduckdb-sys/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,4 @@ serde_json = { version = "1.0" }
3939
tar = "0.4.38"
4040

4141
[dev-dependencies]
42-
arrow = { version = "49", default-features = false, features = ["ffi"] }
42+
arrow = { version = "51", default-features = false, features = ["ffi"] }

src/vtab/arrow.rs

Lines changed: 76 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use super::{
22
vector::{FlatVector, ListVector, Vector},
33
BindInfo, DataChunk, Free, FunctionInfo, InitInfo, LogicalType, LogicalTypeId, StructVector, VTab,
44
};
5+
use std::ptr::null_mut;
56

67
use crate::vtab::vector::Inserter;
78
use arrow::array::{
@@ -74,8 +75,11 @@ impl VTab for ArrowVTab {
7475
type InitData = ArrowInitData;
7576

7677
unsafe fn bind(bind: &BindInfo, data: *mut ArrowBindData) -> Result<(), Box<dyn std::error::Error>> {
78+
(*data).rb = null_mut();
7779
let param_count = bind.get_parameter_count();
78-
assert!(param_count == 2);
80+
if param_count != 2 {
81+
return Err(format!("Bad param count: {param_count}, expected 2").into());
82+
}
7983
let array = bind.get_parameter(0).to_int64();
8084
let schema = bind.get_parameter(1).to_int64();
8185
unsafe {
@@ -106,6 +110,7 @@ impl VTab for ArrowVTab {
106110
output.set_len(0);
107111
} else {
108112
let rb = Box::from_raw((*bind_info).rb);
113+
(*bind_info).rb = null_mut(); // erase ref in case of failure in record_batch_to_duckdb_data_chunk
109114
record_batch_to_duckdb_data_chunk(&rb, output)?;
110115
(*bind_info).rb = Box::into_raw(rb);
111116
(*init_info).done = true;
@@ -156,7 +161,7 @@ pub fn to_duckdb_type_id(data_type: &DataType) -> Result<LogicalTypeId, Box<dyn
156161
DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _) => List,
157162
DataType::Struct(_) => Struct,
158163
DataType::Union(_, _) => Union,
159-
DataType::Dictionary(_, _) => todo!(),
164+
// DataType::Dictionary(_, _) => todo!(),
160165
// duckdb/src/main/capi/helper-c.cpp does not support decimal
161166
// DataType::Decimal128(_, _) => Decimal,
162167
// DataType::Decimal256(_, _) => Decimal,
@@ -194,8 +199,9 @@ pub fn to_duckdb_logical_type(data_type: &DataType) -> Result<LogicalType, Box<d
194199
} else if let DataType::FixedSizeList(child, _) = data_type {
195200
Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?))
196201
} else {
197-
unimplemented!(
198-
"Unsupported data type: {data_type}, please file an issue https://github.com/wangfenjin/duckdb-rs"
202+
Err(
203+
format!("Unsupported data type: {data_type}, please file an issue https://github.com/wangfenjin/duckdb-rs")
204+
.into(),
199205
)
200206
}
201207
}
@@ -216,30 +222,31 @@ pub fn record_batch_to_duckdb_data_chunk(
216222
let col = batch.column(i);
217223
match col.data_type() {
218224
dt if dt.is_primitive() || matches!(dt, DataType::Boolean) => {
219-
primitive_array_to_vector(col, &mut chunk.flat_vector(i));
225+
primitive_array_to_vector(col, &mut chunk.flat_vector(i))?;
220226
}
221227
DataType::Utf8 => {
222228
string_array_to_vector(as_string_array(col.as_ref()), &mut chunk.flat_vector(i));
223229
}
224230
DataType::List(_) => {
225-
list_array_to_vector(as_list_array(col.as_ref()), &mut chunk.list_vector(i));
231+
list_array_to_vector(as_list_array(col.as_ref()), &mut chunk.list_vector(i))?;
226232
}
227233
DataType::LargeList(_) => {
228-
list_array_to_vector(as_large_list_array(col.as_ref()), &mut chunk.list_vector(i));
234+
list_array_to_vector(as_large_list_array(col.as_ref()), &mut chunk.list_vector(i))?;
229235
}
230236
DataType::FixedSizeList(_, _) => {
231-
fixed_size_list_array_to_vector(as_fixed_size_list_array(col.as_ref()), &mut chunk.list_vector(i));
237+
fixed_size_list_array_to_vector(as_fixed_size_list_array(col.as_ref()), &mut chunk.list_vector(i))?;
232238
}
233239
DataType::Struct(_) => {
234240
let struct_array = as_struct_array(col.as_ref());
235241
let mut struct_vector = chunk.struct_vector(i);
236-
struct_array_to_vector(struct_array, &mut struct_vector);
242+
struct_array_to_vector(struct_array, &mut struct_vector)?;
237243
}
238244
_ => {
239-
unimplemented!(
245+
return Err(format!(
240246
"column {} is not supported yet, please file an issue https://github.com/wangfenjin/duckdb-rs",
241247
batch.schema().field(i)
242-
);
248+
)
249+
.into());
243250
}
244251
}
245252
}
@@ -262,7 +269,7 @@ fn primitive_array_to_flat_vector_cast<T: ArrowPrimitiveType>(
262269
out_vector.copy::<T::Native>(array.as_primitive::<T>().values());
263270
}
264271

265-
fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) {
272+
fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) -> Result<(), Box<dyn std::error::Error>> {
266273
match array.data_type() {
267274
DataType::Boolean => {
268275
boolean_array_to_vector(as_boolean_array(array), out.as_mut_any().downcast_mut().unwrap());
@@ -315,7 +322,6 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) {
315322
out.as_mut_any().downcast_mut().unwrap(),
316323
);
317324
}
318-
DataType::Float16 => todo!("Float16 is not supported yet"),
319325
DataType::Float32 => {
320326
primitive_array_to_flat_vector::<Float32Type>(
321327
as_primitive_array(array),
@@ -337,7 +343,6 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) {
337343
out.as_mut_any().downcast_mut().unwrap(),
338344
);
339345
}
340-
DataType::Decimal256(_, _) => todo!("Decimal256 is not supported yet"),
341346

342347
// DuckDB Only supports timetamp_tz in microsecond precision
343348
DataType::Timestamp(_, Some(tz)) => primitive_array_to_flat_vector_cast::<TimestampMicrosecondType>(
@@ -376,11 +381,9 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) {
376381
DataType::Time64(_) => {
377382
primitive_array_to_flat_vector_cast::<Time64MicrosecondType>(Time64MicrosecondType::DATA_TYPE, array, out)
378383
}
379-
_ => todo!(
380-
"Converting '{dtype:#?}' to primitive flat vector is not supported",
381-
dtype = array.data_type()
382-
),
384+
datatype => return Err(format!("Data type \"{datatype}\" not yet supported by ArrowVTab").into()),
383385
}
386+
Ok(())
384387
}
385388

386389
/// Convert Arrow [Decimal128Array] to a duckdb vector.
@@ -410,31 +413,38 @@ fn string_array_to_vector(array: &StringArray, out: &mut FlatVector) {
410413
}
411414
}
412415

413-
fn list_array_to_vector<O: OffsetSizeTrait + AsPrimitive<usize>>(array: &GenericListArray<O>, out: &mut ListVector) {
416+
fn list_array_to_vector<O: OffsetSizeTrait + AsPrimitive<usize>>(
417+
array: &GenericListArray<O>,
418+
out: &mut ListVector,
419+
) -> Result<(), Box<dyn std::error::Error>> {
414420
let value_array = array.values();
415421
let mut child = out.child(value_array.len());
416422
match value_array.data_type() {
417423
dt if dt.is_primitive() => {
418-
primitive_array_to_vector(value_array.as_ref(), &mut child);
424+
primitive_array_to_vector(value_array.as_ref(), &mut child)?;
419425
for i in 0..array.len() {
420426
let offset = array.value_offsets()[i];
421427
let length = array.value_length(i);
422428
out.set_entry(i, offset.as_(), length.as_());
423429
}
424430
}
425431
_ => {
426-
println!("Nested list is not supported yet.");
427-
todo!()
432+
return Err("Nested list is not supported yet.".into());
428433
}
429434
}
435+
436+
Ok(())
430437
}
431438

432-
fn fixed_size_list_array_to_vector(array: &FixedSizeListArray, out: &mut ListVector) {
439+
fn fixed_size_list_array_to_vector(
440+
array: &FixedSizeListArray,
441+
out: &mut ListVector,
442+
) -> Result<(), Box<dyn std::error::Error>> {
433443
let value_array = array.values();
434444
let mut child = out.child(value_array.len());
435445
match value_array.data_type() {
436446
dt if dt.is_primitive() => {
437-
primitive_array_to_vector(value_array.as_ref(), &mut child);
447+
primitive_array_to_vector(value_array.as_ref(), &mut child)?;
438448
for i in 0..array.len() {
439449
let offset = array.value_offset(i);
440450
let length = array.value_length();
@@ -443,10 +453,11 @@ fn fixed_size_list_array_to_vector(array: &FixedSizeListArray, out: &mut ListVec
443453
out.set_len(value_array.len());
444454
}
445455
_ => {
446-
println!("Nested list is not supported yet.");
447-
todo!()
456+
return Err("Nested list is not supported yet.".into());
448457
}
449458
}
459+
460+
Ok(())
450461
}
451462

452463
/// Force downcast of an [`Array`], such as an [`ArrayRef`], to
@@ -455,32 +466,32 @@ fn as_fixed_size_list_array(arr: &dyn Array) -> &FixedSizeListArray {
455466
arr.as_any().downcast_ref::<FixedSizeListArray>().unwrap()
456467
}
457468

458-
fn struct_array_to_vector(array: &StructArray, out: &mut StructVector) {
469+
fn struct_array_to_vector(array: &StructArray, out: &mut StructVector) -> Result<(), Box<dyn std::error::Error>> {
459470
for i in 0..array.num_columns() {
460471
let column = array.column(i);
461472
match column.data_type() {
462473
dt if dt.is_primitive() || matches!(dt, DataType::Boolean) => {
463-
primitive_array_to_vector(column, &mut out.child(i));
474+
primitive_array_to_vector(column, &mut out.child(i))?;
464475
}
465476
DataType::Utf8 => {
466477
string_array_to_vector(as_string_array(column.as_ref()), &mut out.child(i));
467478
}
468479
DataType::List(_) => {
469-
list_array_to_vector(as_list_array(column.as_ref()), &mut out.list_vector_child(i));
480+
list_array_to_vector(as_list_array(column.as_ref()), &mut out.list_vector_child(i))?;
470481
}
471482
DataType::LargeList(_) => {
472-
list_array_to_vector(as_large_list_array(column.as_ref()), &mut out.list_vector_child(i));
483+
list_array_to_vector(as_large_list_array(column.as_ref()), &mut out.list_vector_child(i))?;
473484
}
474485
DataType::FixedSizeList(_, _) => {
475486
fixed_size_list_array_to_vector(
476487
as_fixed_size_list_array(column.as_ref()),
477488
&mut out.list_vector_child(i),
478-
);
489+
)?;
479490
}
480491
DataType::Struct(_) => {
481492
let struct_array = as_struct_array(column.as_ref());
482493
let mut struct_vector = out.struct_vector_child(i);
483-
struct_array_to_vector(struct_array, &mut struct_vector);
494+
struct_array_to_vector(struct_array, &mut struct_vector)?;
484495
}
485496
_ => {
486497
unimplemented!(
@@ -490,6 +501,7 @@ fn struct_array_to_vector(array: &StructArray, out: &mut StructVector) {
490501
}
491502
}
492503
}
504+
Ok(())
493505
}
494506

495507
/// Pass RecordBatch to duckdb.
@@ -531,11 +543,11 @@ mod test {
531543
use crate::{Connection, Result};
532544
use arrow::{
533545
array::{
534-
Array, ArrayRef, AsArray, Date32Array, Date64Array, Float64Array, Int32Array, PrimitiveArray, StringArray,
535-
StructArray, Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray,
536-
TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray,
546+
Array, ArrayRef, AsArray, Date32Array, Date64Array, Decimal256Array, Float64Array, Int32Array,
547+
PrimitiveArray, StringArray, StructArray, Time32SecondArray, Time64MicrosecondArray,
548+
TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray,
537549
},
538-
datatypes::{ArrowPrimitiveType, DataType, Field, Fields, Schema},
550+
datatypes::{i256, ArrowPrimitiveType, DataType, Field, Fields, Schema},
539551
record_batch::RecordBatch,
540552
};
541553
use std::{error::Error, sync::Arc};
@@ -749,4 +761,32 @@ mod test {
749761
assert_eq!(column.value(0), "TIMESTAMP WITH TIME ZONE");
750762
Ok(())
751763
}
764+
765+
#[test]
766+
fn test_arrow_error() {
767+
let arc: ArrayRef = Arc::new(Decimal256Array::from(vec![i256::from(1), i256::from(2), i256::from(3)]));
768+
let batch = RecordBatch::try_from_iter(vec![("x", arc)]).unwrap();
769+
770+
let db = Connection::open_in_memory().unwrap();
771+
db.register_table_function::<ArrowVTab>("arrow").unwrap();
772+
773+
let mut stmt = db.prepare("SELECT * FROM arrow(?, ?)").unwrap();
774+
775+
let res = match stmt.execute(arrow_recordbatch_to_query_params(batch)) {
776+
Ok(..) => None,
777+
Err(e) => Some(e),
778+
}
779+
.unwrap();
780+
781+
assert_eq!(
782+
res,
783+
crate::error::Error::DuckDBFailure(
784+
crate::ffi::Error {
785+
code: crate::ffi::ErrorCode::Unknown,
786+
extended_code: 1
787+
},
788+
Some("Invalid Input Error: Data type \"Decimal256(76, 10)\" not yet supported by ArrowVTab".to_owned())
789+
)
790+
);
791+
}
752792
}

0 commit comments

Comments
 (0)