Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 104 additions & 46 deletions datafusion/functions/src/unicode/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@
// under the License.

use std::any::Any;
use std::sync::Arc;

use arrow::array::{
ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringArray, OffsetSizeTrait,
ArrayAccessor, ArrayIter, ArrayRef, AsArray, LargeStringBuilder, StringBuilder,
StringLikeArrayBuilder, StringViewBuilder,
};
use arrow::datatypes::DataType;
use datafusion_common::HashMap;
use unicode_segmentation::UnicodeSegmentation;

use crate::utils::{make_scalar_function, utf8_to_str_type};
use crate::utils::make_scalar_function;
use datafusion_common::{Result, exec_err};
use datafusion_expr::TypeSignature::Exact;
use datafusion_expr::{
Expand Down Expand Up @@ -93,7 +93,7 @@ impl ScalarUDFImpl for TranslateFunc {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
utf8_to_str_type(&arg_types[0], "translate")
Ok(arg_types[0].clone())
}

fn invoke_with_args(
Expand All @@ -116,33 +116,42 @@ impl ScalarUDFImpl for TranslateFunc {
let ascii_table = build_ascii_translate_table(from_str, to_str);

let string_array = args.args[0].to_array_of_size(args.number_rows)?;
let len = string_array.len();

let result = match string_array.data_type() {
DataType::Utf8View => {
let arr = string_array.as_string_view();
translate_with_map::<i32, _>(
let builder = StringViewBuilder::with_capacity(len);
translate_with_map(
arr,
&from_map,
&to_graphemes,
ascii_table.as_ref(),
builder,
)
}
DataType::Utf8 => {
let arr = string_array.as_string::<i32>();
translate_with_map::<i32, _>(
let builder =
StringBuilder::with_capacity(len, arr.value_data().len());
translate_with_map(
arr,
&from_map,
&to_graphemes,
ascii_table.as_ref(),
builder,
)
}
DataType::LargeUtf8 => {
let arr = string_array.as_string::<i64>();
translate_with_map::<i64, _>(
let builder =
LargeStringBuilder::with_capacity(len, arr.value_data().len());
translate_with_map(
arr,
&from_map,
&to_graphemes,
ascii_table.as_ref(),
builder,
)
}
other => {
Expand Down Expand Up @@ -172,24 +181,30 @@ fn try_as_scalar_str(cv: &ColumnarValue) -> Option<&str> {
}

fn invoke_translate(args: &[ArrayRef]) -> Result<ArrayRef> {
let len = args[0].len();
match args[0].data_type() {
DataType::Utf8View => {
let string_array = args[0].as_string_view();
let from_array = args[1].as_string::<i32>();
let to_array = args[2].as_string::<i32>();
translate::<i32, _, _>(string_array, from_array, to_array)
let builder = StringViewBuilder::with_capacity(len);
translate(string_array, from_array, to_array, builder)
}
DataType::Utf8 => {
let string_array = args[0].as_string::<i32>();
let from_array = args[1].as_string::<i32>();
let to_array = args[2].as_string::<i32>();
translate::<i32, _, _>(string_array, from_array, to_array)
let builder =
StringBuilder::with_capacity(len, string_array.value_data().len());
translate(string_array, from_array, to_array, builder)
}
DataType::LargeUtf8 => {
let string_array = args[0].as_string::<i64>();
let from_array = args[1].as_string::<i32>();
let to_array = args[2].as_string::<i32>();
translate::<i64, _, _>(string_array, from_array, to_array)
let builder =
LargeStringBuilder::with_capacity(len, string_array.value_data().len());
translate(string_array, from_array, to_array, builder)
}
other => {
exec_err!("Unsupported data type {other:?} for function translate")
Expand All @@ -199,14 +214,16 @@ fn invoke_translate(args: &[ArrayRef]) -> Result<ArrayRef> {

/// Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted.
/// translate('12345', '143', 'ax') = 'a2x5'
fn translate<'a, T: OffsetSizeTrait, V, B>(
fn translate<'a, V, B, O>(
string_array: V,
from_array: B,
to_array: B,
mut builder: O,
) -> Result<ArrayRef>
where
V: ArrayAccessor<Item = &'a str>,
B: ArrayAccessor<Item = &'a str>,
O: StringLikeArrayBuilder,
{
let string_array_iter = ArrayIter::new(string_array);
let from_array_iter = ArrayIter::new(from_array);
Expand All @@ -219,10 +236,9 @@ where
let mut string_graphemes: Vec<&str> = Vec::new();
let mut result_graphemes: Vec<&str> = Vec::new();

let result = string_array_iter
.zip(from_array_iter)
.zip(to_array_iter)
.map(|((string, from), to)| match (string, from, to) {
for ((string, from), to) in string_array_iter.zip(from_array_iter).zip(to_array_iter)
{
match (string, from, to) {
(Some(string), Some(from), Some(to)) => {
// Clear and reuse buffers
from_map.clear();
Expand Down Expand Up @@ -254,13 +270,13 @@ where
}
}

Some(result_graphemes.concat())
builder.append_value(&result_graphemes.concat());
}
_ => None,
})
.collect::<GenericStringArray<T>>();
_ => builder.append_null(),
}
}

Ok(Arc::new(result) as ArrayRef)
Ok(builder.finish())
}

/// Sentinel value in the ASCII translate table indicating the character should
Expand Down Expand Up @@ -300,21 +316,23 @@ fn build_ascii_translate_table(from: &str, to: &str) -> Option<[u8; 128]> {
/// translation map instead of rebuilding it for every row. When an ASCII byte
/// lookup table is provided, ASCII input rows use the lookup table; non-ASCII
/// inputs fallback to using the map.
fn translate_with_map<'a, T: OffsetSizeTrait, V>(
fn translate_with_map<'a, V, O>(
string_array: V,
from_map: &HashMap<&str, usize>,
to_graphemes: &[&str],
ascii_table: Option<&[u8; 128]>,
mut builder: O,
) -> Result<ArrayRef>
where
V: ArrayAccessor<Item = &'a str>,
O: StringLikeArrayBuilder,
{
let mut result_graphemes: Vec<&str> = Vec::new();
let mut ascii_buf: Vec<u8> = Vec::new();

let result = ArrayIter::new(string_array)
.map(|string| {
string.map(|s| {
for string in ArrayIter::new(string_array) {
match string {
Some(s) => {
// Fast path: byte-level table lookup for ASCII strings
if let Some(table) = ascii_table
&& s.is_ascii()
Expand All @@ -327,37 +345,38 @@ where
}
}
// SAFETY: all bytes are ASCII, hence valid UTF-8.
return unsafe {
std::str::from_utf8_unchecked(&ascii_buf).to_owned()
};
}

// Slow path: grapheme-based translation
result_graphemes.clear();

for c in s.graphemes(true) {
match from_map.get(c) {
Some(n) => {
if let Some(replacement) = to_graphemes.get(*n) {
result_graphemes.push(*replacement);
builder.append_value(unsafe {
std::str::from_utf8_unchecked(&ascii_buf)
});
} else {
// Slow path: grapheme-based translation
result_graphemes.clear();

for c in s.graphemes(true) {
match from_map.get(c) {
Some(n) => {
if let Some(replacement) = to_graphemes.get(*n) {
result_graphemes.push(*replacement);
}
}
None => result_graphemes.push(c),
}
None => result_graphemes.push(c),
}
}

result_graphemes.concat()
})
})
.collect::<GenericStringArray<T>>();
builder.append_value(&result_graphemes.concat());
}
}
None => builder.append_null(),
}
}

Ok(Arc::new(result) as ArrayRef)
Ok(builder.finish())
}

#[cfg(test)]
mod tests {
use arrow::array::{Array, StringArray};
use arrow::datatypes::DataType::Utf8;
use arrow::array::{Array, StringArray, StringViewArray};
use arrow::datatypes::DataType::{Utf8, Utf8View};

use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
Expand Down Expand Up @@ -453,6 +472,45 @@ mod tests {
Utf8,
StringArray
);
// Utf8View input should produce Utf8View output
test_function!(
TranslateFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8View(Some("12345".into()))),
ColumnarValue::Scalar(ScalarValue::from("143")),
ColumnarValue::Scalar(ScalarValue::from("ax"))
],
Ok(Some("a2x5")),
&str,
Utf8View,
StringViewArray
);
// Null Utf8View input
test_function!(
TranslateFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
ColumnarValue::Scalar(ScalarValue::from("143")),
ColumnarValue::Scalar(ScalarValue::from("ax"))
],
Ok(None),
&str,
Utf8View,
StringViewArray
);
// Non-ASCII Utf8View input
test_function!(
TranslateFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8View(Some("é2íñ5".into()))),
ColumnarValue::Scalar(ScalarValue::from("éñí")),
ColumnarValue::Scalar(ScalarValue::from("óü"))
],
Ok(Some("ó2ü5")),
&str,
Utf8View,
StringViewArray
);

#[cfg(not(feature = "unicode_expressions"))]
test_function!(
Expand Down
22 changes: 22 additions & 0 deletions datafusion/sqllogictest/test_files/string/string_literal.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1768,3 +1768,25 @@ SELECT
;
----
48 176 32 40

# translate preserves input string type

query T
SELECT translate(arrow_cast('12345', 'Utf8View'), '143', 'ax')
----
a2x5

query T
SELECT arrow_typeof(translate('12345', '143', 'ax'))
----
Utf8

query T
SELECT arrow_typeof(translate(arrow_cast('12345', 'LargeUtf8'), '143', 'ax'))
----
LargeUtf8

query T
SELECT arrow_typeof(translate(arrow_cast('12345', 'Utf8View'), '143', 'ax'))
----
Utf8View