diff --git a/datafusion/spark/src/function/math/hex.rs b/datafusion/spark/src/function/math/hex.rs index 134324f45f5bc..5d88d8684d499 100644 --- a/datafusion/spark/src/function/math/hex.rs +++ b/datafusion/spark/src/function/math/hex.rs @@ -19,7 +19,7 @@ use std::any::Any; use std::str::from_utf8_unchecked; use std::sync::Arc; -use arrow::array::{Array, BinaryArray, Int64Array, StringArray, StringBuilder}; +use arrow::array::{Array, ArrayRef, StringBuilder}; use arrow::datatypes::DataType; use arrow::{ array::{as_dictionary_array, as_largestring_array, as_string_array}, @@ -92,11 +92,13 @@ impl ScalarUDFImpl for SparkHex { &self.signature } - fn return_type( - &self, - _arg_types: &[DataType], - ) -> datafusion_common::Result { - Ok(DataType::Utf8) + fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + Ok(match &arg_types[0] { + DataType::Dictionary(key_type, _) => { + DataType::Dictionary(key_type.clone(), Box::new(DataType::Utf8)) + } + _ => DataType::Utf8, + }) } fn invoke_with_args( @@ -241,29 +243,38 @@ pub fn compute_hex( let array = as_fixed_size_binary_array(array)?; hex_encode_bytes(array.iter(), lowercase, array.len()) } - DataType::Dictionary(_, value_type) => { + DataType::Dictionary(_, _) => { let dict = as_dictionary_array::(&array); + let dict_values = dict.values(); - match **value_type { + let encoded_values: ColumnarValue = match dict_values.data_type() { DataType::Int64 => { - let arr = dict.downcast_dict::().unwrap(); - hex_encode_int64(arr.into_iter(), dict.len()) + let arr = as_int64_array(dict_values)?; + hex_encode_int64(arr.iter(), arr.len())? } DataType::Utf8 => { - let arr = dict.downcast_dict::().unwrap(); - hex_encode_bytes(arr.into_iter(), lowercase, dict.len()) + let arr = as_string_array(dict_values); + hex_encode_bytes(arr.iter(), lowercase, arr.len())? } DataType::Binary => { - let arr = dict.downcast_dict::().unwrap(); - hex_encode_bytes(arr.into_iter(), lowercase, dict.len()) + let arr = as_binary_array(dict_values)?; + hex_encode_bytes(arr.iter(), lowercase, arr.len())? } _ => { - exec_err!( + return exec_err!( "hex got an unexpected argument type: {}", - array.data_type() - ) + dict_values.data_type() + ); } - } + }; + + let encoded_values_array: ArrayRef = match encoded_values { + ColumnarValue::Array(a) => a, + ColumnarValue::Scalar(s) => Arc::new(s.to_array()?), + }; + + let new_dict = dict.with_values(encoded_values_array); + Ok(ColumnarValue::Array(Arc::new(new_dict))) } _ => exec_err!("hex got an unexpected argument type: {}", array.data_type()), }, @@ -279,11 +290,12 @@ mod test { use arrow::array::{DictionaryArray, Int32Array, Int64Array, StringArray}; use arrow::{ array::{ - BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringBuilder, - StringDictionaryBuilder, as_string_array, + BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringDictionaryBuilder, + as_string_array, }, datatypes::{Int32Type, Int64Type}, }; + use datafusion_common::cast::as_dictionary_array; use datafusion_expr::ColumnarValue; #[test] @@ -295,12 +307,12 @@ mod test { input_builder.append_value("rust"); let input = input_builder.finish(); - let mut string_builder = StringBuilder::new(); - string_builder.append_value("6869"); - string_builder.append_value("627965"); - string_builder.append_null(); - string_builder.append_value("72757374"); - let expected = string_builder.finish(); + let mut expected_builder = StringDictionaryBuilder::::new(); + expected_builder.append_value("6869"); + expected_builder.append_value("627965"); + expected_builder.append_null(); + expected_builder.append_value("72757374"); + let expected = expected_builder.finish(); let columnar_value = ColumnarValue::Array(Arc::new(input)); let result = super::spark_hex(&[columnar_value]).unwrap(); @@ -310,7 +322,7 @@ mod test { _ => panic!("Expected array"), }; - let result = as_string_array(&result); + let result = as_dictionary_array(&result).unwrap(); assert_eq!(result, &expected); } @@ -324,12 +336,12 @@ mod test { input_builder.append_value(3); let input = input_builder.finish(); - let mut string_builder = StringBuilder::new(); - string_builder.append_value("1"); - string_builder.append_value("2"); - string_builder.append_null(); - string_builder.append_value("3"); - let expected = string_builder.finish(); + let mut expected_builder = StringDictionaryBuilder::::new(); + expected_builder.append_value("1"); + expected_builder.append_value("2"); + expected_builder.append_null(); + expected_builder.append_value("3"); + let expected = expected_builder.finish(); let columnar_value = ColumnarValue::Array(Arc::new(input)); let result = super::spark_hex(&[columnar_value]).unwrap(); @@ -339,7 +351,7 @@ mod test { _ => panic!("Expected array"), }; - let result = as_string_array(&result); + let result = as_dictionary_array(&result).unwrap(); assert_eq!(result, &expected); } @@ -353,7 +365,7 @@ mod test { input_builder.append_value("3"); let input = input_builder.finish(); - let mut expected_builder = StringBuilder::new(); + let mut expected_builder = StringDictionaryBuilder::::new(); expected_builder.append_value("31"); expected_builder.append_value("6A"); expected_builder.append_null(); @@ -368,7 +380,7 @@ mod test { _ => panic!("Expected array"), }; - let result = as_string_array(&result); + let result = as_dictionary_array(&result).unwrap(); assert_eq!(result, &expected); } @@ -425,8 +437,11 @@ mod test { _ => panic!("Expected array"), }; - let result = as_string_array(&result); - let expected = StringArray::from(vec![Some("20"), None, None]); + let result = as_dictionary_array(&result).unwrap(); + + let keys = Int32Array::from(vec![Some(0), None, Some(1)]); + let vals = StringArray::from(vec![Some("20"), None]); + let expected = DictionaryArray::new(keys, Arc::new(vals)); assert_eq!(&expected, result); } diff --git a/datafusion/sqllogictest/test_files/spark/math/hex.slt b/datafusion/sqllogictest/test_files/spark/math/hex.slt index 05c9fb3f31b28..756088a26d6ca 100644 --- a/datafusion/sqllogictest/test_files/spark/math/hex.slt +++ b/datafusion/sqllogictest/test_files/spark/math/hex.slt @@ -63,3 +63,18 @@ query T SELECT hex(arrow_cast('test', 'LargeBinary')) as lar_b; ---- 74657374 + +statement ok +CREATE TABLE t_dict_utf8 AS +SELECT arrow_cast(column1, 'Dictionary(Int32, Utf8)') as dict_col +FROM VALUES ('foo'), ('bar'), ('foo'), (NULL), ('baz'), ('bar'); + +query T +SELECT hex(dict_col) FROM t_dict_utf8; +---- +666F6F +626172 +666F6F +NULL +62617A +626172