Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 53 additions & 38 deletions datafusion/spark/src/function/math/hex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use std::any::Any;
use std::str::from_utf8_unchecked;
use std::sync::Arc;

use arrow::array::{Array, BinaryArray, Int64Array, StringArray, StringBuilder};
use arrow::array::{Array, ArrayRef, StringBuilder};
use arrow::datatypes::DataType;
use arrow::{
array::{as_dictionary_array, as_largestring_array, as_string_array},
Expand Down Expand Up @@ -92,11 +92,13 @@ impl ScalarUDFImpl for SparkHex {
&self.signature
}

fn return_type(
&self,
_arg_types: &[DataType],
) -> datafusion_common::Result<DataType> {
Ok(DataType::Utf8)
fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
Ok(match &arg_types[0] {
DataType::Dictionary(key_type, _) => {
DataType::Dictionary(key_type.clone(), Box::new(DataType::Utf8))
}
_ => DataType::Utf8,
})
}

fn invoke_with_args(
Expand Down Expand Up @@ -241,29 +243,38 @@ pub fn compute_hex(
let array = as_fixed_size_binary_array(array)?;
hex_encode_bytes(array.iter(), lowercase, array.len())
}
DataType::Dictionary(_, value_type) => {
DataType::Dictionary(_, _) => {
let dict = as_dictionary_array::<Int32Type>(&array);
let dict_values = dict.values();

match **value_type {
let encoded_values: ColumnarValue = match dict_values.data_type() {
DataType::Int64 => {
let arr = dict.downcast_dict::<Int64Array>().unwrap();
hex_encode_int64(arr.into_iter(), dict.len())
let arr = as_int64_array(dict_values)?;
hex_encode_int64(arr.iter(), arr.len())?
}
DataType::Utf8 => {
let arr = dict.downcast_dict::<StringArray>().unwrap();
hex_encode_bytes(arr.into_iter(), lowercase, dict.len())
let arr = as_string_array(dict_values);
hex_encode_bytes(arr.iter(), lowercase, arr.len())?
}
DataType::Binary => {
let arr = dict.downcast_dict::<BinaryArray>().unwrap();
hex_encode_bytes(arr.into_iter(), lowercase, dict.len())
let arr = as_binary_array(dict_values)?;
hex_encode_bytes(arr.iter(), lowercase, arr.len())?
}
_ => {
exec_err!(
return exec_err!(
"hex got an unexpected argument type: {}",
array.data_type()
)
dict_values.data_type()
);
}
}
};

let encoded_values_array: ArrayRef = match encoded_values {
ColumnarValue::Array(a) => a,
ColumnarValue::Scalar(s) => Arc::new(s.to_array()?),
};

let new_dict = dict.with_values(encoded_values_array);
Ok(ColumnarValue::Array(Arc::new(new_dict)))
}
_ => exec_err!("hex got an unexpected argument type: {}", array.data_type()),
},
Expand All @@ -279,11 +290,12 @@ mod test {
use arrow::array::{DictionaryArray, Int32Array, Int64Array, StringArray};
use arrow::{
array::{
BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringBuilder,
StringDictionaryBuilder, as_string_array,
BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringDictionaryBuilder,
as_string_array,
},
datatypes::{Int32Type, Int64Type},
};
use datafusion_common::cast::as_dictionary_array;
use datafusion_expr::ColumnarValue;

#[test]
Expand All @@ -295,12 +307,12 @@ mod test {
input_builder.append_value("rust");
let input = input_builder.finish();

let mut string_builder = StringBuilder::new();
string_builder.append_value("6869");
string_builder.append_value("627965");
string_builder.append_null();
string_builder.append_value("72757374");
let expected = string_builder.finish();
let mut expected_builder = StringDictionaryBuilder::<Int32Type>::new();
expected_builder.append_value("6869");
expected_builder.append_value("627965");
expected_builder.append_null();
expected_builder.append_value("72757374");
let expected = expected_builder.finish();

let columnar_value = ColumnarValue::Array(Arc::new(input));
let result = super::spark_hex(&[columnar_value]).unwrap();
Expand All @@ -310,7 +322,7 @@ mod test {
_ => panic!("Expected array"),
};

let result = as_string_array(&result);
let result = as_dictionary_array(&result).unwrap();

assert_eq!(result, &expected);
}
Expand All @@ -324,12 +336,12 @@ mod test {
input_builder.append_value(3);
let input = input_builder.finish();

let mut string_builder = StringBuilder::new();
string_builder.append_value("1");
string_builder.append_value("2");
string_builder.append_null();
string_builder.append_value("3");
let expected = string_builder.finish();
let mut expected_builder = StringDictionaryBuilder::<Int32Type>::new();
expected_builder.append_value("1");
expected_builder.append_value("2");
expected_builder.append_null();
expected_builder.append_value("3");
let expected = expected_builder.finish();

let columnar_value = ColumnarValue::Array(Arc::new(input));
let result = super::spark_hex(&[columnar_value]).unwrap();
Expand All @@ -339,7 +351,7 @@ mod test {
_ => panic!("Expected array"),
};

let result = as_string_array(&result);
let result = as_dictionary_array(&result).unwrap();

assert_eq!(result, &expected);
}
Expand All @@ -353,7 +365,7 @@ mod test {
input_builder.append_value("3");
let input = input_builder.finish();

let mut expected_builder = StringBuilder::new();
let mut expected_builder = StringDictionaryBuilder::<Int32Type>::new();
expected_builder.append_value("31");
expected_builder.append_value("6A");
expected_builder.append_null();
Expand All @@ -368,7 +380,7 @@ mod test {
_ => panic!("Expected array"),
};

let result = as_string_array(&result);
let result = as_dictionary_array(&result).unwrap();

assert_eq!(result, &expected);
}
Expand Down Expand Up @@ -425,8 +437,11 @@ mod test {
_ => panic!("Expected array"),
};

let result = as_string_array(&result);
let expected = StringArray::from(vec![Some("20"), None, None]);
let result = as_dictionary_array(&result).unwrap();

let keys = Int32Array::from(vec![Some(0), None, Some(1)]);
let vals = StringArray::from(vec![Some("20"), None]);
let expected = DictionaryArray::new(keys, Arc::new(vals));

assert_eq!(&expected, result);
}
Expand Down
15 changes: 15 additions & 0 deletions datafusion/sqllogictest/test_files/spark/math/hex.slt
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,18 @@ query T
SELECT hex(arrow_cast('test', 'LargeBinary')) as lar_b;
----
74657374

statement ok
CREATE TABLE t_dict_utf8 AS
SELECT arrow_cast(column1, 'Dictionary(Int32, Utf8)') as dict_col
FROM VALUES ('foo'), ('bar'), ('foo'), (NULL), ('baz'), ('bar');

query T
SELECT hex(dict_col) FROM t_dict_utf8;
----
666F6F
626172
666F6F
NULL
62617A
626172