Skip to content

Commit f1fb2b1

Browse files
authored
fix take kernel null handling on structs (apache#531)
This closes apache#530. Co-authored-by: Ben Chambers <[email protected]>
1 parent 6538fe5 commit f1fb2b1

File tree

1 file changed

+88
-63
lines changed

1 file changed

+88
-63
lines changed

arrow/src/compute/kernels/take.rs

Lines changed: 88 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,22 @@ where
231231
.map(|a| take_impl(a.as_ref(), indices, Some(options.clone())))
232232
.collect();
233233
let arrays = arrays?;
234-
let pairs: Vec<(Field, ArrayRef)> =
234+
let fields: Vec<(Field, ArrayRef)> =
235235
fields.clone().into_iter().zip(arrays).collect();
236-
Ok(Arc::new(StructArray::from(pairs)) as ArrayRef)
236+
237+
// Create the null bit buffer.
238+
let is_valid: Buffer = indices
239+
.iter()
240+
.map(|index| {
241+
if let Some(index) = index {
242+
struct_.is_valid(ArrowNativeType::to_usize(&index).unwrap())
243+
} else {
244+
false
245+
}
246+
})
247+
.collect();
248+
249+
Ok(Arc::new(StructArray::from((fields, is_valid))) as ArrayRef)
237250
}
238251
DataType::Dictionary(key_type, _) => match key_type.as_ref() {
239252
DataType::Int8 => downcast_dict_take!(Int8Type, values, indices),
@@ -848,20 +861,34 @@ mod tests {
848861
}
849862

850863
// create a simple struct for testing purposes
851-
fn create_test_struct() -> StructArray {
852-
let boolean_data = BooleanArray::from(vec![true, false, false, true])
853-
.data()
854-
.clone();
855-
let int_data = Int32Array::from(vec![42, 28, 19, 31]).data().clone();
856-
let mut field_types = vec![];
857-
field_types.push(Field::new("a", DataType::Boolean, true));
858-
field_types.push(Field::new("b", DataType::Int32, true));
859-
let struct_array_data = ArrayData::builder(DataType::Struct(field_types))
860-
.len(4)
861-
.add_child_data(boolean_data)
862-
.add_child_data(int_data)
863-
.build();
864-
StructArray::from(struct_array_data)
864+
fn create_test_struct(
865+
values: Vec<Option<(Option<bool>, Option<i32>)>>,
866+
) -> StructArray {
867+
let mut struct_builder = StructBuilder::new(
868+
vec![
869+
Field::new("a", DataType::Boolean, true),
870+
Field::new("b", DataType::Int32, true),
871+
],
872+
vec![
873+
Box::new(BooleanBuilder::new(values.len())),
874+
Box::new(Int32Builder::new(values.len())),
875+
],
876+
);
877+
878+
for value in values {
879+
struct_builder
880+
.field_builder::<BooleanBuilder>(0)
881+
.unwrap()
882+
.append_option(value.and_then(|v| v.0))
883+
.unwrap();
884+
struct_builder
885+
.field_builder::<Int32Builder>(1)
886+
.unwrap()
887+
.append_option(value.and_then(|v| v.1))
888+
.unwrap();
889+
struct_builder.append(value.is_some()).unwrap();
890+
}
891+
struct_builder.finish()
865892
}
866893

867894
#[test]
@@ -1576,61 +1603,59 @@ mod tests {
15761603

15771604
#[test]
15781605
fn test_take_struct() {
1579-
let array = create_test_struct();
1580-
1581-
let index = UInt32Array::from(vec![0, 3, 1, 0, 2]);
1582-
let a = take(&array, &index, None).unwrap();
1583-
let a: &StructArray = a.as_any().downcast_ref::<StructArray>().unwrap();
1584-
assert_eq!(index.len(), a.len());
1585-
assert_eq!(0, a.null_count());
1606+
let array = create_test_struct(vec![
1607+
Some((Some(true), Some(42))),
1608+
Some((Some(false), Some(28))),
1609+
Some((Some(false), Some(19))),
1610+
Some((Some(true), Some(31))),
1611+
None,
1612+
]);
15861613

1587-
let expected_bool_data = BooleanArray::from(vec![true, true, false, true, false])
1588-
.data()
1589-
.clone();
1590-
let expected_int_data = Int32Array::from(vec![42, 31, 28, 42, 19]).data().clone();
1591-
let mut field_types = vec![];
1592-
field_types.push(Field::new("a", DataType::Boolean, true));
1593-
field_types.push(Field::new("b", DataType::Int32, true));
1594-
let struct_array_data = ArrayData::builder(DataType::Struct(field_types))
1595-
.len(5)
1596-
.add_child_data(expected_bool_data)
1597-
.add_child_data(expected_int_data)
1598-
.build();
1599-
let struct_array = StructArray::from(struct_array_data);
1614+
let index = UInt32Array::from(vec![0, 3, 1, 0, 2, 4]);
1615+
let actual = take(&array, &index, None).unwrap();
1616+
let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
1617+
assert_eq!(index.len(), actual.len());
1618+
assert_eq!(1, actual.null_count());
1619+
1620+
let expected = create_test_struct(vec![
1621+
Some((Some(true), Some(42))),
1622+
Some((Some(true), Some(31))),
1623+
Some((Some(false), Some(28))),
1624+
Some((Some(true), Some(42))),
1625+
Some((Some(false), Some(19))),
1626+
None,
1627+
]);
16001628

1601-
assert_eq!(a, &struct_array);
1629+
assert_eq!(&expected, actual);
16021630
}
16031631

16041632
#[test]
1605-
fn test_take_struct_with_nulls() {
1606-
let array = create_test_struct();
1633+
fn test_take_struct_with_null_indices() {
1634+
let array = create_test_struct(vec![
1635+
Some((Some(true), Some(42))),
1636+
Some((Some(false), Some(28))),
1637+
Some((Some(false), Some(19))),
1638+
Some((Some(true), Some(31))),
1639+
None,
1640+
]);
16071641

1608-
let index = UInt32Array::from(vec![None, Some(3), Some(1), None, Some(0)]);
1609-
let a = take(&array, &index, None).unwrap();
1610-
let a: &StructArray = a.as_any().downcast_ref::<StructArray>().unwrap();
1611-
assert_eq!(index.len(), a.len());
1612-
assert_eq!(0, a.null_count());
1642+
let index =
1643+
UInt32Array::from(vec![None, Some(3), Some(1), None, Some(0), Some(4)]);
1644+
let actual = take(&array, &index, None).unwrap();
1645+
let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
1646+
assert_eq!(index.len(), actual.len());
1647+
assert_eq!(3, actual.null_count()); // 2 because of indices, 1 because of struct array
16131648

1614-
let expected_bool_data =
1615-
BooleanArray::from(vec![None, Some(true), Some(false), None, Some(true)])
1616-
.data()
1617-
.clone();
1618-
let expected_int_data =
1619-
Int32Array::from(vec![None, Some(31), Some(28), None, Some(42)])
1620-
.data()
1621-
.clone();
1649+
let expected = create_test_struct(vec![
1650+
None,
1651+
Some((Some(true), Some(31))),
1652+
Some((Some(false), Some(28))),
1653+
None,
1654+
Some((Some(true), Some(42))),
1655+
None,
1656+
]);
16221657

1623-
let mut field_types = vec![];
1624-
field_types.push(Field::new("a", DataType::Boolean, true));
1625-
field_types.push(Field::new("b", DataType::Int32, true));
1626-
let struct_array_data = ArrayData::builder(DataType::Struct(field_types))
1627-
.len(5)
1628-
// TODO: see https://issues.apache.org/jira/browse/ARROW-5408 for why count != 2
1629-
.add_child_data(expected_bool_data)
1630-
.add_child_data(expected_int_data)
1631-
.build();
1632-
let struct_array = StructArray::from(struct_array_data);
1633-
assert_eq!(a, &struct_array);
1658+
assert_eq!(&expected, actual);
16341659
}
16351660

16361661
#[test]

0 commit comments

Comments
 (0)