Skip to content

Commit a26b3de

Browse files
authored
feat: Add support for rpad (#1470)
* use stable toolchain * clippy * fmt * add support for lpad and rpad * test passse * enable read-side padding in TPC stability suite * format * revert a change * address feeedback * re-implement * re-implement * re-implement * format * address feedback
1 parent 26b406d commit a26b3de

File tree

7 files changed

+82
-12
lines changed

7 files changed

+82
-12
lines changed

docs/source/user-guide/expressions.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ The following Spark expressions are currently available. Any known compatibility
6868
## String Functions
6969

7070
| Expression | Notes |
71-
| --------------- | ----------------------------------------------------------------------------------------------------------- |
71+
|-----------------| ----------------------------------------------------------------------------------------------------------- |
7272
| Ascii | |
7373
| BitLength | |
7474
| Chr | |
@@ -85,6 +85,7 @@ The following Spark expressions are currently available. Any known compatibility
8585
| Replace | |
8686
| Reverse | |
8787
| StartsWith | |
88+
| StringRPad | |
8889
| StringSpace | |
8990
| StringTrim | |
9091
| StringTrimBoth | |

native/spark-expr/src/comet_scalar_funcs.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use crate::hash_funcs::*;
1919
use crate::{
2020
spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, spark_decimal_integral_div,
2121
spark_floor, spark_hex, spark_isnan, spark_make_decimal, spark_read_side_padding, spark_round,
22-
spark_unhex, spark_unscaled_value, SparkChrFunc,
22+
spark_rpad, spark_unhex, spark_unscaled_value, SparkChrFunc,
2323
};
2424
use arrow_schema::DataType;
2525
use datafusion_common::{DataFusionError, Result as DataFusionResult};
@@ -69,6 +69,10 @@ pub fn create_comet_physical_fun(
6969
let func = Arc::new(spark_read_side_padding);
7070
make_comet_scalar_udf!("read_side_padding", func, without data_type)
7171
}
72+
"rpad" => {
73+
let func = Arc::new(spark_rpad);
74+
make_comet_scalar_udf!("rpad", func, without data_type)
75+
}
7276
"round" => {
7377
make_comet_scalar_udf!("round", spark_round, data_type)
7478
}

native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@
1717

1818
mod read_side_padding;
1919

20-
pub use read_side_padding::spark_read_side_padding;
20+
pub use read_side_padding::{spark_read_side_padding, spark_rpad};

native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,25 @@ use std::sync::Arc;
2626

2727
/// Similar to DataFusion `rpad`, but not to truncate when the string is already longer than length
2828
pub fn spark_read_side_padding(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
29+
spark_read_side_padding2(args, false)
30+
}
31+
32+
/// Custom `rpad` because DataFusion's `rpad` has differences in unicode handling
33+
pub fn spark_rpad(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
34+
spark_read_side_padding2(args, true)
35+
}
36+
37+
fn spark_read_side_padding2(
38+
args: &[ColumnarValue],
39+
truncate: bool,
40+
) -> Result<ColumnarValue, DataFusionError> {
2941
match args {
3042
[ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => {
3143
match array.data_type() {
32-
DataType::Utf8 => spark_read_side_padding_internal::<i32>(array, *length),
33-
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(array, *length),
44+
DataType::Utf8 => spark_read_side_padding_internal::<i32>(array, *length, truncate),
45+
DataType::LargeUtf8 => {
46+
spark_read_side_padding_internal::<i64>(array, *length, truncate)
47+
}
3448
// TODO: handle Dictionary types
3549
other => Err(DataFusionError::Internal(format!(
3650
"Unsupported data type {other:?} for function read_side_padding",
@@ -46,6 +60,7 @@ pub fn spark_read_side_padding(args: &[ColumnarValue]) -> Result<ColumnarValue,
4660
fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
4761
array: &ArrayRef,
4862
length: i32,
63+
truncate: bool,
4964
) -> Result<ColumnarValue, DataFusionError> {
5065
let string_array = as_generic_string_array::<T>(array)?;
5166
let length = 0.max(length) as usize;
@@ -61,7 +76,16 @@ fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
6176
// https://stackoverflow.com/a/46290728
6277
let char_len = string.chars().count();
6378
if length <= char_len {
64-
builder.append_value(string);
79+
if truncate {
80+
let idx = string
81+
.char_indices()
82+
.nth(length)
83+
.map(|(i, _)| i)
84+
.unwrap_or(string.len());
85+
builder.append_value(&string[..idx]);
86+
} else {
87+
builder.append_value(string);
88+
}
6589
} else {
6690
// write_str updates only the value buffer, not null nor offset buffer
6791
// This is convenient for concatenating str(s)

native/spark-expr/src/static_invoke/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@
1717

1818
mod char_varchar_utils;
1919

20-
pub use char_varchar_utils::spark_read_side_padding;
20+
pub use char_varchar_utils::{spark_read_side_padding, spark_rpad};

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1745,16 +1745,30 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
17451745
exprToProtoInternal(s.arguments(1), inputs, binding))
17461746

17471747
if (argsExpr.forall(_.isDefined)) {
1748-
val builder = ExprOuterClass.ScalarFunc.newBuilder()
1749-
builder.setFunc("read_side_padding")
1750-
argsExpr.foreach(arg => builder.addArgs(arg.get))
1751-
1752-
Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build())
1748+
scalarExprToProto("read_side_padding", argsExpr: _*)
17531749
} else {
17541750
withInfo(expr, s.arguments: _*)
17551751
None
17561752
}
17571753

1754+
// read-side padding in Spark 3.5.2+ is represented by rpad function
1755+
case StringRPad(srcStr, size, chars) =>
1756+
chars match {
1757+
case Literal(str, DataTypes.StringType) if str.toString == " " =>
1758+
val arg0 = exprToProtoInternal(srcStr, inputs, binding)
1759+
val arg1 = exprToProtoInternal(size, inputs, binding)
1760+
if (arg0.isDefined && arg1.isDefined) {
1761+
scalarExprToProto("rpad", arg0, arg1)
1762+
} else {
1763+
withInfo(expr, "rpad unsupported arguments", srcStr, size)
1764+
None
1765+
}
1766+
1767+
case _ =>
1768+
withInfo(expr, "rpad only supports padding with spaces")
1769+
None
1770+
}
1771+
17581772
case KnownFloatingPointNormalized(NormalizeNaNAndZero(expr)) =>
17591773
val dataType = serializeDataType(expr.dataType)
17601774
if (dataType.isEmpty) {

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2122,6 +2122,33 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
21222122
}
21232123
}
21242124

2125+
test("rpad") {
2126+
val table = "rpad"
2127+
val gen = new DataGenerator(new Random(42))
2128+
withTable(table) {
2129+
// generate some data
2130+
val dataChars = "abc123"
2131+
sql(s"create table $table(id int, name1 char(8), name2 varchar(8)) using parquet")
2132+
val testData = gen.generateStrings(100, dataChars, 6) ++ Seq(
2133+
"", // unicode 'e\\u{301}'
2134+
"é" // unicode '\\u{e9}'
2135+
)
2136+
testData.zipWithIndex.foreach { x =>
2137+
sql(s"insert into $table values(${x._2}, '${x._1}', '${x._1}')")
2138+
}
2139+
// test 2-arg version
2140+
checkSparkAnswerAndOperator(
2141+
s"SELECT id, rpad(name1, 10), rpad(name2, 10) FROM $table ORDER BY id")
2142+
// test 3-arg version
2143+
for (length <- Seq(2, 10)) {
2144+
checkSparkAnswerAndOperator(
2145+
s"SELECT id, name1, rpad(name1, $length, ' ') FROM $table ORDER BY id")
2146+
checkSparkAnswerAndOperator(
2147+
s"SELECT id, name2, rpad(name2, $length, ' ') FROM $table ORDER BY id")
2148+
}
2149+
}
2150+
}
2151+
21252152
test("isnan") {
21262153
Seq("true", "false").foreach { dictionary =>
21272154
withSQLConf("parquet.enable.dictionary" -> dictionary) {

0 commit comments

Comments
 (0)