Skip to content

Commit 46afdd8

Browse files
authored
Preserve null dictionary values in interleave and concat kernels (#7144)
* fix(select): preserve null values in `merge_dictionary_values` This function internally computes value masks describing which values from input dictionaries should remain in the output. Values never referenced by keys are considered redundant. Null values were considered redundant, but they are now preserved as of this commit. This change is necessary because keys can reference null values. Before this commit, the entries of `MergedDictionaries::key_mappings` corresponding to null values were left unset. This caused `concat` and `interleave` to remap all elements referencing them to whatever value at index 0, producing an erroneous result. * test(select): add test case `concat::test_string_dictionary_array_nulls_in_values` This test case passes dictionary arrays containing null values (but no null keys) to `concat`. * test(select): add test case `interleave::test_interleave_dictionary_nulls` This test case passes two dictionary arrays each containing null values or keys to `interleave`. * refactor(select): add type alias for `Interner` bucket Addresses `clippy::type-complexity`.
1 parent d0a2301 commit 46afdd8

File tree

3 files changed

+71
-12
lines changed

3 files changed

+71
-12
lines changed

arrow-select/src/concat.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,24 @@ mod tests {
739739
)
740740
}
741741

742+
#[test]
743+
fn test_string_dictionary_array_nulls_in_values() {
744+
let input_1_keys = Int32Array::from_iter_values([0, 2, 1, 3]);
745+
let input_1_values = StringArray::from(vec![Some("foo"), None, Some("bar"), Some("fiz")]);
746+
let input_1 = DictionaryArray::new(input_1_keys, Arc::new(input_1_values));
747+
748+
let input_2_keys = Int32Array::from_iter_values([0]);
749+
let input_2_values = StringArray::from(vec![None, Some("hello")]);
750+
let input_2 = DictionaryArray::new(input_2_keys, Arc::new(input_2_values));
751+
752+
let expected = vec![Some("foo"), Some("bar"), None, Some("fiz"), None];
753+
754+
let concat = concat(&[&input_1 as _, &input_2 as _]).unwrap();
755+
let dictionary = concat.as_dictionary::<Int32Type>();
756+
let actual = collect_string_dictionary(dictionary);
757+
assert_eq!(actual, expected);
758+
}
759+
742760
#[test]
743761
fn test_string_dictionary_merge() {
744762
let mut builder = StringDictionaryBuilder::<Int32Type>::new();

arrow-select/src/dictionary.rs

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,13 @@ use arrow_schema::{ArrowError, DataType};
3232
/// Hash collisions will result in replacement
3333
struct Interner<'a, V> {
3434
state: RandomState,
35-
buckets: Vec<Option<(&'a [u8], V)>>,
35+
buckets: Vec<Option<InternerBucket<'a, V>>>,
3636
shift: u32,
3737
}
3838

39+
/// A single bucket in [`Interner`].
40+
type InternerBucket<'a, V> = (Option<&'a [u8]>, V);
41+
3942
impl<'a, V> Interner<'a, V> {
4043
/// Capacity controls the number of unique buckets allocated within the Interner
4144
///
@@ -54,7 +57,11 @@ impl<'a, V> Interner<'a, V> {
5457
}
5558
}
5659

57-
fn intern<F: FnOnce() -> Result<V, E>, E>(&mut self, new: &'a [u8], f: F) -> Result<&V, E> {
60+
fn intern<F: FnOnce() -> Result<V, E>, E>(
61+
&mut self,
62+
new: Option<&'a [u8]>,
63+
f: F,
64+
) -> Result<&V, E> {
5865
let hash = self.state.hash_one(new);
5966
let bucket_idx = hash >> self.shift;
6067
Ok(match &mut self.buckets[bucket_idx as usize] {
@@ -151,15 +158,19 @@ pub fn merge_dictionary_values<K: ArrowDictionaryKeyType>(
151158

152159
for (idx, dictionary) in dictionaries.iter().enumerate() {
153160
let mask = masks.and_then(|m| m.get(idx));
154-
let key_mask = match (dictionary.logical_nulls(), mask) {
155-
(Some(n), None) => Some(n.into_inner()),
156-
(None, Some(n)) => Some(n.clone()),
157-
(Some(n), Some(m)) => Some(n.inner() & m),
161+
let key_mask_owned;
162+
let key_mask = match (dictionary.nulls(), mask) {
163+
(Some(n), None) => Some(n.inner()),
164+
(None, Some(n)) => Some(n),
165+
(Some(n), Some(m)) => {
166+
key_mask_owned = n.inner() & m;
167+
Some(&key_mask_owned)
168+
}
158169
(None, None) => None,
159170
};
160171
let keys = dictionary.keys().values();
161172
let values = dictionary.values().as_ref();
162-
let values_mask = compute_values_mask(keys, key_mask.as_ref(), values.len());
173+
let values_mask = compute_values_mask(keys, key_mask, values.len());
163174

164175
let masked_values = get_masked_values(values, &values_mask);
165176
num_values += masked_values.len();
@@ -223,7 +234,10 @@ fn compute_values_mask<K: ArrowNativeType>(
223234
}
224235

225236
/// Return a Vec containing for each set index in `mask`, the index and byte value of that index
226-
fn get_masked_values<'a>(array: &'a dyn Array, mask: &BooleanBuffer) -> Vec<(usize, &'a [u8])> {
237+
fn get_masked_values<'a>(
238+
array: &'a dyn Array,
239+
mask: &BooleanBuffer,
240+
) -> Vec<(usize, Option<&'a [u8]>)> {
227241
match array.data_type() {
228242
DataType::Utf8 => masked_bytes(array.as_string::<i32>(), mask),
229243
DataType::LargeUtf8 => masked_bytes(array.as_string::<i64>(), mask),
@@ -239,10 +253,13 @@ fn get_masked_values<'a>(array: &'a dyn Array, mask: &BooleanBuffer) -> Vec<(usi
239253
fn masked_bytes<'a, T: ByteArrayType>(
240254
array: &'a GenericByteArray<T>,
241255
mask: &BooleanBuffer,
242-
) -> Vec<(usize, &'a [u8])> {
256+
) -> Vec<(usize, Option<&'a [u8]>)> {
243257
let mut out = Vec::with_capacity(mask.count_set_bits());
244258
for idx in mask.set_indices() {
245-
out.push((idx, array.value(idx).as_ref()))
259+
out.push((
260+
idx,
261+
array.is_valid(idx).then_some(array.value(idx).as_ref()),
262+
))
246263
}
247264
out
248265
}
@@ -311,10 +328,10 @@ mod tests {
311328
let b = DictionaryArray::new(Int32Array::new_null(10), Arc::new(StringArray::new_null(0)));
312329

313330
let merged = merge_dictionary_values(&[&a, &b], None).unwrap();
314-
let expected = StringArray::from(vec!["bingo", "hello"]);
331+
let expected = StringArray::from(vec![None, Some("bingo"), Some("hello")]);
315332
assert_eq!(merged.values.as_ref(), &expected);
316333
assert_eq!(merged.key_mappings.len(), 2);
317-
assert_eq!(&merged.key_mappings[0], &[0, 0, 0, 1, 0]);
334+
assert_eq!(&merged.key_mappings[0], &[0, 0, 1, 2, 0]);
318335
assert_eq!(&merged.key_mappings[1], &[] as &[i32; 0]);
319336
}
320337

arrow-select/src/interleave.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,30 @@ mod tests {
441441
assert_eq!(&collected, &["c", "c", "c"]);
442442
}
443443

444+
#[test]
445+
fn test_interleave_dictionary_nulls() {
446+
let input_1_keys = Int32Array::from_iter_values([0, 2, 1, 3]);
447+
let input_1_values = StringArray::from(vec![Some("foo"), None, Some("bar"), Some("fiz")]);
448+
let input_1 = DictionaryArray::new(input_1_keys, Arc::new(input_1_values));
449+
let input_2: DictionaryArray<Int32Type> = vec![None].into_iter().collect();
450+
451+
let expected = vec![Some("fiz"), None, None, Some("foo")];
452+
453+
let values = interleave(
454+
&[&input_1 as _, &input_2 as _],
455+
&[(0, 3), (0, 2), (1, 0), (0, 0)],
456+
)
457+
.unwrap();
458+
let dictionary = values.as_dictionary::<Int32Type>();
459+
let actual: Vec<Option<&str>> = dictionary
460+
.downcast_dict::<StringArray>()
461+
.unwrap()
462+
.into_iter()
463+
.collect();
464+
465+
assert_eq!(actual, expected);
466+
}
467+
444468
#[test]
445469
fn test_lists() {
446470
// [[1, 2], null, [3]]

0 commit comments

Comments
 (0)