|
17 | 17 |
|
18 | 18 | use crate::builder::{ArrayBuilder, GenericByteBuilder, PrimitiveBuilder}; |
19 | 19 | use crate::types::{ArrowDictionaryKeyType, ByteArrayType, GenericBinaryType, GenericStringType}; |
20 | | -use crate::{Array, ArrayRef, DictionaryArray, GenericByteArray, TypedDictionaryArray}; |
| 20 | +use crate::{ |
| 21 | + Array, ArrayRef, DictionaryArray, GenericByteArray, PrimitiveArray, TypedDictionaryArray, |
| 22 | +}; |
21 | 23 | use arrow_buffer::ArrowNativeType; |
22 | 24 | use arrow_schema::{ArrowError, DataType}; |
23 | 25 | use hashbrown::HashTable; |
| 26 | +use num::NumCast; |
24 | 27 | use std::any::Any; |
25 | 28 | use std::sync::Arc; |
26 | 29 |
|
@@ -152,6 +155,71 @@ where |
152 | 155 | values_builder, |
153 | 156 | }) |
154 | 157 | } |
| 158 | + |
| 159 | + /// Creates a new `GenericByteDictionaryBuilder` from the existing builder with the same |
| 160 | + /// keys and values, but with a new data type for the keys. |
| 161 | + /// |
| 162 | + /// # Example |
| 163 | + /// ``` |
| 164 | + /// # |
| 165 | + /// # use arrow_array::builder::StringDictionaryBuilder; |
| 166 | + /// # use arrow_array::types::{UInt8Type, UInt16Type}; |
| 167 | + /// # use arrow_array::UInt16Array; |
| 168 | + /// # use arrow_schema::ArrowError; |
| 169 | + /// |
| 170 | + /// let mut u8_keyed_builder = StringDictionaryBuilder::<UInt8Type>::new(); |
| 171 | + /// |
| 172 | + /// // appending too many values causes the dictionary to overflow |
| 173 | + /// for i in 0..256 { |
| 174 | + /// u8_keyed_builder.append_value(format!("{}", i)); |
| 175 | + /// } |
| 176 | + /// let result = u8_keyed_builder.append("256"); |
| 177 | + /// assert!(matches!(result, Err(ArrowError::DictionaryKeyOverflowError{}))); |
| 178 | + /// |
| 179 | + /// // we need to upgrade to a larger key type |
| 180 | + /// let mut u16_keyed_builder = StringDictionaryBuilder::<UInt16Type>::try_new_from_builder(u8_keyed_builder).unwrap(); |
| 181 | + /// let dictionary_array = u16_keyed_builder.finish(); |
| 182 | + /// let keys = dictionary_array.keys(); |
| 183 | + /// |
| 184 | + /// assert_eq!(keys, &UInt16Array::from_iter(0..256)); |
| 185 | + /// ``` |
| 186 | + pub fn try_new_from_builder<K2>( |
| 187 | + mut source: GenericByteDictionaryBuilder<K2, T>, |
| 188 | + ) -> Result<Self, ArrowError> |
| 189 | + where |
| 190 | + K::Native: NumCast, |
| 191 | + K2: ArrowDictionaryKeyType, |
| 192 | + K2::Native: NumCast, |
| 193 | + { |
| 194 | + let state = source.state; |
| 195 | + let dedup = source.dedup; |
| 196 | + let values_builder = source.values_builder; |
| 197 | + |
| 198 | + let source_keys = source.keys_builder.finish(); |
| 199 | + let new_keys: PrimitiveArray<K> = source_keys.try_unary(|value| { |
| 200 | + num::cast::cast::<K2::Native, K::Native>(value).ok_or_else(|| { |
| 201 | + ArrowError::CastError(format!( |
| 202 | + "Can't cast dictionary keys from source type {:?} to type {:?}", |
| 203 | + K2::DATA_TYPE, |
| 204 | + K::DATA_TYPE |
| 205 | + )) |
| 206 | + }) |
| 207 | + })?; |
| 208 | + |
| 209 | + // drop source key here because currently source_keys and new_keys are holding reference to |
| 210 | + // the same underlying null_buffer. Below we want to call new_keys.into_builder() it must |
| 211 | + // be the only reference holder. |
| 212 | + drop(source_keys); |
| 213 | + |
| 214 | + Ok(Self { |
| 215 | + state, |
| 216 | + dedup, |
| 217 | + keys_builder: new_keys |
| 218 | + .into_builder() |
| 219 | + .expect("underlying buffer has no references"), |
| 220 | + values_builder, |
| 221 | + }) |
| 222 | + } |
155 | 223 | } |
156 | 224 |
|
157 | 225 | impl<K, T> ArrayBuilder for GenericByteDictionaryBuilder<K, T> |
@@ -503,7 +571,7 @@ mod tests { |
503 | 571 |
|
504 | 572 | use crate::array::Int8Array; |
505 | 573 | use crate::cast::AsArray; |
506 | | - use crate::types::{Int16Type, Int32Type, Int8Type, Utf8Type}; |
| 574 | + use crate::types::{Int16Type, Int32Type, Int8Type, UInt16Type, UInt8Type, Utf8Type}; |
507 | 575 | use crate::{ArrowPrimitiveType, BinaryArray, StringArray}; |
508 | 576 |
|
509 | 577 | fn test_bytes_dictionary_builder<T>(values: Vec<&T::Native>) |
@@ -614,6 +682,97 @@ mod tests { |
614 | 682 | ]); |
615 | 683 | } |
616 | 684 |
|
| 685 | + fn _test_try_new_from_builder_generic_for_key_types<K1, K2, T>(values: Vec<&T::Native>) |
| 686 | + where |
| 687 | + K1: ArrowDictionaryKeyType, |
| 688 | + K1::Native: NumCast, |
| 689 | + K2: ArrowDictionaryKeyType, |
| 690 | + K2::Native: NumCast + From<u8>, |
| 691 | + T: ByteArrayType, |
| 692 | + <T as ByteArrayType>::Native: PartialEq + AsRef<<T as ByteArrayType>::Native>, |
| 693 | + { |
| 694 | + let mut source = GenericByteDictionaryBuilder::<K1, T>::new(); |
| 695 | + source.append(values[0]).unwrap(); |
| 696 | + source.append(values[1]).unwrap(); |
| 697 | + source.append_null(); |
| 698 | + source.append(values[2]).unwrap(); |
| 699 | + |
| 700 | + let mut result = |
| 701 | + GenericByteDictionaryBuilder::<K2, T>::try_new_from_builder(source).unwrap(); |
| 702 | + let array = result.finish(); |
| 703 | + |
| 704 | + let mut expected_keys_builder = PrimitiveBuilder::<K2>::new(); |
| 705 | + expected_keys_builder |
| 706 | + .append_value(<<K2 as ArrowPrimitiveType>::Native as From<u8>>::from(0u8)); |
| 707 | + expected_keys_builder |
| 708 | + .append_value(<<K2 as ArrowPrimitiveType>::Native as From<u8>>::from(1u8)); |
| 709 | + expected_keys_builder.append_null(); |
| 710 | + expected_keys_builder |
| 711 | + .append_value(<<K2 as ArrowPrimitiveType>::Native as From<u8>>::from(2u8)); |
| 712 | + let expected_keys = expected_keys_builder.finish(); |
| 713 | + assert_eq!(array.keys(), &expected_keys); |
| 714 | + |
| 715 | + let av = array.values(); |
| 716 | + let ava: &GenericByteArray<T> = av.as_any().downcast_ref::<GenericByteArray<T>>().unwrap(); |
| 717 | + assert_eq!(ava.value(0), values[0]); |
| 718 | + assert_eq!(ava.value(1), values[1]); |
| 719 | + assert_eq!(ava.value(2), values[2]); |
| 720 | + } |
| 721 | + |
| 722 | + fn test_try_new_from_builder<T>(values: Vec<&T::Native>) |
| 723 | + where |
| 724 | + T: ByteArrayType, |
| 725 | + <T as ByteArrayType>::Native: PartialEq + AsRef<<T as ByteArrayType>::Native>, |
| 726 | + { |
| 727 | + // test cast to bigger size unsigned |
| 728 | + _test_try_new_from_builder_generic_for_key_types::<UInt8Type, UInt16Type, T>( |
| 729 | + values.clone(), |
| 730 | + ); |
| 731 | + // test cast going to smaller size unsigned |
| 732 | + _test_try_new_from_builder_generic_for_key_types::<UInt16Type, UInt8Type, T>( |
| 733 | + values.clone(), |
| 734 | + ); |
| 735 | + // test cast going to bigger size signed |
| 736 | + _test_try_new_from_builder_generic_for_key_types::<Int8Type, Int16Type, T>(values.clone()); |
| 737 | + // test cast going to smaller size signed |
| 738 | + _test_try_new_from_builder_generic_for_key_types::<Int32Type, Int16Type, T>(values.clone()); |
| 739 | + // test going from signed to signed for different size changes |
| 740 | + _test_try_new_from_builder_generic_for_key_types::<UInt8Type, Int16Type, T>(values.clone()); |
| 741 | + _test_try_new_from_builder_generic_for_key_types::<Int8Type, UInt8Type, T>(values.clone()); |
| 742 | + _test_try_new_from_builder_generic_for_key_types::<Int8Type, UInt16Type, T>(values.clone()); |
| 743 | + _test_try_new_from_builder_generic_for_key_types::<Int32Type, Int16Type, T>(values.clone()); |
| 744 | + } |
| 745 | + |
| 746 | + #[test] |
| 747 | + fn test_string_dictionary_builder_try_new_from_builder() { |
| 748 | + test_try_new_from_builder::<GenericStringType<i32>>(vec!["abc", "def", "ghi"]); |
| 749 | + } |
| 750 | + |
| 751 | + #[test] |
| 752 | + fn test_binary_dictionary_builder_try_new_from_builder() { |
| 753 | + test_try_new_from_builder::<GenericBinaryType<i32>>(vec![b"abc", b"def", b"ghi"]); |
| 754 | + } |
| 755 | + |
| 756 | + #[test] |
| 757 | + fn test_try_new_from_builder_cast_fails() { |
| 758 | + let mut source_builder = StringDictionaryBuilder::<UInt16Type>::new(); |
| 759 | + for i in 0..257 { |
| 760 | + source_builder.append_value(format!("val{}", i)); |
| 761 | + } |
| 762 | + |
| 763 | + // there should be too many values that we can't downcast to the underlying type |
| 764 | + // we have keys that wouldn't fit into UInt8Type |
| 765 | + let result = StringDictionaryBuilder::<UInt8Type>::try_new_from_builder(source_builder); |
| 766 | + assert!(result.is_err()); |
| 767 | + if let Err(e) = result { |
| 768 | + assert!(matches!(e, ArrowError::CastError(_))); |
| 769 | + assert_eq!( |
| 770 | + e.to_string(), |
| 771 | + "Cast error: Can't cast dictionary keys from source type UInt16 to type UInt8" |
| 772 | + ); |
| 773 | + } |
| 774 | + } |
| 775 | + |
617 | 776 | fn test_bytes_dictionary_builder_with_existing_dictionary<T>( |
618 | 777 | dictionary: GenericByteArray<T>, |
619 | 778 | values: Vec<&T::Native>, |
|
0 commit comments