Skip to content

Commit d803d25

Browse files
committed
fix[vortex-array]: fix take on varbinviews with NULL indices
Index values at NULL positions were used. In the case that these were garbage, an OOB panic would occur. Signed-off-by: Alfonso Subiotto Marques <[email protected]>
1 parent 8a20090 commit d803d25

File tree

2 files changed

+42
-12
lines changed

2 files changed

+42
-12
lines changed

vortex-array/src/arrays/varbinview/compute/take.rs

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use std::iter;
45
use std::ops::Deref;
56

67
use num_traits::AsPrimitive;
78
use vortex_buffer::Buffer;
89
use vortex_dtype::match_each_integer_ptype;
910
use vortex_error::VortexResult;
11+
use vortex_mask::AllOr;
12+
use vortex_mask::Mask;
1013
use vortex_vector::binaryview::BinaryView;
1114

1215
use crate::arrays::{VarBinViewArray, VarBinViewVTable};
@@ -17,16 +20,16 @@ use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
1720
/// Take involves creating a new array that references the old array, just with the given set of views.
1821
impl TakeKernel for VarBinViewVTable {
1922
fn take(&self, array: &VarBinViewArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
20-
// Compute the new validity
21-
22-
// This is valid since all elements (of all arrays) even null values must be inside
23-
// min-max valid range.
23+
// Compute the new validity.
2424
let validity = array.validity().take(indices)?;
2525
let indices = indices.to_primitive();
2626

2727
let views_buffer = match_each_integer_ptype!(indices.ptype(), |I| {
28-
// This is valid since all elements even null values are inside the min-max valid range.
29-
take_views(array.views(), indices.as_slice::<I>())
28+
take_views(
29+
array.views(),
30+
indices.as_slice::<I>(),
31+
&indices.validity_mask(),
32+
)
3033
});
3134

3235
// SAFETY: taking all components at same indices maintains invariants
@@ -49,15 +52,36 @@ register_kernel!(TakeKernelAdapter(VarBinViewVTable).lift());
4952
fn take_views<I: AsPrimitive<usize>>(
5053
views: &Buffer<BinaryView>,
5154
indices: &[I],
55+
mask: &Mask,
5256
) -> Buffer<BinaryView> {
5357
// NOTE(ngates): this deref is not actually trivial, so we run it once.
5458
let views_ref = views.deref();
55-
Buffer::<BinaryView>::from_trusted_len_iter(indices.iter().map(|i| views_ref[i.as_()]))
59+
// We do not use iter_bools directly, since the resulting dyn iterator cannot
60+
// implement TrustedLen.
61+
match mask.bit_buffer() {
62+
AllOr::All => {
63+
Buffer::<BinaryView>::from_trusted_len_iter(indices.iter().map(|i| views_ref[i.as_()]))
64+
}
65+
AllOr::None => Buffer::<BinaryView>::from_trusted_len_iter(iter::repeat_n(
66+
BinaryView::default(),
67+
indices.len(),
68+
)),
69+
AllOr::Some(buffer) => Buffer::<BinaryView>::from_trusted_len_iter(
70+
buffer.iter().zip(indices.iter()).map(|(valid, idx)| {
71+
if valid {
72+
views_ref[idx.as_()]
73+
} else {
74+
BinaryView::default()
75+
}
76+
}),
77+
),
78+
}
5679
}
5780

5881
#[cfg(test)]
5982
mod tests {
6083
use rstest::rstest;
84+
use vortex_buffer::BitBuffer;
6185
use vortex_buffer::buffer;
6286
use vortex_dtype::DType;
6387
use vortex_dtype::Nullability::NonNullable;
@@ -69,6 +93,7 @@ mod tests {
6993
use crate::canonical::ToCanonical;
7094
use crate::compute::conformance::take::test_take_conformance;
7195
use crate::compute::take;
96+
use crate::validity::Validity;
7297

7398
#[test]
7499
fn take_nullable() {
@@ -96,11 +121,13 @@ mod tests {
96121
fn take_nullable_indices() {
97122
let arr = VarBinViewArray::from_iter(["one", "two"].map(Some), DType::Utf8(NonNullable));
98123

99-
let taken = take(
100-
arr.as_ref(),
101-
PrimitiveArray::from_option_iter(vec![Some(1), None]).as_ref(),
102-
)
103-
.unwrap();
124+
let indices = PrimitiveArray::new(
125+
// Verify that garbage values at NULL indices are ignored.
126+
buffer![1u64, 999],
127+
Validity::from(BitBuffer::from(vec![true, false])),
128+
);
129+
130+
let taken = take(arr.as_ref(), indices.as_ref()).unwrap();
104131

105132
assert!(taken.dtype().is_nullable());
106133
assert_eq!(

vortex-buffer/src/trusted_len.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,9 @@ where
161161
{
162162
}
163163

164+
unsafe impl<T: Clone> TrustedLen for std::iter::RepeatN<T> {}
165+
164166
// Arrow bit iterators
165167
unsafe impl<'a> TrustedLen for crate::bit::BitChunkIterator<'a> {}
166168
unsafe impl<'a> TrustedLen for crate::bit::UnalignedBitChunkIterator<'a> {}
169+
unsafe impl<'a> TrustedLen for crate::bit::BitIterator<'a> {}

0 commit comments

Comments
 (0)