Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 110 additions & 54 deletions datafusion/functions/src/string/repeat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::sync::Arc;
use crate::utils::utf8_to_str_type;
use arrow::array::{
Array, ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array,
OffsetSizeTrait, StringArrayType, StringViewArray,
OffsetSizeTrait, StringArrayType, StringViewArray, StringViewBuilder,
};
use arrow::datatypes::DataType;
use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View};
Expand Down Expand Up @@ -96,23 +96,27 @@ impl ScalarUDFImpl for RepeatFunc {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
utf8_to_str_type(&arg_types[0], "repeat")
if arg_types[0] == Utf8View {
Ok(Utf8View)
} else {
utf8_to_str_type(&arg_types[0], "repeat")
}
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let return_type = args.return_field.data_type().clone();
let return_type = args.return_field.data_type();
let [string_arg, count_arg] = take_function_args(self.name(), args.args)?;

// Early return if either argument is a scalar null
if let ColumnarValue::Scalar(s) = &string_arg
&& s.is_null()
{
return Ok(ColumnarValue::Scalar(ScalarValue::try_from(&return_type)?));
return Ok(ColumnarValue::Scalar(ScalarValue::try_from(return_type)?));
}
if let ColumnarValue::Scalar(c) = &count_arg
&& c.is_null()
{
return Ok(ColumnarValue::Scalar(ScalarValue::try_from(&return_type)?));
return Ok(ColumnarValue::Scalar(ScalarValue::try_from(return_type)?));
}

match (&string_arg, &count_arg) {
Expand All @@ -131,13 +135,12 @@ impl ScalarUDFImpl for RepeatFunc {
};

let result = match string_scalar {
ScalarValue::Utf8(Some(s)) | ScalarValue::Utf8View(Some(s)) => {
ScalarValue::Utf8(Some(compute_repeat(
s,
count,
i32::MAX as usize,
)?))
}
ScalarValue::Utf8(Some(s)) => ScalarValue::Utf8(Some(
compute_repeat(s, count, i32::MAX as usize)?,
)),
ScalarValue::Utf8View(Some(s)) => ScalarValue::Utf8View(Some(
compute_repeat(s, count, i32::MAX as usize)?,
)),
ScalarValue::LargeUtf8(Some(s)) => ScalarValue::LargeUtf8(Some(
compute_repeat(s, count, i64::MAX as usize)?,
)),
Expand Down Expand Up @@ -188,11 +191,7 @@ fn repeat(string_array: &ArrayRef, count_array: &ArrayRef) -> Result<ArrayRef> {
match string_array.data_type() {
Utf8View => {
let string_view_array = string_array.as_string_view();
repeat_impl::<i32, &StringViewArray>(
&string_view_array,
number_array,
i32::MAX as usize,
)
repeat_view_impl(string_view_array, number_array, i32::MAX as usize)
}
Utf8 => {
let string_arr = string_array.as_string::<i32>();
Expand All @@ -217,6 +216,22 @@ fn repeat(string_array: &ArrayRef, count_array: &ArrayRef) -> Result<ArrayRef> {
}
}

#[inline]
fn repeat_to_buffer(buffer: &mut Vec<u8>, string: &str, count: usize) {
buffer.clear();
if !string.is_empty() {
let src = string.as_bytes();
// Initial copy
buffer.extend_from_slice(src);
// Doubling strategy: copy what we have so far until we reach the target
while buffer.len() < src.len() * count {
let copy_len = buffer.len().min(src.len() * count - buffer.len());
// SAFETY: we're copying valid UTF-8 bytes that we already verified
buffer.extend_from_within(..copy_len);
}
}
}

fn repeat_impl<'a, T, S>(
string_array: &S,
number_array: &Int64Array,
Expand All @@ -230,20 +245,19 @@ where
let mut max_item_capacity = 0;
string_array.iter().zip(number_array.iter()).try_for_each(
|(string, number)| -> Result<(), DataFusionError> {
match (string, number) {
(Some(string), Some(number)) if number >= 0 => {
let item_capacity = string.len() * number as usize;
if item_capacity > max_str_len {
return exec_err!(
"string size overflow on repeat, max size is {}, but got {}",
max_str_len,
number as usize * string.len()
);
}
total_capacity += item_capacity;
max_item_capacity = max_item_capacity.max(item_capacity);
if let (Some(string), Some(number)) = (string, number)
&& number >= 0
{
let item_capacity = string.len() * number as usize;
if item_capacity > max_str_len {
return exec_err!(
"string size overflow on repeat, max size is {}, but got {}",
max_str_len,
number as usize * string.len()
);
}
_ => (),
total_capacity += item_capacity;
max_item_capacity = max_item_capacity.max(item_capacity);
}
Ok(())
},
Expand All @@ -255,25 +269,68 @@ where
// Reusable buffer to avoid allocations in string.repeat()
let mut buffer = Vec::<u8>::with_capacity(max_item_capacity);

// Helper function to repeat a string into a buffer using doubling strategy
// count must be > 0
#[inline]
fn repeat_to_buffer(buffer: &mut Vec<u8>, string: &str, count: usize) {
buffer.clear();
if !string.is_empty() {
let src = string.as_bytes();
// Initial copy
buffer.extend_from_slice(src);
// Doubling strategy: copy what we have so far until we reach the target
while buffer.len() < src.len() * count {
let copy_len = buffer.len().min(src.len() * count - buffer.len());
// SAFETY: we're copying valid UTF-8 bytes that we already verified
buffer.extend_from_within(..copy_len);
// Fast path: no nulls in either array
if string_array.null_count() == 0 && number_array.null_count() == 0 {
for i in 0..string_array.len() {
// SAFETY: i is within bounds (0..len) and null_count() == 0 guarantees valid value
let string = unsafe { string_array.value_unchecked(i) };
let count = number_array.value(i);
if count > 0 {
repeat_to_buffer(&mut buffer, string, count as usize);
// SAFETY: buffer contains valid UTF-8 since we only copy from a valid &str
builder.append_value(unsafe { std::str::from_utf8_unchecked(&buffer) });
} else {
builder.append_value("");
}
}
} else {
// Slow path: handle nulls
for (string, number) in string_array.iter().zip(number_array.iter()) {
match (string, number) {
(Some(string), Some(count)) if count > 0 => {
repeat_to_buffer(&mut buffer, string, count as usize);
// SAFETY: buffer contains valid UTF-8 since we only copy from a valid &str
builder
.append_value(unsafe { std::str::from_utf8_unchecked(&buffer) });
}
(Some(_), Some(_)) => builder.append_value(""),
_ => builder.append_null(),
}
}
}

// Fast path: no nulls in either array
Ok(Arc::new(builder.finish()) as ArrayRef)
}

fn repeat_view_impl(
string_array: &StringViewArray,
number_array: &Int64Array,
max_str_len: usize,
) -> Result<ArrayRef> {
let mut total_capacity = 0;
let mut max_item_capacity = 0;
string_array.iter().zip(number_array.iter()).try_for_each(
|(string, number)| -> Result<(), DataFusionError> {
if let (Some(string), Some(number)) = (string, number)
&& number >= 0
{
let item_capacity = string.len() * number as usize;
if item_capacity > max_str_len {
return exec_err!(
"string size overflow on repeat, max size is {}, but got {}",
max_str_len,
number as usize * string.len()
);
}
total_capacity += item_capacity;
max_item_capacity = max_item_capacity.max(item_capacity);
}
Ok(())
},
)?;
let mut builder = StringViewBuilder::with_capacity(total_capacity);
let mut buffer = Vec::<u8>::with_capacity(max_item_capacity);

if string_array.null_count() == 0 && number_array.null_count() == 0 {
for i in 0..string_array.len() {
// SAFETY: i is within bounds (0..len) and null_count() == 0 guarantees valid value
Expand All @@ -288,7 +345,6 @@ where
}
}
} else {
// Slow path: handle nulls
for (string, number) in string_array.iter().zip(number_array.iter()) {
match (string, number) {
(Some(string), Some(count)) if count > 0 => {
Expand All @@ -308,8 +364,8 @@ where

#[cfg(test)]
mod tests {
use arrow::array::{Array, StringArray};
use arrow::datatypes::DataType::Utf8;
use arrow::array::{Array, StringArray, StringViewArray};
use arrow::datatypes::DataType::{Utf8, Utf8View};

use datafusion_common::ScalarValue;
use datafusion_common::{Result, exec_err};
Expand Down Expand Up @@ -362,8 +418,8 @@ mod tests {
],
Ok(Some("PgPgPgPg")),
&str,
Utf8,
StringArray
Utf8View,
StringViewArray
);
test_function!(
RepeatFunc::new(),
Expand All @@ -373,8 +429,8 @@ mod tests {
],
Ok(None),
&str,
Utf8,
StringArray
Utf8View,
StringViewArray
);
test_function!(
RepeatFunc::new(),
Expand All @@ -384,8 +440,8 @@ mod tests {
],
Ok(None),
&str,
Utf8,
StringArray
Utf8View,
StringViewArray
);
test_function!(
RepeatFunc::new(),
Expand Down
15 changes: 15 additions & 0 deletions datafusion/sqllogictest/test_files/string/string_literal.slt
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,21 @@ SELECT repeat(arrow_cast('foo', 'Dictionary(Int32, Utf8)'), 3)
----
foofoofoo

query T
SELECT arrow_typeof(repeat('foo', 3))
----
Utf8

query T
SELECT arrow_typeof(repeat(arrow_cast('foo', 'LargeUtf8'), 3))
----
LargeUtf8

query T
SELECT arrow_typeof(repeat(arrow_cast('foo', 'Utf8View'), 3))
----
Utf8View


query T
SELECT replace('foobar', 'bar', 'hello')
Expand Down