Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
use arrow::array::builder::GenericStringBuilder;
use arrow::array::cast::as_dictionary_array;
use arrow::array::types::Int32Type;
use arrow::array::{make_array, Array, DictionaryArray};
use arrow::array::{make_array, Array, AsArray, DictionaryArray};
use arrow::array::{ArrayRef, OffsetSizeTrait};
use arrow::datatypes::DataType;
use datafusion::common::{cast::as_generic_string_array, DataFusionError, ScalarValue};
use datafusion::physical_plan::ColumnarValue;
use std::fmt::Write;
use std::sync::Arc;

/// Similar to DataFusion `rpad`, but not to truncate when the string is already longer than length
Expand All @@ -43,17 +42,31 @@ fn spark_read_side_padding2(
match args {
[ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => {
match array.data_type() {
DataType::Utf8 => spark_read_side_padding_internal::<i32>(array, *length, truncate),
DataType::LargeUtf8 => {
spark_read_side_padding_internal::<i64>(array, *length, truncate)
}
DataType::Utf8 => spark_read_side_padding_internal::<i32>(
array,
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
),
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(
array,
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
),
// Dictionary support required for SPARK-48498
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apache/spark#46832

This seems related to padding. How does this affect dictionary encoded columns?

DataType::Dictionary(_, value_type) => {
let dict = as_dictionary_array::<Int32Type>(array);
let col = if value_type.as_ref() == &DataType::Utf8 {
spark_read_side_padding_internal::<i32>(dict.values(), *length, truncate)?
spark_read_side_padding_internal::<i32>(
dict.values(),
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
)?
} else {
spark_read_side_padding_internal::<i64>(dict.values(), *length, truncate)?
spark_read_side_padding_internal::<i64>(
dict.values(),
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
)?
};
// col consists of an array, so arg of to_array() is not used. Can be anything
let values = col.to_array(0)?;
Expand All @@ -65,6 +78,21 @@ fn spark_read_side_padding2(
))),
}
}
[ColumnarValue::Array(array), ColumnarValue::Array(array_int)] => match array.data_type() {
DataType::Utf8 => spark_read_side_padding_internal::<i32>(
array,
truncate,
ColumnarValue::Array(Arc::<dyn Array>::clone(array_int)),
),
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(
array,
truncate,
ColumnarValue::Array(Arc::<dyn Array>::clone(array_int)),
),
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for function rpad/read_side_padding",
))),
},
other => Err(DataFusionError::Internal(format!(
"Unsupported arguments {other:?} for function rpad/read_side_padding",
))),
Expand All @@ -73,42 +101,81 @@ fn spark_read_side_padding2(

fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
array: &ArrayRef,
length: i32,
truncate: bool,
pad_type: ColumnarValue,
) -> Result<ColumnarValue, DataFusionError> {
let string_array = as_generic_string_array::<T>(array)?;
let length = 0.max(length) as usize;
let space_string = " ".repeat(length);
match pad_type {
ColumnarValue::Array(array_int) => {
let int_pad_array = array_int.as_primitive::<Int32Type>();

let mut builder =
GenericStringBuilder::<T>::with_capacity(string_array.len(), string_array.len() * length);
let mut builder = GenericStringBuilder::<T>::with_capacity(
string_array.len(),
string_array.len() * int_pad_array.len(),
);

for string in string_array.iter() {
match string {
Some(string) => {
// It looks Spark's UTF8String is closer to chars rather than graphemes
// https://stackoverflow.com/a/46290728
let char_len = string.chars().count();
if length <= char_len {
if truncate {
let idx = string
.char_indices()
.nth(length)
.map(|(i, _)| i)
.unwrap_or(string.len());
builder.append_value(&string[..idx]);
} else {
builder.append_value(string);
}
} else {
// write_str updates only the value buffer, not null nor offset buffer
// This is convenient for concatenating str(s)
builder.write_str(string)?;
builder.append_value(&space_string[char_len..]);
for (string, length) in string_array.iter().zip(int_pad_array) {
match string {
Some(string) => builder.append_value(add_padding_string(
string.parse().unwrap(),
length.unwrap() as usize,
truncate,
)?),
_ => builder.append_null(),
}
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
ColumnarValue::Scalar(const_pad_length) => {
let length = 0.max(i32::try_from(const_pad_length)?) as usize;

let mut builder = GenericStringBuilder::<T>::with_capacity(
string_array.len(),
string_array.len() * length,
);

for string in string_array.iter() {
match string {
Some(string) => builder.append_value(add_padding_string(
string.parse().unwrap(),
length,
truncate,
)?),
_ => builder.append_null(),
}
}
_ => builder.append_null(),
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
}
}

fn add_padding_string(
string: String,
length: usize,
truncate: bool,
) -> Result<String, DataFusionError> {
// It looks Spark's UTF8String is closer to chars rather than graphemes
// https://stackoverflow.com/a/46290728
let space_string = " ".repeat(length);
let char_len = string.chars().count();
if length <= char_len {
if truncate {
let idx = string
.char_indices()
.nth(length)
.map(|(i, _)| i)
.unwrap_or(string.len());
match string[..idx].parse() {
Ok(string) => Ok(string),
Err(err) => Err(DataFusionError::Internal(format!(
"Failed adding padding string {} error {:}",
string, err
))),
}
} else {
Ok(string)
}
} else {
Ok(string + &space_string[char_len..])
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,13 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}
}
test("Verify rpad expr support for second arg instead of just literal") {
val data = Seq(("IfIWasARoadIWouldBeBent", 10), ("తెలుగు", 2))
withParquetTable(data, "t1") {
val res = sql("select rpad(_1,_2) , rpad(_1,2) from t1 order by _1")
checkSparkAnswerAndOperator(res)
}
}

test("dictionary arithmetic") {
// TODO: test ANSI mode
Expand Down
Loading