Skip to content

Commit 02a0f3e

Browse files
y-f-upeasee
andauthored
Set nulls correctly for all type of arrays/vectors (#344)
* Set nulls for all possible arrays * set nulls for all possible array to vectors * add more set nulls * wip * only change flat vector * Revert "only change flat vector" This reverts commit 90c9d75. * add list vector nulls * add tests to cover set_nulls * fix test * fix clippy * clippy --------- Co-authored-by: peasee <[email protected]>
1 parent 44e0ff1 commit 02a0f3e

File tree

4 files changed

+137
-24
lines changed

4 files changed

+137
-24
lines changed

crates/duckdb/src/core/data_chunk.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ impl Drop for DataChunkHandle {
2626
}
2727

2828
impl DataChunkHandle {
29+
#[allow(dead_code)]
2930
pub(crate) unsafe fn new_unowned(ptr: duckdb_data_chunk) -> Self {
3031
Self { ptr, owned: false }
3132
}

crates/duckdb/src/core/vector.rs

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,15 @@ impl ListVector {
173173
self.entries.as_mut_slice::<duckdb_list_entry>()[idx].length = length as u64;
174174
}
175175

176+
/// Set row as null
177+
pub fn set_null(&mut self, row: usize) {
178+
unsafe {
179+
duckdb_vector_ensure_validity_writable(self.entries.ptr);
180+
let idx = duckdb_vector_get_validity(self.entries.ptr);
181+
duckdb_validity_set_row_invalid(idx, row as u64);
182+
}
183+
}
184+
176185
/// Reserve the capacity for its child node.
177186
fn reserve(&self, capacity: usize) {
178187
unsafe {
@@ -190,7 +199,6 @@ impl ListVector {
190199

191200
/// A array vector. (fixed-size list)
192201
pub struct ArrayVector {
193-
/// ArrayVector does not own the vector pointer.
194202
ptr: duckdb_vector,
195203
}
196204

@@ -223,11 +231,19 @@ impl ArrayVector {
223231
pub fn set_child<T: Copy>(&self, data: &[T]) {
224232
self.child(data.len()).copy(data);
225233
}
234+
235+
/// Set row as null
236+
pub fn set_null(&mut self, row: usize) {
237+
unsafe {
238+
duckdb_vector_ensure_validity_writable(self.ptr);
239+
let idx = duckdb_vector_get_validity(self.ptr);
240+
duckdb_validity_set_row_invalid(idx, row as u64);
241+
}
242+
}
226243
}
227244

228245
/// A struct vector.
229246
pub struct StructVector {
230-
/// ListVector does not own the vector pointer.
231247
ptr: duckdb_vector,
232248
}
233249

@@ -277,4 +293,13 @@ impl StructVector {
277293
let logical_type = self.logical_type();
278294
unsafe { duckdb_struct_type_child_count(logical_type.ptr) as usize }
279295
}
296+
297+
/// Set row as null
298+
pub fn set_null(&mut self, row: usize) {
299+
unsafe {
300+
duckdb_vector_ensure_validity_writable(self.ptr);
301+
let idx = duckdb_vector_get_validity(self.ptr);
302+
duckdb_validity_set_row_invalid(idx, row as u64);
303+
}
304+
}
280305
}

crates/duckdb/src/vtab/arrow.rs

Lines changed: 108 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -268,13 +268,7 @@ pub fn record_batch_to_duckdb_data_chunk(
268268
fn primitive_array_to_flat_vector<T: ArrowPrimitiveType>(array: &PrimitiveArray<T>, out_vector: &mut FlatVector) {
269269
// assert!(array.len() <= out_vector.capacity());
270270
out_vector.copy::<T::Native>(array.values());
271-
if let Some(nulls) = array.nulls() {
272-
for (i, null) in nulls.into_iter().enumerate() {
273-
if !null {
274-
out_vector.set_null(i);
275-
}
276-
}
277-
}
271+
set_nulls_in_flat_vector(array, out_vector);
278272
}
279273

280274
fn primitive_array_to_flat_vector_cast<T: ArrowPrimitiveType>(
@@ -285,13 +279,7 @@ fn primitive_array_to_flat_vector_cast<T: ArrowPrimitiveType>(
285279
let array = arrow::compute::kernels::cast::cast(array, &data_type).unwrap();
286280
let out_vector: &mut FlatVector = out_vector.as_mut_any().downcast_mut().unwrap();
287281
out_vector.copy::<T::Native>(array.as_primitive::<T>().values());
288-
if let Some(nulls) = array.nulls() {
289-
for (i, null) in nulls.iter().enumerate() {
290-
if !null {
291-
out_vector.set_null(i);
292-
}
293-
}
294-
}
282+
set_nulls_in_flat_vector(&array, out_vector);
295283
}
296284

297285
fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) -> Result<(), Box<dyn std::error::Error>> {
@@ -441,13 +429,7 @@ fn decimal_array_to_vector(array: &Decimal128Array, out: &mut FlatVector, width:
441429
}
442430

443431
// Set nulls
444-
if let Some(nulls) = array.nulls() {
445-
for (i, null) in nulls.into_iter().enumerate() {
446-
if !null {
447-
out.set_null(i);
448-
}
449-
}
450-
}
432+
set_nulls_in_flat_vector(array, out);
451433
}
452434

453435
/// Convert Arrow [BooleanArray] to a duckdb vector.
@@ -457,6 +439,7 @@ fn boolean_array_to_vector(array: &BooleanArray, out: &mut FlatVector) {
457439
for i in 0..array.len() {
458440
out.as_mut_slice()[i] = array.value(i);
459441
}
442+
set_nulls_in_flat_vector(array, out);
460443
}
461444

462445
fn string_array_to_vector<O: OffsetSizeTrait>(array: &GenericStringArray<O>, out: &mut FlatVector) {
@@ -467,6 +450,7 @@ fn string_array_to_vector<O: OffsetSizeTrait>(array: &GenericStringArray<O>, out
467450
let s = array.value(i);
468451
out.insert(i, s);
469452
}
453+
set_nulls_in_flat_vector(array, out);
470454
}
471455

472456
fn binary_array_to_vector(array: &BinaryArray, out: &mut FlatVector) {
@@ -476,6 +460,7 @@ fn binary_array_to_vector(array: &BinaryArray, out: &mut FlatVector) {
476460
let s = array.value(i);
477461
out.insert(i, s);
478462
}
463+
set_nulls_in_flat_vector(array, out);
479464
}
480465

481466
fn list_array_to_vector<O: OffsetSizeTrait + AsPrimitive<usize>>(
@@ -504,6 +489,8 @@ fn list_array_to_vector<O: OffsetSizeTrait + AsPrimitive<usize>>(
504489
let length = array.value_length(i);
505490
out.set_entry(i, offset.as_(), length.as_());
506491
}
492+
set_nulls_in_list_vector(array, out);
493+
507494
Ok(())
508495
}
509496

@@ -528,6 +515,8 @@ fn fixed_size_list_array_to_vector(
528515
}
529516
}
530517

518+
set_nulls_in_array_vector(array, out);
519+
531520
Ok(())
532521
}
533522

@@ -575,6 +564,7 @@ fn struct_array_to_vector(array: &StructArray, out: &mut StructVector) -> Result
575564
}
576565
}
577566
}
567+
set_nulls_in_struct_vector(array, out);
578568
Ok(())
579569
}
580570

@@ -611,6 +601,46 @@ pub fn arrow_ffi_to_query_params(array: FFI_ArrowArray, schema: FFI_ArrowSchema)
611601
[arr as *mut _ as usize, sch as *mut _ as usize]
612602
}
613603

604+
fn set_nulls_in_flat_vector(array: &dyn Array, out_vector: &mut FlatVector) {
605+
if let Some(nulls) = array.nulls() {
606+
for (i, null) in nulls.into_iter().enumerate() {
607+
if !null {
608+
out_vector.set_null(i);
609+
}
610+
}
611+
}
612+
}
613+
614+
fn set_nulls_in_struct_vector(array: &dyn Array, out_vector: &mut StructVector) {
615+
if let Some(nulls) = array.nulls() {
616+
for (i, null) in nulls.into_iter().enumerate() {
617+
if !null {
618+
out_vector.set_null(i);
619+
}
620+
}
621+
}
622+
}
623+
624+
fn set_nulls_in_array_vector(array: &dyn Array, out_vector: &mut ArrayVector) {
625+
if let Some(nulls) = array.nulls() {
626+
for (i, null) in nulls.into_iter().enumerate() {
627+
if !null {
628+
out_vector.set_null(i);
629+
}
630+
}
631+
}
632+
}
633+
634+
fn set_nulls_in_list_vector(array: &dyn Array, out_vector: &mut ListVector) {
635+
if let Some(nulls) = array.nulls() {
636+
for (i, null) in nulls.into_iter().enumerate() {
637+
if !null {
638+
out_vector.set_null(i);
639+
}
640+
}
641+
}
642+
}
643+
614644
#[cfg(test)]
615645
mod test {
616646
use super::{arrow_recordbatch_to_query_params, ArrowVTab};
@@ -705,6 +735,44 @@ mod test {
705735
Ok(())
706736
}
707737

738+
#[test]
739+
fn test_append_struct_contains_null() -> Result<(), Box<dyn Error>> {
740+
let db = Connection::open_in_memory()?;
741+
db.execute_batch("CREATE TABLE t1 (s STRUCT(v VARCHAR, i INTEGER))")?;
742+
{
743+
let struct_array = StructArray::try_new(
744+
vec![
745+
Arc::new(Field::new("v", DataType::Utf8, true)),
746+
Arc::new(Field::new("i", DataType::Int32, true)),
747+
]
748+
.into(),
749+
vec![
750+
Arc::new(StringArray::from(vec![Some("foo"), Some("bar")])) as ArrayRef,
751+
Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef,
752+
],
753+
Some(vec![true, false].into()),
754+
)?;
755+
756+
let schema = Schema::new(vec![Field::new(
757+
"s",
758+
DataType::Struct(Fields::from(vec![
759+
Field::new("v", DataType::Utf8, true),
760+
Field::new("i", DataType::Int32, true),
761+
])),
762+
true,
763+
)]);
764+
765+
let record_batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_array)])?;
766+
let mut app = db.appender("t1")?;
767+
app.append_record_batch(record_batch)?;
768+
}
769+
let mut stmt = db.prepare("SELECT s FROM t1 where s IS NOT NULL")?;
770+
let rbs: Vec<RecordBatch> = stmt.query_arrow([])?.collect();
771+
assert_eq!(rbs.iter().map(|op| op.num_rows()).sum::<usize>(), 1);
772+
773+
Ok(())
774+
}
775+
708776
fn check_rust_primitive_array_roundtrip<T1, T2>(
709777
input_array: PrimitiveArray<T1>,
710778
expected_array: PrimitiveArray<T2>,
@@ -762,7 +830,7 @@ mod test {
762830
db.register_table_function::<ArrowVTab>("arrow")?;
763831

764832
// Roundtrip a record batch from Rust to DuckDB and back to Rust
765-
let schema = Schema::new(vec![Field::new("a", arry.data_type().clone(), false)]);
833+
let schema = Schema::new(vec![Field::new("a", arry.data_type().clone(), true)]);
766834

767835
let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arry.clone())])?;
768836
let param = arrow_recordbatch_to_query_params(rb);
@@ -910,6 +978,24 @@ mod test {
910978
Ok(())
911979
}
912980

981+
#[test]
982+
fn test_check_generic_array_roundtrip_contains_null() -> Result<(), Box<dyn Error>> {
983+
check_generic_array_roundtrip(ListArray::new(
984+
Arc::new(Field::new("item", DataType::Utf8, true)),
985+
OffsetBuffer::new(ScalarBuffer::from(vec![0, 2, 4, 5])),
986+
Arc::new(StringArray::from(vec![
987+
Some("foo"),
988+
Some("baz"),
989+
Some("bar"),
990+
Some("foo"),
991+
Some("baz"),
992+
])),
993+
Some(vec![true, false, true].into()),
994+
))?;
995+
996+
Ok(())
997+
}
998+
913999
#[test]
9141000
fn test_utf8_roundtrip() -> Result<(), Box<dyn Error>> {
9151001
check_generic_byte_roundtrip(

crates/libduckdb-sys/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ buildtime_bindgen = ["bindgen", "pkg-config", "vcpkg"]
2222
json = ["bundled"]
2323
parquet = ["bundled"]
2424
extensions-full = ["json", "parquet"]
25+
winduckdb = []
2526

2627
[dependencies]
2728

0 commit comments

Comments
 (0)