Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 142 additions & 44 deletions datafusion/functions/src/unicode/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ use std::any::Any;
use std::sync::Arc;

use arrow::array::{
ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringArray, OffsetSizeTrait,
ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringBuilder, OffsetSizeTrait,
StringViewBuilder,
};
use arrow::datatypes::DataType;
use datafusion_common::HashMap;
Expand Down Expand Up @@ -93,7 +94,11 @@ impl ScalarUDFImpl for TranslateFunc {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
utf8_to_str_type(&arg_types[0], "translate")
if arg_types[0] == DataType::Utf8View {
Ok(DataType::Utf8View)
} else {
utf8_to_str_type(&arg_types[0], "translate")
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if arg_types[0] == DataType::Utf8View {
Ok(DataType::Utf8View)
} else {
utf8_to_str_type(&arg_types[0], "translate")
}
Ok(arg_types[0].clone())

Simpler

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, simplified.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... and if args_type[0] was Int32Type, etc? I am not sure this is the correct approach here.

}

fn invoke_with_args(
Expand All @@ -116,33 +121,42 @@ impl ScalarUDFImpl for TranslateFunc {
let ascii_table = build_ascii_translate_table(from_str, to_str);

let string_array = args.args[0].to_array_of_size(args.number_rows)?;
let len = string_array.len();

let result = match string_array.data_type() {
DataType::Utf8View => {
let arr = string_array.as_string_view();
translate_with_map::<i32, _>(
let builder = StringViewBuilder::with_capacity(len);
translate_with_map(
arr,
&from_map,
&to_graphemes,
ascii_table.as_ref(),
builder,
)
}
DataType::Utf8 => {
let arr = string_array.as_string::<i32>();
translate_with_map::<i32, _>(
let builder =
GenericStringBuilder::<i32>::with_capacity(len, len * 4);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why * 4? Seems it might overestimate, compared to getting the byte size from input array?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to use arr.value_data().len() at all call sites.

translate_with_map(
arr,
&from_map,
&to_graphemes,
ascii_table.as_ref(),
builder,
)
}
DataType::LargeUtf8 => {
let arr = string_array.as_string::<i64>();
translate_with_map::<i64, _>(
let builder =
GenericStringBuilder::<i64>::with_capacity(len, len * 4);
translate_with_map(
arr,
&from_map,
&to_graphemes,
ascii_table.as_ref(),
builder,
)
}
other => {
Expand Down Expand Up @@ -172,41 +186,83 @@ fn try_as_scalar_str(cv: &ColumnarValue) -> Option<&str> {
}

fn invoke_translate(args: &[ArrayRef]) -> Result<ArrayRef> {
let len = args[0].len();
match args[0].data_type() {
DataType::Utf8View => {
let string_array = args[0].as_string_view();
let from_array = args[1].as_string::<i32>();
let to_array = args[2].as_string::<i32>();
translate::<i32, _, _>(string_array, from_array, to_array)
let builder = StringViewBuilder::with_capacity(len);
translate(string_array, from_array, to_array, builder)
}
DataType::Utf8 => {
let string_array = args[0].as_string::<i32>();
let from_array = args[1].as_string::<i32>();
let to_array = args[2].as_string::<i32>();
translate::<i32, _, _>(string_array, from_array, to_array)
let builder = GenericStringBuilder::<i32>::with_capacity(len, len * 4);
translate(string_array, from_array, to_array, builder)
}
DataType::LargeUtf8 => {
let string_array = args[0].as_string::<i64>();
let from_array = args[1].as_string::<i32>();
let to_array = args[2].as_string::<i32>();
translate::<i64, _, _>(string_array, from_array, to_array)
let builder = GenericStringBuilder::<i64>::with_capacity(len, len * 4);
translate(string_array, from_array, to_array, builder)
}
other => {
exec_err!("Unsupported data type {other:?} for function translate")
}
}
}

/// Helper trait to abstract over different string builder types so `translate`
/// and `translate_with_map` can produce the correct output array type.
trait TranslateOutput {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a lot of PRs making changes to other string functions; I wonder if having this trait specific only to translate is the best move? Can we take a step back and see if there is an easier way for all string UDFs to benefit from common code changes required?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, replaced with Arrow's StringLikeArrayBuilder.

fn append_value(&mut self, value: &str);
fn append_null(&mut self);
fn finish(self) -> ArrayRef;
}

impl<T: OffsetSizeTrait> TranslateOutput for GenericStringBuilder<T> {
fn append_value(&mut self, value: &str) {
self.append_value(value);
}

fn append_null(&mut self) {
self.append_null();
}

fn finish(mut self) -> ArrayRef {
Arc::new(GenericStringBuilder::finish(&mut self)) as ArrayRef
}
}

impl TranslateOutput for StringViewBuilder {
fn append_value(&mut self, value: &str) {
self.append_value(value);
}

fn append_null(&mut self) {
self.append_null();
}

fn finish(mut self) -> ArrayRef {
Arc::new(StringViewBuilder::finish(&mut self)) as ArrayRef
}
}

/// Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted.
/// translate('12345', '143', 'ax') = 'a2x5'
fn translate<'a, T: OffsetSizeTrait, V, B>(
fn translate<'a, V, B, O>(
string_array: V,
from_array: B,
to_array: B,
mut builder: O,
) -> Result<ArrayRef>
where
V: ArrayAccessor<Item = &'a str>,
B: ArrayAccessor<Item = &'a str>,
O: TranslateOutput,
{
let string_array_iter = ArrayIter::new(string_array);
let from_array_iter = ArrayIter::new(from_array);
Expand All @@ -219,10 +275,10 @@ where
let mut string_graphemes: Vec<&str> = Vec::new();
let mut result_graphemes: Vec<&str> = Vec::new();

let result = string_array_iter
.zip(from_array_iter)
.zip(to_array_iter)
.map(|((string, from), to)| match (string, from, to) {
for ((string, from), to) in
string_array_iter.zip(from_array_iter).zip(to_array_iter)
{
match (string, from, to) {
(Some(string), Some(from), Some(to)) => {
// Clear and reuse buffers
from_map.clear();
Expand Down Expand Up @@ -254,13 +310,13 @@ where
}
}

Some(result_graphemes.concat())
builder.append_value(&result_graphemes.concat());
}
_ => None,
})
.collect::<GenericStringArray<T>>();
_ => builder.append_null(),
}
}

Ok(Arc::new(result) as ArrayRef)
Ok(builder.finish())
}

/// Sentinel value in the ASCII translate table indicating the character should
Expand Down Expand Up @@ -300,21 +356,23 @@ fn build_ascii_translate_table(from: &str, to: &str) -> Option<[u8; 128]> {
/// translation map instead of rebuilding it for every row. When an ASCII byte
/// lookup table is provided, ASCII input rows use the lookup table; non-ASCII
/// inputs fallback to using the map.
fn translate_with_map<'a, T: OffsetSizeTrait, V>(
fn translate_with_map<'a, V, O>(
string_array: V,
from_map: &HashMap<&str, usize>,
to_graphemes: &[&str],
ascii_table: Option<&[u8; 128]>,
mut builder: O,
) -> Result<ArrayRef>
where
V: ArrayAccessor<Item = &'a str>,
O: TranslateOutput,
{
let mut result_graphemes: Vec<&str> = Vec::new();
let mut ascii_buf: Vec<u8> = Vec::new();

let result = ArrayIter::new(string_array)
.map(|string| {
string.map(|s| {
for string in ArrayIter::new(string_array) {
match string {
Some(s) => {
// Fast path: byte-level table lookup for ASCII strings
if let Some(table) = ascii_table
&& s.is_ascii()
Expand All @@ -327,37 +385,38 @@ where
}
}
// SAFETY: all bytes are ASCII, hence valid UTF-8.
return unsafe {
std::str::from_utf8_unchecked(&ascii_buf).to_owned()
};
}

// Slow path: grapheme-based translation
result_graphemes.clear();

for c in s.graphemes(true) {
match from_map.get(c) {
Some(n) => {
if let Some(replacement) = to_graphemes.get(*n) {
result_graphemes.push(*replacement);
builder.append_value(unsafe {
std::str::from_utf8_unchecked(&ascii_buf)
});
} else {
// Slow path: grapheme-based translation
result_graphemes.clear();

for c in s.graphemes(true) {
match from_map.get(c) {
Some(n) => {
if let Some(replacement) = to_graphemes.get(*n) {
result_graphemes.push(*replacement);
}
}
None => result_graphemes.push(c),
}
None => result_graphemes.push(c),
}
}

result_graphemes.concat()
})
})
.collect::<GenericStringArray<T>>();
builder.append_value(&result_graphemes.concat());
}
}
None => builder.append_null(),
}
}

Ok(Arc::new(result) as ArrayRef)
Ok(builder.finish())
}

#[cfg(test)]
mod tests {
use arrow::array::{Array, StringArray};
use arrow::datatypes::DataType::Utf8;
use arrow::array::{Array, StringArray, StringViewArray};
use arrow::datatypes::DataType::{Utf8, Utf8View};

use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
Expand Down Expand Up @@ -453,6 +512,45 @@ mod tests {
Utf8,
StringArray
);
// Utf8View input should produce Utf8View output
test_function!(
TranslateFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8View(Some("12345".into()))),
ColumnarValue::Scalar(ScalarValue::from("143")),
ColumnarValue::Scalar(ScalarValue::from("ax"))
],
Ok(Some("a2x5")),
&str,
Utf8View,
StringViewArray
);
// Null Utf8View input
test_function!(
TranslateFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
ColumnarValue::Scalar(ScalarValue::from("143")),
ColumnarValue::Scalar(ScalarValue::from("ax"))
],
Ok(None),
&str,
Utf8View,
StringViewArray
);
// Non-ASCII Utf8View input
test_function!(
TranslateFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8View(Some("é2íñ5".into()))),
ColumnarValue::Scalar(ScalarValue::from("éñí")),
ColumnarValue::Scalar(ScalarValue::from("óü"))
],
Ok(Some("ó2ü5")),
&str,
Utf8View,
StringViewArray
);

#[cfg(not(feature = "unicode_expressions"))]
test_function!(
Expand Down
22 changes: 22 additions & 0 deletions datafusion/sqllogictest/test_files/string/string_literal.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1768,3 +1768,25 @@ SELECT
;
----
48 176 32 40

# translate preserves input string type

query T
SELECT translate(arrow_cast('12345', 'Utf8View'), '143', 'ax')
----
a2x5

query T
SELECT arrow_typeof(translate('12345', '143', 'ax'))
----
Utf8

query T
SELECT arrow_typeof(translate(arrow_cast('12345', 'LargeUtf8'), '143', 'ax'))
----
LargeUtf8

query T
SELECT arrow_typeof(translate(arrow_cast('12345', 'Utf8View'), '143', 'ax'))
----
Utf8View
Loading