Skip to content

Commit b83a670

Browse files
coderfenderSteve Vaughan Jr
authored andcommitted
feat: rpad support column for second arg instead of just literal (apache#2099)
1 parent 1bd47b4 commit b83a670

File tree

2 files changed

+111
-37
lines changed

2 files changed

+111
-37
lines changed

native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs

Lines changed: 104 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,11 @@
1818
use arrow::array::builder::GenericStringBuilder;
1919
use arrow::array::cast::as_dictionary_array;
2020
use arrow::array::types::Int32Type;
21-
use arrow::array::{make_array, Array, DictionaryArray};
21+
use arrow::array::{make_array, Array, AsArray, DictionaryArray};
2222
use arrow::array::{ArrayRef, OffsetSizeTrait};
2323
use arrow::datatypes::DataType;
2424
use datafusion::common::{cast::as_generic_string_array, DataFusionError, ScalarValue};
2525
use datafusion::physical_plan::ColumnarValue;
26-
use std::fmt::Write;
2726
use std::sync::Arc;
2827

2928
/// Similar to DataFusion `rpad`, but not to truncate when the string is already longer than length
@@ -43,17 +42,31 @@ fn spark_read_side_padding2(
4342
match args {
4443
[ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => {
4544
match array.data_type() {
46-
DataType::Utf8 => spark_read_side_padding_internal::<i32>(array, *length, truncate),
47-
DataType::LargeUtf8 => {
48-
spark_read_side_padding_internal::<i64>(array, *length, truncate)
49-
}
45+
DataType::Utf8 => spark_read_side_padding_internal::<i32>(
46+
array,
47+
truncate,
48+
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
49+
),
50+
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(
51+
array,
52+
truncate,
53+
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
54+
),
5055
// Dictionary support required for SPARK-48498
5156
DataType::Dictionary(_, value_type) => {
5257
let dict = as_dictionary_array::<Int32Type>(array);
5358
let col = if value_type.as_ref() == &DataType::Utf8 {
54-
spark_read_side_padding_internal::<i32>(dict.values(), *length, truncate)?
59+
spark_read_side_padding_internal::<i32>(
60+
dict.values(),
61+
truncate,
62+
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
63+
)?
5564
} else {
56-
spark_read_side_padding_internal::<i64>(dict.values(), *length, truncate)?
65+
spark_read_side_padding_internal::<i64>(
66+
dict.values(),
67+
truncate,
68+
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
69+
)?
5770
};
5871
// col consists of an array, so arg of to_array() is not used. Can be anything
5972
let values = col.to_array(0)?;
@@ -65,6 +78,21 @@ fn spark_read_side_padding2(
6578
))),
6679
}
6780
}
81+
[ColumnarValue::Array(array), ColumnarValue::Array(array_int)] => match array.data_type() {
82+
DataType::Utf8 => spark_read_side_padding_internal::<i32>(
83+
array,
84+
truncate,
85+
ColumnarValue::Array(Arc::<dyn Array>::clone(array_int)),
86+
),
87+
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(
88+
array,
89+
truncate,
90+
ColumnarValue::Array(Arc::<dyn Array>::clone(array_int)),
91+
),
92+
other => Err(DataFusionError::Internal(format!(
93+
"Unsupported data type {other:?} for function rpad/read_side_padding",
94+
))),
95+
},
6896
other => Err(DataFusionError::Internal(format!(
6997
"Unsupported arguments {other:?} for function rpad/read_side_padding",
7098
))),
@@ -73,42 +101,81 @@ fn spark_read_side_padding2(
73101

74102
fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
75103
array: &ArrayRef,
76-
length: i32,
77104
truncate: bool,
105+
pad_type: ColumnarValue,
78106
) -> Result<ColumnarValue, DataFusionError> {
79107
let string_array = as_generic_string_array::<T>(array)?;
80-
let length = 0.max(length) as usize;
81-
let space_string = " ".repeat(length);
108+
match pad_type {
109+
ColumnarValue::Array(array_int) => {
110+
let int_pad_array = array_int.as_primitive::<Int32Type>();
82111

83-
let mut builder =
84-
GenericStringBuilder::<T>::with_capacity(string_array.len(), string_array.len() * length);
112+
let mut builder = GenericStringBuilder::<T>::with_capacity(
113+
string_array.len(),
114+
string_array.len() * int_pad_array.len(),
115+
);
85116

86-
for string in string_array.iter() {
87-
match string {
88-
Some(string) => {
89-
// It looks Spark's UTF8String is closer to chars rather than graphemes
90-
// https://stackoverflow.com/a/46290728
91-
let char_len = string.chars().count();
92-
if length <= char_len {
93-
if truncate {
94-
let idx = string
95-
.char_indices()
96-
.nth(length)
97-
.map(|(i, _)| i)
98-
.unwrap_or(string.len());
99-
builder.append_value(&string[..idx]);
100-
} else {
101-
builder.append_value(string);
102-
}
103-
} else {
104-
// write_str updates only the value buffer, not null nor offset buffer
105-
// This is convenient for concatenating str(s)
106-
builder.write_str(string)?;
107-
builder.append_value(&space_string[char_len..]);
117+
for (string, length) in string_array.iter().zip(int_pad_array) {
118+
match string {
119+
Some(string) => builder.append_value(add_padding_string(
120+
string.parse().unwrap(),
121+
length.unwrap() as usize,
122+
truncate,
123+
)?),
124+
_ => builder.append_null(),
125+
}
126+
}
127+
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
128+
}
129+
ColumnarValue::Scalar(const_pad_length) => {
130+
let length = 0.max(i32::try_from(const_pad_length)?) as usize;
131+
132+
let mut builder = GenericStringBuilder::<T>::with_capacity(
133+
string_array.len(),
134+
string_array.len() * length,
135+
);
136+
137+
for string in string_array.iter() {
138+
match string {
139+
Some(string) => builder.append_value(add_padding_string(
140+
string.parse().unwrap(),
141+
length,
142+
truncate,
143+
)?),
144+
_ => builder.append_null(),
108145
}
109146
}
110-
_ => builder.append_null(),
147+
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
148+
}
149+
}
150+
}
151+
152+
fn add_padding_string(
153+
string: String,
154+
length: usize,
155+
truncate: bool,
156+
) -> Result<String, DataFusionError> {
157+
// It looks Spark's UTF8String is closer to chars rather than graphemes
158+
// https://stackoverflow.com/a/46290728
159+
let space_string = " ".repeat(length);
160+
let char_len = string.chars().count();
161+
if length <= char_len {
162+
if truncate {
163+
let idx = string
164+
.char_indices()
165+
.nth(length)
166+
.map(|(i, _)| i)
167+
.unwrap_or(string.len());
168+
match string[..idx].parse() {
169+
Ok(string) => Ok(string),
170+
Err(err) => Err(DataFusionError::Internal(format!(
171+
"Failed adding padding string {} error {:}",
172+
string, err
173+
))),
174+
}
175+
} else {
176+
Ok(string)
111177
}
178+
} else {
179+
Ok(string + &space_string[char_len..])
112180
}
113-
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
114181
}

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,13 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
407407
}
408408
}
409409
}
410+
test("Verify rpad expr support for second arg instead of just literal") {
411+
val data = Seq(("IfIWasARoadIWouldBeBent", 10), ("తెలుగు", 2))
412+
withParquetTable(data, "t1") {
413+
val res = sql("select rpad(_1,_2) , rpad(_1,2) from t1 order by _1")
414+
checkSparkAnswerAndOperator(res)
415+
}
416+
}
410417

411418
test("dictionary arithmetic") {
412419
// TODO: test ANSI mode

0 commit comments

Comments
 (0)