Skip to content

Commit 5062735

Browse files
authored
feat: implement_comet_native_lpad_expr (#2102)
* implement lpad expression
1 parent e684cf8 commit 5062735

File tree

7 files changed

+93
-11
lines changed

7 files changed

+93
-11
lines changed

native/spark-expr/src/comet_scalar_funcs.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mu
2020
use crate::math_funcs::modulo_expr::spark_modulo;
2121
use crate::{
2222
spark_array_repeat, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor,
23-
spark_hex, spark_isnan, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad,
24-
spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkBitwiseNot,
23+
spark_hex, spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round,
24+
spark_rpad, spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkBitwiseNot,
2525
SparkDateTrunc, SparkStringSpace,
2626
};
2727
use arrow::datatypes::DataType;
@@ -114,6 +114,10 @@ pub fn create_comet_physical_fun_with_eval_mode(
114114
let func = Arc::new(spark_rpad);
115115
make_comet_scalar_udf!("rpad", func, without data_type)
116116
}
117+
"lpad" => {
118+
let func = Arc::new(spark_lpad);
119+
make_comet_scalar_udf!("lpad", func, without data_type)
120+
}
117121
"round" => {
118122
make_comet_scalar_udf!("round", spark_round, data_type)
119123
}

native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@
1717

1818
mod read_side_padding;
1919

20-
pub use read_side_padding::{spark_read_side_padding, spark_rpad};
20+
pub use read_side_padding::{spark_lpad, spark_read_side_padding, spark_rpad};

native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,23 @@ use std::sync::Arc;
2828
const SPACE: &str = " ";
2929
/// Similar to DataFusion `rpad`, but not to truncate when the string is already longer than length
3030
pub fn spark_read_side_padding(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
31-
spark_read_side_padding2(args, false)
31+
spark_read_side_padding2(args, false, false)
3232
}
3333

3434
/// Custom `rpad` because DataFusion's `rpad` has differences in unicode handling
3535
pub fn spark_rpad(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
36-
spark_read_side_padding2(args, true)
36+
spark_read_side_padding2(args, true, false)
37+
}
38+
39+
/// Custom `lpad` because DataFusion's `lpad` has differences in unicode handling
40+
pub fn spark_lpad(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
41+
spark_read_side_padding2(args, true, true)
3742
}
3843

3944
fn spark_read_side_padding2(
4045
args: &[ColumnarValue],
4146
truncate: bool,
47+
is_left_pad: bool,
4248
) -> Result<ColumnarValue, DataFusionError> {
4349
match args {
4450
[ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => {
@@ -48,12 +54,14 @@ fn spark_read_side_padding2(
4854
truncate,
4955
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
5056
SPACE,
57+
is_left_pad,
5158
),
5259
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(
5360
array,
5461
truncate,
5562
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
5663
SPACE,
64+
is_left_pad,
5765
),
5866
// Dictionary support required for SPARK-48498
5967
DataType::Dictionary(_, value_type) => {
@@ -64,13 +72,15 @@ fn spark_read_side_padding2(
6472
truncate,
6573
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
6674
SPACE,
75+
is_left_pad,
6776
)?
6877
} else {
6978
spark_read_side_padding_internal::<i64>(
7079
dict.values(),
7180
truncate,
7281
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
7382
SPACE,
83+
is_left_pad,
7484
)?
7585
};
7686
// col consists of an array, so arg of to_array() is not used. Can be anything
@@ -91,12 +101,14 @@ fn spark_read_side_padding2(
91101
truncate,
92102
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
93103
string,
104+
is_left_pad,
94105
),
95106
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(
96107
array,
97108
truncate,
98109
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
99110
string,
111+
is_left_pad,
100112
),
101113
// Dictionary support required for SPARK-48498
102114
DataType::Dictionary(_, value_type) => {
@@ -107,13 +119,15 @@ fn spark_read_side_padding2(
107119
truncate,
108120
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
109121
SPACE,
122+
is_left_pad,
110123
)?
111124
} else {
112125
spark_read_side_padding_internal::<i64>(
113126
dict.values(),
114127
truncate,
115128
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
116129
SPACE,
130+
is_left_pad,
117131
)?
118132
};
119133
// col consists of an array, so arg of to_array() is not used. Can be anything
@@ -122,7 +136,7 @@ fn spark_read_side_padding2(
122136
Ok(ColumnarValue::Array(make_array(result.into())))
123137
}
124138
other => Err(DataFusionError::Internal(format!(
125-
"Unsupported data type {other:?} for function rpad/read_side_padding",
139+
"Unsupported data type {other:?} for function rpad/lpad/read_side_padding",
126140
))),
127141
}
128142
}
@@ -132,15 +146,17 @@ fn spark_read_side_padding2(
132146
truncate,
133147
ColumnarValue::Array(Arc::<dyn Array>::clone(array_int)),
134148
SPACE,
149+
is_left_pad,
135150
),
136151
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(
137152
array,
138153
truncate,
139154
ColumnarValue::Array(Arc::<dyn Array>::clone(array_int)),
140155
SPACE,
156+
is_left_pad,
141157
),
142158
other => Err(DataFusionError::Internal(format!(
143-
"Unsupported data type {other:?} for function rpad/read_side_padding",
159+
"Unsupported data type {other:?} for function rpad/lpad/read_side_padding",
144160
))),
145161
},
146162
[ColumnarValue::Array(array), ColumnarValue::Array(array_int), ColumnarValue::Scalar(ScalarValue::Utf8(Some(string)))] => {
@@ -150,20 +166,22 @@ fn spark_read_side_padding2(
150166
truncate,
151167
ColumnarValue::Array(Arc::<dyn Array>::clone(array_int)),
152168
string,
169+
is_left_pad,
153170
),
154171
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(
155172
array,
156173
truncate,
157174
ColumnarValue::Array(Arc::<dyn Array>::clone(array_int)),
158175
string,
176+
is_left_pad,
159177
),
160178
other => Err(DataFusionError::Internal(format!(
161179
"Unsupported data type {other:?} for function rpad/read_side_padding",
162180
))),
163181
}
164182
}
165183
other => Err(DataFusionError::Internal(format!(
166-
"Unsupported arguments {other:?} for function rpad/read_side_padding",
184+
"Unsupported arguments {other:?} for function rpad/lpad/read_side_padding",
167185
))),
168186
}
169187
}
@@ -173,6 +191,7 @@ fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
173191
truncate: bool,
174192
pad_type: ColumnarValue,
175193
pad_string: &str,
194+
is_left_pad: bool,
176195
) -> Result<ColumnarValue, DataFusionError> {
177196
let string_array = as_generic_string_array::<T>(array)?;
178197
match pad_type {
@@ -191,6 +210,7 @@ fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
191210
length.unwrap() as usize,
192211
truncate,
193212
pad_string,
213+
is_left_pad,
194214
)?),
195215
_ => builder.append_null(),
196216
}
@@ -212,6 +232,7 @@ fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
212232
length,
213233
truncate,
214234
pad_string,
235+
is_left_pad,
215236
)?),
216237
_ => builder.append_null(),
217238
}
@@ -226,6 +247,7 @@ fn add_padding_string(
226247
length: usize,
227248
truncate: bool,
228249
pad_string: &str,
250+
is_left_pad: bool,
229251
) -> Result<String, DataFusionError> {
230252
// It looks Spark's UTF8String is closer to chars rather than graphemes
231253
// https://stackoverflow.com/a/46290728
@@ -250,6 +272,14 @@ fn add_padding_string(
250272
} else {
251273
let pad_needed = length - char_len;
252274
let pad: String = pad_string.chars().cycle().take(pad_needed).collect();
253-
Ok(string + &pad)
275+
let mut result = String::with_capacity(string.len() + pad.len());
276+
if is_left_pad {
277+
result.push_str(&pad);
278+
result.push_str(&string);
279+
} else {
280+
result.push_str(&string);
281+
result.push_str(&pad);
282+
}
283+
Ok(result)
254284
}
255285
}

native/spark-expr/src/static_invoke/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@
1717

1818
mod char_varchar_utils;
1919

20-
pub use char_varchar_utils::{spark_read_side_padding, spark_rpad};
20+
pub use char_varchar_utils::{spark_lpad, spark_read_side_padding, spark_rpad};

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
175175
classOf[StringRepeat] -> CometStringRepeat,
176176
classOf[StringReplace] -> CometScalarFunction("replace"),
177177
classOf[StringRPad] -> CometStringRPad,
178+
classOf[StringLPad] -> CometStringLPad,
178179
classOf[StringSpace] -> CometScalarFunction("string_space"),
179180
classOf[StringTranslate] -> CometScalarFunction("translate"),
180181
classOf[StringTrim] -> CometScalarFunction("trim"),

spark/src/main/scala/org/apache/comet/serde/strings.scala

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ package org.apache.comet.serde
2121

2222
import java.util.Locale
2323

24-
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression, InitCap, Like, Literal, Lower, RLike, StringRepeat, StringRPad, Substring, Upper}
24+
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression, InitCap, Like, Literal, Lower, RLike, StringLPad, StringRepeat, StringRPad, Substring, Upper}
2525
import org.apache.spark.sql.types.{DataTypes, LongType, StringType}
2626

2727
import org.apache.comet.CometConf
@@ -168,6 +168,35 @@ object CometStringRPad extends CometExpressionSerde[StringRPad] {
168168
}
169169
}
170170

171+
object CometStringLPad extends CometExpressionSerde[StringLPad] {
172+
173+
/**
174+
* Convert a Spark expression into a protocol buffer representation that can be passed into
175+
* native code.
176+
*
177+
* @param expr
178+
* The Spark expression.
179+
* @param inputs
180+
* The input attributes.
181+
* @param binding
182+
* Whether the attributes are bound (this is only relevant in aggregate expressions).
183+
* @return
184+
* Protocol buffer representation, or None if the expression could not be converted. In this
185+
* case it is expected that the input expression will have been tagged with reasons why it
186+
* could not be converted.
187+
*/
188+
override def convert(
189+
expr: StringLPad,
190+
inputs: Seq[Attribute],
191+
binding: Boolean): Option[Expr] = {
192+
scalarFunctionExprToProto(
193+
"lpad",
194+
exprToProtoInternal(expr.str, inputs, binding),
195+
exprToProtoInternal(expr.len, inputs, binding),
196+
exprToProtoInternal(expr.pad, inputs, binding))
197+
}
198+
}
199+
171200
trait CommonStringExprs {
172201

173202
def stringDecode(

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,24 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
431431
}
432432
}
433433

434+
test("test lpad expression support") {
435+
val data = Seq(("IfIWasARoadIWouldBeBent", 10), ("తెలుగు", 2))
436+
withParquetTable(data, "t1") {
437+
val res = sql("select lpad(_1,_2) , lpad(_1,2) from t1 order by _1")
438+
checkSparkAnswerAndOperator(res)
439+
}
440+
}
441+
442+
test("LPAD with character support other than default space") {
443+
val data = Seq(("IfIWasARoadIWouldBeBent", 10), ("hi", 2))
444+
withParquetTable(data, "t1") {
445+
val res = sql(
446+
""" select lpad(_1,_2,'?'), lpad(_1,_2,'??') , lpad(_1,2, '??'), hex(lpad(unhex('aabb'), 5)),
447+
rpad(_1, 5, '??') from t1 order by _1 """.stripMargin)
448+
checkSparkAnswerAndOperator(res)
449+
}
450+
}
451+
434452
test("dictionary arithmetic") {
435453
// TODO: test ANSI mode
436454
withSQLConf(SQLConf.ANSI_ENABLED.key -> "false", "parquet.enable.dictionary" -> "true") {

0 commit comments

Comments
 (0)