11// SPDX-License-Identifier: Apache-2.0
22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
4+ use std:: iter;
45use std:: ops:: Deref ;
56
67use num_traits:: AsPrimitive ;
78use vortex_buffer:: Buffer ;
89use vortex_dtype:: match_each_integer_ptype;
910use vortex_error:: VortexResult ;
11+ use vortex_mask:: AllOr ;
12+ use vortex_mask:: Mask ;
1013use vortex_vector:: binaryview:: BinaryView ;
1114
1215use crate :: arrays:: { VarBinViewArray , VarBinViewVTable } ;
@@ -17,16 +20,16 @@ use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
1720/// Take involves creating a new array that references the old array, just with the given set of views.
1821impl TakeKernel for VarBinViewVTable {
1922 fn take ( & self , array : & VarBinViewArray , indices : & dyn Array ) -> VortexResult < ArrayRef > {
20- // Compute the new validity
21-
22- // This is valid since all elements (of all arrays) even null values must be inside
23- // min-max valid range.
23+ // Compute the new validity.
2424 let validity = array. validity ( ) . take ( indices) ?;
2525 let indices = indices. to_primitive ( ) ;
2626
2727 let views_buffer = match_each_integer_ptype ! ( indices. ptype( ) , |I | {
28- // This is valid since all elements even null values are inside the min-max valid range.
29- take_views( array. views( ) , indices. as_slice:: <I >( ) )
28+ take_views(
29+ array. views( ) ,
30+ indices. as_slice:: <I >( ) ,
31+ & indices. validity_mask( ) ,
32+ )
3033 } ) ;
3134
3235 // SAFETY: taking all components at same indices maintains invariants
@@ -49,15 +52,36 @@ register_kernel!(TakeKernelAdapter(VarBinViewVTable).lift());
4952fn take_views < I : AsPrimitive < usize > > (
5053 views : & Buffer < BinaryView > ,
5154 indices : & [ I ] ,
55+ mask : & Mask ,
5256) -> Buffer < BinaryView > {
5357 // NOTE(ngates): this deref is not actually trivial, so we run it once.
5458 let views_ref = views. deref ( ) ;
55- Buffer :: < BinaryView > :: from_trusted_len_iter ( indices. iter ( ) . map ( |i| views_ref[ i. as_ ( ) ] ) )
59+ // We do not use iter_bools directly, since the resulting dyn iterator cannot
60+ // implement TrustedLen.
61+ match mask. bit_buffer ( ) {
62+ AllOr :: All => {
63+ Buffer :: < BinaryView > :: from_trusted_len_iter ( indices. iter ( ) . map ( |i| views_ref[ i. as_ ( ) ] ) )
64+ }
65+ AllOr :: None => Buffer :: < BinaryView > :: from_trusted_len_iter ( iter:: repeat_n (
66+ BinaryView :: default ( ) ,
67+ indices. len ( ) ,
68+ ) ) ,
69+ AllOr :: Some ( buffer) => Buffer :: < BinaryView > :: from_trusted_len_iter (
70+ buffer. iter ( ) . zip ( indices. iter ( ) ) . map ( |( valid, idx) | {
71+ if valid {
72+ views_ref[ idx. as_ ( ) ]
73+ } else {
74+ BinaryView :: default ( )
75+ }
76+ } ) ,
77+ ) ,
78+ }
5679}
5780
5881#[ cfg( test) ]
5982mod tests {
6083 use rstest:: rstest;
84+ use vortex_buffer:: BitBuffer ;
6185 use vortex_buffer:: buffer;
6286 use vortex_dtype:: DType ;
6387 use vortex_dtype:: Nullability :: NonNullable ;
@@ -69,6 +93,7 @@ mod tests {
6993 use crate :: canonical:: ToCanonical ;
7094 use crate :: compute:: conformance:: take:: test_take_conformance;
7195 use crate :: compute:: take;
96+ use crate :: validity:: Validity ;
7297
7398 #[ test]
7499 fn take_nullable ( ) {
@@ -96,11 +121,13 @@ mod tests {
96121 fn take_nullable_indices ( ) {
97122 let arr = VarBinViewArray :: from_iter ( [ "one" , "two" ] . map ( Some ) , DType :: Utf8 ( NonNullable ) ) ;
98123
99- let taken = take (
100- arr. as_ref ( ) ,
101- PrimitiveArray :: from_option_iter ( vec ! [ Some ( 1 ) , None ] ) . as_ref ( ) ,
102- )
103- . unwrap ( ) ;
124+ let indices = PrimitiveArray :: new (
125+ // Verify that garbage values at NULL indices are ignored.
126+ buffer ! [ 1u64 , 999 ] ,
127+ Validity :: from ( BitBuffer :: from ( vec ! [ true , false ] ) ) ,
128+ ) ;
129+
130+ let taken = take ( arr. as_ref ( ) , indices. as_ref ( ) ) . unwrap ( ) ;
104131
105132 assert ! ( taken. dtype( ) . is_nullable( ) ) ;
106133 assert_eq ! (
0 commit comments