Skip to content

Commit 2f2e705

Browse files
albertlockettalamb
andauthored
feat: add constructor to help efficiently upgrade key for GenericBytesDictionaryBuilder (#7611)
# Which issue does this PR close? - Closes #7610 # Rationale for this change I'm adding this because I would like to have a more efficient method for upgrading the key type of a dictionary builder in the case where my dictionary keys have overflowed. # What changes are included in this PR? This adds a method called `try_new_from_builder` to `GenericByteDictionaryBuilder` that can be used to construct a new builder from the passed argument with the same values and internal state, but a keys array builder of a different type (the motivation being that the new key type could hold more values). # Are there any user-facing changes? --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 71ee9d9 commit 2f2e705

File tree

1 file changed

+161
-2
lines changed

1 file changed

+161
-2
lines changed

arrow-array/src/builder/generic_bytes_dictionary_builder.rs

Lines changed: 161 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@
1717

1818
use crate::builder::{ArrayBuilder, GenericByteBuilder, PrimitiveBuilder};
1919
use crate::types::{ArrowDictionaryKeyType, ByteArrayType, GenericBinaryType, GenericStringType};
20-
use crate::{Array, ArrayRef, DictionaryArray, GenericByteArray, TypedDictionaryArray};
20+
use crate::{
21+
Array, ArrayRef, DictionaryArray, GenericByteArray, PrimitiveArray, TypedDictionaryArray,
22+
};
2123
use arrow_buffer::ArrowNativeType;
2224
use arrow_schema::{ArrowError, DataType};
2325
use hashbrown::HashTable;
26+
use num::NumCast;
2427
use std::any::Any;
2528
use std::sync::Arc;
2629

@@ -152,6 +155,71 @@ where
152155
values_builder,
153156
})
154157
}
158+
159+
/// Creates a new `GenericByteDictionaryBuilder` from the existing builder with the same
160+
/// keys and values, but with a new data type for the keys.
161+
///
162+
/// # Example
163+
/// ```
164+
/// #
165+
/// # use arrow_array::builder::StringDictionaryBuilder;
166+
/// # use arrow_array::types::{UInt8Type, UInt16Type};
167+
/// # use arrow_array::UInt16Array;
168+
/// # use arrow_schema::ArrowError;
169+
///
170+
/// let mut u8_keyed_builder = StringDictionaryBuilder::<UInt8Type>::new();
171+
///
172+
/// // appending too many values causes the dictionary to overflow
173+
/// for i in 0..256 {
174+
/// u8_keyed_builder.append_value(format!("{}", i));
175+
/// }
176+
/// let result = u8_keyed_builder.append("256");
177+
/// assert!(matches!(result, Err(ArrowError::DictionaryKeyOverflowError{})));
178+
///
179+
/// // we need to upgrade to a larger key type
180+
/// let mut u16_keyed_builder = StringDictionaryBuilder::<UInt16Type>::try_new_from_builder(u8_keyed_builder).unwrap();
181+
/// let dictionary_array = u16_keyed_builder.finish();
182+
/// let keys = dictionary_array.keys();
183+
///
184+
/// assert_eq!(keys, &UInt16Array::from_iter(0..256));
185+
/// ```
186+
pub fn try_new_from_builder<K2>(
187+
mut source: GenericByteDictionaryBuilder<K2, T>,
188+
) -> Result<Self, ArrowError>
189+
where
190+
K::Native: NumCast,
191+
K2: ArrowDictionaryKeyType,
192+
K2::Native: NumCast,
193+
{
194+
let state = source.state;
195+
let dedup = source.dedup;
196+
let values_builder = source.values_builder;
197+
198+
let source_keys = source.keys_builder.finish();
199+
let new_keys: PrimitiveArray<K> = source_keys.try_unary(|value| {
200+
num::cast::cast::<K2::Native, K::Native>(value).ok_or_else(|| {
201+
ArrowError::CastError(format!(
202+
"Can't cast dictionary keys from source type {:?} to type {:?}",
203+
K2::DATA_TYPE,
204+
K::DATA_TYPE
205+
))
206+
})
207+
})?;
208+
209+
// drop source key here because currently source_keys and new_keys are holding reference to
210+
// the same underlying null_buffer. Below we want to call new_keys.into_builder() it must
211+
// be the only reference holder.
212+
drop(source_keys);
213+
214+
Ok(Self {
215+
state,
216+
dedup,
217+
keys_builder: new_keys
218+
.into_builder()
219+
.expect("underlying buffer has no references"),
220+
values_builder,
221+
})
222+
}
155223
}
156224

157225
impl<K, T> ArrayBuilder for GenericByteDictionaryBuilder<K, T>
@@ -503,7 +571,7 @@ mod tests {
503571

504572
use crate::array::Int8Array;
505573
use crate::cast::AsArray;
506-
use crate::types::{Int16Type, Int32Type, Int8Type, Utf8Type};
574+
use crate::types::{Int16Type, Int32Type, Int8Type, UInt16Type, UInt8Type, Utf8Type};
507575
use crate::{ArrowPrimitiveType, BinaryArray, StringArray};
508576

509577
fn test_bytes_dictionary_builder<T>(values: Vec<&T::Native>)
@@ -614,6 +682,97 @@ mod tests {
614682
]);
615683
}
616684

685+
fn _test_try_new_from_builder_generic_for_key_types<K1, K2, T>(values: Vec<&T::Native>)
686+
where
687+
K1: ArrowDictionaryKeyType,
688+
K1::Native: NumCast,
689+
K2: ArrowDictionaryKeyType,
690+
K2::Native: NumCast + From<u8>,
691+
T: ByteArrayType,
692+
<T as ByteArrayType>::Native: PartialEq + AsRef<<T as ByteArrayType>::Native>,
693+
{
694+
let mut source = GenericByteDictionaryBuilder::<K1, T>::new();
695+
source.append(values[0]).unwrap();
696+
source.append(values[1]).unwrap();
697+
source.append_null();
698+
source.append(values[2]).unwrap();
699+
700+
let mut result =
701+
GenericByteDictionaryBuilder::<K2, T>::try_new_from_builder(source).unwrap();
702+
let array = result.finish();
703+
704+
let mut expected_keys_builder = PrimitiveBuilder::<K2>::new();
705+
expected_keys_builder
706+
.append_value(<<K2 as ArrowPrimitiveType>::Native as From<u8>>::from(0u8));
707+
expected_keys_builder
708+
.append_value(<<K2 as ArrowPrimitiveType>::Native as From<u8>>::from(1u8));
709+
expected_keys_builder.append_null();
710+
expected_keys_builder
711+
.append_value(<<K2 as ArrowPrimitiveType>::Native as From<u8>>::from(2u8));
712+
let expected_keys = expected_keys_builder.finish();
713+
assert_eq!(array.keys(), &expected_keys);
714+
715+
let av = array.values();
716+
let ava: &GenericByteArray<T> = av.as_any().downcast_ref::<GenericByteArray<T>>().unwrap();
717+
assert_eq!(ava.value(0), values[0]);
718+
assert_eq!(ava.value(1), values[1]);
719+
assert_eq!(ava.value(2), values[2]);
720+
}
721+
722+
fn test_try_new_from_builder<T>(values: Vec<&T::Native>)
723+
where
724+
T: ByteArrayType,
725+
<T as ByteArrayType>::Native: PartialEq + AsRef<<T as ByteArrayType>::Native>,
726+
{
727+
// test cast to bigger size unsigned
728+
_test_try_new_from_builder_generic_for_key_types::<UInt8Type, UInt16Type, T>(
729+
values.clone(),
730+
);
731+
// test cast going to smaller size unsigned
732+
_test_try_new_from_builder_generic_for_key_types::<UInt16Type, UInt8Type, T>(
733+
values.clone(),
734+
);
735+
// test cast going to bigger size signed
736+
_test_try_new_from_builder_generic_for_key_types::<Int8Type, Int16Type, T>(values.clone());
737+
// test cast going to smaller size signed
738+
_test_try_new_from_builder_generic_for_key_types::<Int32Type, Int16Type, T>(values.clone());
739+
// test going from signed to signed for different size changes
740+
_test_try_new_from_builder_generic_for_key_types::<UInt8Type, Int16Type, T>(values.clone());
741+
_test_try_new_from_builder_generic_for_key_types::<Int8Type, UInt8Type, T>(values.clone());
742+
_test_try_new_from_builder_generic_for_key_types::<Int8Type, UInt16Type, T>(values.clone());
743+
_test_try_new_from_builder_generic_for_key_types::<Int32Type, Int16Type, T>(values.clone());
744+
}
745+
746+
#[test]
747+
fn test_string_dictionary_builder_try_new_from_builder() {
748+
test_try_new_from_builder::<GenericStringType<i32>>(vec!["abc", "def", "ghi"]);
749+
}
750+
751+
#[test]
752+
fn test_binary_dictionary_builder_try_new_from_builder() {
753+
test_try_new_from_builder::<GenericBinaryType<i32>>(vec![b"abc", b"def", b"ghi"]);
754+
}
755+
756+
#[test]
757+
fn test_try_new_from_builder_cast_fails() {
758+
let mut source_builder = StringDictionaryBuilder::<UInt16Type>::new();
759+
for i in 0..257 {
760+
source_builder.append_value(format!("val{}", i));
761+
}
762+
763+
// there should be too many values that we can't downcast to the underlying type
764+
// we have keys that wouldn't fit into UInt8Type
765+
let result = StringDictionaryBuilder::<UInt8Type>::try_new_from_builder(source_builder);
766+
assert!(result.is_err());
767+
if let Err(e) = result {
768+
assert!(matches!(e, ArrowError::CastError(_)));
769+
assert_eq!(
770+
e.to_string(),
771+
"Cast error: Can't cast dictionary keys from source type UInt16 to type UInt8"
772+
);
773+
}
774+
}
775+
617776
fn test_bytes_dictionary_builder_with_existing_dictionary<T>(
618777
dictionary: GenericByteArray<T>,
619778
values: Vec<&T::Native>,

0 commit comments

Comments
 (0)