-
Notifications
You must be signed in to change notification settings - Fork 2k
perf: optimize spark_hex dictionary path by avoiding dictionary expansion
#19832
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
3b465c3
8b5087e
62d2912
ba8d1ee
0946546
dcde0f8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,7 @@ use std::any::Any; | |
| use std::str::from_utf8_unchecked; | ||
| use std::sync::Arc; | ||
|
|
||
| use arrow::array::{Array, BinaryArray, Int64Array, StringArray, StringBuilder}; | ||
| use arrow::array::{Array, ArrayRef, StringBuilder}; | ||
| use arrow::datatypes::DataType; | ||
| use arrow::{ | ||
| array::{as_dictionary_array, as_largestring_array, as_string_array}, | ||
|
|
@@ -92,11 +92,13 @@ impl ScalarUDFImpl for SparkHex { | |
| &self.signature | ||
| } | ||
|
|
||
| fn return_type( | ||
| &self, | ||
| _arg_types: &[DataType], | ||
| ) -> datafusion_common::Result<DataType> { | ||
| Ok(DataType::Utf8) | ||
| fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> { | ||
| Ok(match &arg_types[0] { | ||
| DataType::Dictionary(key_type, _) => { | ||
| DataType::Dictionary(key_type.clone(), Box::new(DataType::Utf8)) | ||
| } | ||
| _ => DataType::Utf8, | ||
| }) | ||
| } | ||
|
|
||
| fn invoke_with_args( | ||
|
|
@@ -241,29 +243,38 @@ pub fn compute_hex( | |
| let array = as_fixed_size_binary_array(array)?; | ||
| hex_encode_bytes(array.iter(), lowercase, array.len()) | ||
| } | ||
| DataType::Dictionary(_, value_type) => { | ||
| DataType::Dictionary(_, _) => { | ||
| let dict = as_dictionary_array::<Int32Type>(&array); | ||
| let dict_values = dict.values(); | ||
|
|
||
| match **value_type { | ||
| let encoded_values: ColumnarValue = match dict_values.data_type() { | ||
|
||
| DataType::Int64 => { | ||
| let arr = dict.downcast_dict::<Int64Array>().unwrap(); | ||
| hex_encode_int64(arr.into_iter(), dict.len()) | ||
| let arr = as_int64_array(dict_values)?; | ||
| hex_encode_int64(arr.iter(), arr.len())? | ||
| } | ||
| DataType::Utf8 => { | ||
| let arr = dict.downcast_dict::<StringArray>().unwrap(); | ||
| hex_encode_bytes(arr.into_iter(), lowercase, dict.len()) | ||
| let arr = as_string_array(dict_values); | ||
| hex_encode_bytes(arr.iter(), lowercase, arr.len())? | ||
| } | ||
| DataType::Binary => { | ||
| let arr = dict.downcast_dict::<BinaryArray>().unwrap(); | ||
| hex_encode_bytes(arr.into_iter(), lowercase, dict.len()) | ||
| let arr = as_binary_array(dict_values)?; | ||
| hex_encode_bytes(arr.iter(), lowercase, arr.len())? | ||
| } | ||
| _ => { | ||
| exec_err!( | ||
| return exec_err!( | ||
| "hex got an unexpected argument type: {}", | ||
| array.data_type() | ||
| ) | ||
| dict_values.data_type() | ||
| ); | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| let encoded_values_array: ArrayRef = match encoded_values { | ||
| ColumnarValue::Array(a) => a, | ||
| ColumnarValue::Scalar(s) => Arc::new(s.to_array()?), | ||
| }; | ||
|
||
|
|
||
| let new_dict = dict.with_values(encoded_values_array); | ||
| Ok(ColumnarValue::Array(Arc::new(new_dict))) | ||
| } | ||
| _ => exec_err!("hex got an unexpected argument type: {}", array.data_type()), | ||
| }, | ||
|
|
@@ -279,11 +290,12 @@ mod test { | |
| use arrow::array::{DictionaryArray, Int32Array, Int64Array, StringArray}; | ||
| use arrow::{ | ||
| array::{ | ||
| BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringBuilder, | ||
| StringDictionaryBuilder, as_string_array, | ||
| BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringDictionaryBuilder, | ||
| as_string_array, | ||
| }, | ||
| datatypes::{Int32Type, Int64Type}, | ||
| }; | ||
| use datafusion_common::cast::as_dictionary_array; | ||
| use datafusion_expr::ColumnarValue; | ||
|
|
||
| #[test] | ||
|
|
@@ -295,12 +307,12 @@ mod test { | |
| input_builder.append_value("rust"); | ||
| let input = input_builder.finish(); | ||
|
|
||
| let mut string_builder = StringBuilder::new(); | ||
| string_builder.append_value("6869"); | ||
| string_builder.append_value("627965"); | ||
| string_builder.append_null(); | ||
| string_builder.append_value("72757374"); | ||
| let expected = string_builder.finish(); | ||
| let mut expected_builder = StringDictionaryBuilder::<Int32Type>::new(); | ||
| expected_builder.append_value("6869"); | ||
| expected_builder.append_value("627965"); | ||
| expected_builder.append_null(); | ||
| expected_builder.append_value("72757374"); | ||
| let expected = expected_builder.finish(); | ||
|
|
||
| let columnar_value = ColumnarValue::Array(Arc::new(input)); | ||
| let result = super::spark_hex(&[columnar_value]).unwrap(); | ||
|
|
@@ -310,7 +322,7 @@ mod test { | |
| _ => panic!("Expected array"), | ||
| }; | ||
|
|
||
| let result = as_string_array(&result); | ||
| let result = as_dictionary_array(&result).unwrap(); | ||
|
|
||
| assert_eq!(result, &expected); | ||
| } | ||
|
|
@@ -324,12 +336,12 @@ mod test { | |
| input_builder.append_value(3); | ||
| let input = input_builder.finish(); | ||
|
|
||
| let mut string_builder = StringBuilder::new(); | ||
| string_builder.append_value("1"); | ||
| string_builder.append_value("2"); | ||
| string_builder.append_null(); | ||
| string_builder.append_value("3"); | ||
| let expected = string_builder.finish(); | ||
| let mut expected_builder = StringDictionaryBuilder::<Int32Type>::new(); | ||
| expected_builder.append_value("1"); | ||
| expected_builder.append_value("2"); | ||
| expected_builder.append_null(); | ||
| expected_builder.append_value("3"); | ||
| let expected = expected_builder.finish(); | ||
|
|
||
| let columnar_value = ColumnarValue::Array(Arc::new(input)); | ||
| let result = super::spark_hex(&[columnar_value]).unwrap(); | ||
|
|
@@ -339,7 +351,7 @@ mod test { | |
| _ => panic!("Expected array"), | ||
| }; | ||
|
|
||
| let result = as_string_array(&result); | ||
| let result = as_dictionary_array(&result).unwrap(); | ||
|
|
||
| assert_eq!(result, &expected); | ||
| } | ||
|
|
@@ -353,7 +365,7 @@ mod test { | |
| input_builder.append_value("3"); | ||
| let input = input_builder.finish(); | ||
|
|
||
| let mut expected_builder = StringBuilder::new(); | ||
| let mut expected_builder = StringDictionaryBuilder::<Int32Type>::new(); | ||
| expected_builder.append_value("31"); | ||
| expected_builder.append_value("6A"); | ||
| expected_builder.append_null(); | ||
|
|
@@ -368,7 +380,7 @@ mod test { | |
| _ => panic!("Expected array"), | ||
| }; | ||
|
|
||
| let result = as_string_array(&result); | ||
| let result = as_dictionary_array(&result).unwrap(); | ||
|
|
||
| assert_eq!(result, &expected); | ||
| } | ||
|
|
@@ -425,8 +437,11 @@ mod test { | |
| _ => panic!("Expected array"), | ||
| }; | ||
|
|
||
| let result = as_string_array(&result); | ||
| let expected = StringArray::from(vec![Some("20"), None, None]); | ||
| let result = as_dictionary_array(&result).unwrap(); | ||
|
|
||
| let keys = Int32Array::from(vec![Some(0), None, Some(1)]); | ||
| let vals = StringArray::from(vec![Some("20"), None]); | ||
| let expected = DictionaryArray::new(keys, Arc::new(vals)); | ||
|
|
||
| assert_eq!(&expected, result); | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -63,3 +63,18 @@ query T | |
| SELECT hex(arrow_cast('test', 'LargeBinary')) as lar_b; | ||
| ---- | ||
| 74657374 | ||
|
|
||
| statement ok | ||
| CREATE TABLE t_dict_utf8 AS | ||
| SELECT arrow_cast(column1, 'Dictionary(Int32, Utf8)') as dict_col | ||
| FROM VALUES ('foo'), ('bar'), ('foo'), (NULL), ('baz'), ('bar'); | ||
|
|
||
| query T | ||
| SELECT hex(dict_col) FROM t_dict_utf8; | ||
|
||
| ---- | ||
| 666F6F | ||
| 626172 | ||
| 666F6F | ||
| NULL | ||
| 62617A | ||
| 626172 | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we should have some check that the dictionary has i32 key type, otherwise this will panic