-
Notifications
You must be signed in to change notification settings - Fork 2k
Make translate emit Utf8View for Utf8View input #20624
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,8 @@ use std::any::Any; | |
| use std::sync::Arc; | ||
|
|
||
| use arrow::array::{ | ||
| ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringArray, OffsetSizeTrait, | ||
| ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringBuilder, OffsetSizeTrait, | ||
| StringViewBuilder, | ||
| }; | ||
| use arrow::datatypes::DataType; | ||
| use datafusion_common::HashMap; | ||
|
|
@@ -93,7 +94,11 @@ impl ScalarUDFImpl for TranslateFunc { | |
| } | ||
|
|
||
| fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { | ||
| utf8_to_str_type(&arg_types[0], "translate") | ||
| if arg_types[0] == DataType::Utf8View { | ||
| Ok(DataType::Utf8View) | ||
| } else { | ||
| utf8_to_str_type(&arg_types[0], "translate") | ||
| } | ||
| } | ||
|
|
||
| fn invoke_with_args( | ||
|
|
@@ -116,33 +121,42 @@ impl ScalarUDFImpl for TranslateFunc { | |
| let ascii_table = build_ascii_translate_table(from_str, to_str); | ||
|
|
||
| let string_array = args.args[0].to_array_of_size(args.number_rows)?; | ||
| let len = string_array.len(); | ||
|
|
||
| let result = match string_array.data_type() { | ||
| DataType::Utf8View => { | ||
| let arr = string_array.as_string_view(); | ||
| translate_with_map::<i32, _>( | ||
| let builder = StringViewBuilder::with_capacity(len); | ||
| translate_with_map( | ||
| arr, | ||
| &from_map, | ||
| &to_graphemes, | ||
| ascii_table.as_ref(), | ||
| builder, | ||
| ) | ||
| } | ||
| DataType::Utf8 => { | ||
| let arr = string_array.as_string::<i32>(); | ||
| translate_with_map::<i32, _>( | ||
| let builder = | ||
| GenericStringBuilder::<i32>::with_capacity(len, len * 4); | ||
|
||
| translate_with_map( | ||
| arr, | ||
| &from_map, | ||
| &to_graphemes, | ||
| ascii_table.as_ref(), | ||
| builder, | ||
| ) | ||
| } | ||
| DataType::LargeUtf8 => { | ||
| let arr = string_array.as_string::<i64>(); | ||
| translate_with_map::<i64, _>( | ||
| let builder = | ||
| GenericStringBuilder::<i64>::with_capacity(len, len * 4); | ||
| translate_with_map( | ||
| arr, | ||
| &from_map, | ||
| &to_graphemes, | ||
| ascii_table.as_ref(), | ||
| builder, | ||
| ) | ||
| } | ||
| other => { | ||
|
|
@@ -172,41 +186,83 @@ fn try_as_scalar_str(cv: &ColumnarValue) -> Option<&str> { | |
| } | ||
|
|
||
| fn invoke_translate(args: &[ArrayRef]) -> Result<ArrayRef> { | ||
| let len = args[0].len(); | ||
| match args[0].data_type() { | ||
| DataType::Utf8View => { | ||
| let string_array = args[0].as_string_view(); | ||
| let from_array = args[1].as_string::<i32>(); | ||
| let to_array = args[2].as_string::<i32>(); | ||
| translate::<i32, _, _>(string_array, from_array, to_array) | ||
| let builder = StringViewBuilder::with_capacity(len); | ||
| translate(string_array, from_array, to_array, builder) | ||
| } | ||
| DataType::Utf8 => { | ||
| let string_array = args[0].as_string::<i32>(); | ||
| let from_array = args[1].as_string::<i32>(); | ||
| let to_array = args[2].as_string::<i32>(); | ||
| translate::<i32, _, _>(string_array, from_array, to_array) | ||
| let builder = GenericStringBuilder::<i32>::with_capacity(len, len * 4); | ||
| translate(string_array, from_array, to_array, builder) | ||
| } | ||
| DataType::LargeUtf8 => { | ||
| let string_array = args[0].as_string::<i64>(); | ||
| let from_array = args[1].as_string::<i32>(); | ||
| let to_array = args[2].as_string::<i32>(); | ||
| translate::<i64, _, _>(string_array, from_array, to_array) | ||
| let builder = GenericStringBuilder::<i64>::with_capacity(len, len * 4); | ||
| translate(string_array, from_array, to_array, builder) | ||
| } | ||
| other => { | ||
| exec_err!("Unsupported data type {other:?} for function translate") | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /// Helper trait to abstract over different string builder types so `translate` | ||
| /// and `translate_with_map` can produce the correct output array type. | ||
| trait TranslateOutput { | ||
|
||
| fn append_value(&mut self, value: &str); | ||
| fn append_null(&mut self); | ||
| fn finish(self) -> ArrayRef; | ||
| } | ||
|
|
||
| impl<T: OffsetSizeTrait> TranslateOutput for GenericStringBuilder<T> { | ||
| fn append_value(&mut self, value: &str) { | ||
| self.append_value(value); | ||
| } | ||
|
|
||
| fn append_null(&mut self) { | ||
| self.append_null(); | ||
| } | ||
|
|
||
| fn finish(mut self) -> ArrayRef { | ||
| Arc::new(GenericStringBuilder::finish(&mut self)) as ArrayRef | ||
| } | ||
| } | ||
|
|
||
| impl TranslateOutput for StringViewBuilder { | ||
| fn append_value(&mut self, value: &str) { | ||
| self.append_value(value); | ||
| } | ||
|
|
||
| fn append_null(&mut self) { | ||
| self.append_null(); | ||
| } | ||
|
|
||
| fn finish(mut self) -> ArrayRef { | ||
| Arc::new(StringViewBuilder::finish(&mut self)) as ArrayRef | ||
| } | ||
| } | ||
|
|
||
| /// Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted. | ||
| /// translate('12345', '143', 'ax') = 'a2x5' | ||
| fn translate<'a, T: OffsetSizeTrait, V, B>( | ||
| fn translate<'a, V, B, O>( | ||
| string_array: V, | ||
| from_array: B, | ||
| to_array: B, | ||
| mut builder: O, | ||
| ) -> Result<ArrayRef> | ||
| where | ||
| V: ArrayAccessor<Item = &'a str>, | ||
| B: ArrayAccessor<Item = &'a str>, | ||
| O: TranslateOutput, | ||
| { | ||
| let string_array_iter = ArrayIter::new(string_array); | ||
| let from_array_iter = ArrayIter::new(from_array); | ||
|
|
@@ -219,10 +275,10 @@ where | |
| let mut string_graphemes: Vec<&str> = Vec::new(); | ||
| let mut result_graphemes: Vec<&str> = Vec::new(); | ||
|
|
||
| let result = string_array_iter | ||
| .zip(from_array_iter) | ||
| .zip(to_array_iter) | ||
| .map(|((string, from), to)| match (string, from, to) { | ||
| for ((string, from), to) in | ||
| string_array_iter.zip(from_array_iter).zip(to_array_iter) | ||
| { | ||
| match (string, from, to) { | ||
| (Some(string), Some(from), Some(to)) => { | ||
| // Clear and reuse buffers | ||
| from_map.clear(); | ||
|
|
@@ -254,13 +310,13 @@ where | |
| } | ||
| } | ||
|
|
||
| Some(result_graphemes.concat()) | ||
| builder.append_value(&result_graphemes.concat()); | ||
| } | ||
| _ => None, | ||
| }) | ||
| .collect::<GenericStringArray<T>>(); | ||
| _ => builder.append_null(), | ||
| } | ||
| } | ||
|
|
||
| Ok(Arc::new(result) as ArrayRef) | ||
| Ok(builder.finish()) | ||
| } | ||
|
|
||
| /// Sentinel value in the ASCII translate table indicating the character should | ||
|
|
@@ -300,21 +356,23 @@ fn build_ascii_translate_table(from: &str, to: &str) -> Option<[u8; 128]> { | |
| /// translation map instead of rebuilding it for every row. When an ASCII byte | ||
| /// lookup table is provided, ASCII input rows use the lookup table; non-ASCII | ||
| /// inputs fallback to using the map. | ||
| fn translate_with_map<'a, T: OffsetSizeTrait, V>( | ||
| fn translate_with_map<'a, V, O>( | ||
| string_array: V, | ||
| from_map: &HashMap<&str, usize>, | ||
| to_graphemes: &[&str], | ||
| ascii_table: Option<&[u8; 128]>, | ||
| mut builder: O, | ||
| ) -> Result<ArrayRef> | ||
| where | ||
| V: ArrayAccessor<Item = &'a str>, | ||
| O: TranslateOutput, | ||
| { | ||
| let mut result_graphemes: Vec<&str> = Vec::new(); | ||
| let mut ascii_buf: Vec<u8> = Vec::new(); | ||
|
|
||
| let result = ArrayIter::new(string_array) | ||
| .map(|string| { | ||
| string.map(|s| { | ||
| for string in ArrayIter::new(string_array) { | ||
| match string { | ||
| Some(s) => { | ||
| // Fast path: byte-level table lookup for ASCII strings | ||
| if let Some(table) = ascii_table | ||
| && s.is_ascii() | ||
|
|
@@ -327,37 +385,38 @@ where | |
| } | ||
| } | ||
| // SAFETY: all bytes are ASCII, hence valid UTF-8. | ||
| return unsafe { | ||
| std::str::from_utf8_unchecked(&ascii_buf).to_owned() | ||
| }; | ||
| } | ||
|
|
||
| // Slow path: grapheme-based translation | ||
| result_graphemes.clear(); | ||
|
|
||
| for c in s.graphemes(true) { | ||
| match from_map.get(c) { | ||
| Some(n) => { | ||
| if let Some(replacement) = to_graphemes.get(*n) { | ||
| result_graphemes.push(*replacement); | ||
| builder.append_value(unsafe { | ||
| std::str::from_utf8_unchecked(&ascii_buf) | ||
| }); | ||
| } else { | ||
| // Slow path: grapheme-based translation | ||
| result_graphemes.clear(); | ||
|
|
||
| for c in s.graphemes(true) { | ||
| match from_map.get(c) { | ||
| Some(n) => { | ||
| if let Some(replacement) = to_graphemes.get(*n) { | ||
| result_graphemes.push(*replacement); | ||
| } | ||
| } | ||
| None => result_graphemes.push(c), | ||
| } | ||
| None => result_graphemes.push(c), | ||
| } | ||
| } | ||
|
|
||
| result_graphemes.concat() | ||
| }) | ||
| }) | ||
| .collect::<GenericStringArray<T>>(); | ||
| builder.append_value(&result_graphemes.concat()); | ||
| } | ||
| } | ||
| None => builder.append_null(), | ||
| } | ||
| } | ||
|
|
||
| Ok(Arc::new(result) as ArrayRef) | ||
| Ok(builder.finish()) | ||
| } | ||
|
|
||
| #[cfg(test)] | ||
| mod tests { | ||
| use arrow::array::{Array, StringArray}; | ||
| use arrow::datatypes::DataType::Utf8; | ||
| use arrow::array::{Array, StringArray, StringViewArray}; | ||
| use arrow::datatypes::DataType::{Utf8, Utf8View}; | ||
|
|
||
| use datafusion_common::{Result, ScalarValue}; | ||
| use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; | ||
|
|
@@ -453,6 +512,45 @@ mod tests { | |
| Utf8, | ||
| StringArray | ||
| ); | ||
| // Utf8View input should produce Utf8View output | ||
| test_function!( | ||
| TranslateFunc::new(), | ||
| vec![ | ||
| ColumnarValue::Scalar(ScalarValue::Utf8View(Some("12345".into()))), | ||
| ColumnarValue::Scalar(ScalarValue::from("143")), | ||
| ColumnarValue::Scalar(ScalarValue::from("ax")) | ||
| ], | ||
| Ok(Some("a2x5")), | ||
| &str, | ||
| Utf8View, | ||
| StringViewArray | ||
| ); | ||
| // Null Utf8View input | ||
| test_function!( | ||
| TranslateFunc::new(), | ||
| vec![ | ||
| ColumnarValue::Scalar(ScalarValue::Utf8View(None)), | ||
| ColumnarValue::Scalar(ScalarValue::from("143")), | ||
| ColumnarValue::Scalar(ScalarValue::from("ax")) | ||
| ], | ||
| Ok(None), | ||
| &str, | ||
| Utf8View, | ||
| StringViewArray | ||
| ); | ||
| // Non-ASCII Utf8View input | ||
| test_function!( | ||
| TranslateFunc::new(), | ||
| vec![ | ||
| ColumnarValue::Scalar(ScalarValue::Utf8View(Some("é2íñ5".into()))), | ||
| ColumnarValue::Scalar(ScalarValue::from("éñí")), | ||
| ColumnarValue::Scalar(ScalarValue::from("óü")) | ||
| ], | ||
| Ok(Some("ó2ü5")), | ||
| &str, | ||
| Utf8View, | ||
| StringViewArray | ||
| ); | ||
|
|
||
| #[cfg(not(feature = "unicode_expressions"))] | ||
| test_function!( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simpler
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, simplified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... and if args_type[0] was Int32Type, etc? I am not sure this is the correct approach here.