diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 961309b788..d20f2c9d2d 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -19,9 +19,10 @@ use crate::hash_funcs::*; use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, - spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal, - spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value, - SparkBitwiseCount, SparkBitwiseGet, SparkBitwiseNot, SparkDateTrunc, SparkStringSpace, + spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_left_side_padding, + spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex, + spark_unscaled_value, SparkBitwiseCount, SparkBitwiseGet, SparkBitwiseNot, SparkDateTrunc, + SparkStringSpace, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -83,6 +84,14 @@ pub fn create_comet_physical_fun( let func = Arc::new(spark_read_side_padding); make_comet_scalar_udf!("read_side_padding", func, without data_type) } + "lpad" => { + let func = Arc::new(spark_lpad); + make_comet_scalar_udf!("lpad", func, without data_type) + } + "left_side_padding" => { + let func = Arc::new(spark_left_side_padding); + make_comet_scalar_udf!("left_side_padding", func, without data_type) + } "rpad" => { let func = Arc::new(spark_rpad); make_comet_scalar_udf!("rpad", func, without data_type) diff --git a/native/spark-expr/src/static_invoke/char_varchar_utils/left_side_padding.rs b/native/spark-expr/src/static_invoke/char_varchar_utils/left_side_padding.rs new file mode 100644 index 0000000000..c54ba0bb23 --- /dev/null +++ b/native/spark-expr/src/static_invoke/char_varchar_utils/left_side_padding.rs @@ -0,0 +1,115 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::builder::GenericStringBuilder; +use arrow::array::cast::as_dictionary_array; +use arrow::array::types::Int32Type; +use arrow::array::{make_array, Array, DictionaryArray}; +use arrow::array::{ArrayRef, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use datafusion::common::{cast::as_generic_string_array, DataFusionError, ScalarValue}; +use datafusion::physical_plan::ColumnarValue; +use std::sync::Arc; + +/// Similar to DataFusion `rpad`, but not to truncate when the string is already longer than length +pub fn spark_left_side_padding(args: &[ColumnarValue]) -> Result { + spark_left_side_padding2(args, false) +} + +/// Custom `lpad` because DataFusion's `lpad` has differences in Unicode handling +pub fn spark_lpad(args: &[ColumnarValue]) -> Result { + spark_left_side_padding2(args, true) +} + +fn spark_left_side_padding2( + args: &[ColumnarValue], + truncate: bool, +) -> Result { + match args { + [ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => { + match array.data_type() { + DataType::Utf8 => spark_left_side_padding_internal::(array, *length, truncate), + DataType::LargeUtf8 => { + spark_left_side_padding_internal::(array, *length, truncate) + } + // Dictionary support required for SPARK-48498 + DataType::Dictionary(_, value_type) => { + let dict = as_dictionary_array::(array); + let col = if value_type.as_ref() == &DataType::Utf8 { + spark_left_side_padding_internal::(dict.values(), *length, truncate)? + } else { + spark_left_side_padding_internal::(dict.values(), *length, truncate)? + }; + // col consists of an array, so arg of to_array() is not used. Can be anything + let values = col.to_array(0)?; + let result = DictionaryArray::try_new(dict.keys().clone(), values)?; + Ok(ColumnarValue::Array(make_array(result.into()))) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function rpad/read_side_padding", + ))), + } + } + other => Err(DataFusionError::Internal(format!( + "Unsupported arguments {other:?} for function rpad/read_side_padding", + ))), + } +} + +fn spark_left_side_padding_internal( + array: &ArrayRef, + length: i32, + truncate: bool, +) -> Result { + let string_array = as_generic_string_array::(array)?; + let length = 0.max(length) as usize; + let space_string = " ".repeat(length); + + let mut builder = + GenericStringBuilder::::with_capacity(string_array.len(), string_array.len() * length); + + for string in string_array.iter() { + match string { + Some(string) => { + // It looks Spark's UTF8String is closer to chars rather than graphemes + // https://stackoverflow.com/a/46290728 + let char_len = string.chars().count(); + if length <= char_len { + if truncate { + let idx = string + .char_indices() + .nth(length) + .map(|(i, _)| i) + .unwrap_or(string.len()); + builder.append_value(&string[..idx]); + } else { + builder.append_value(string); + } + } else { + // write_str updates only the value buffer, not null nor offset buffer + // This is convenient for concatenating str(s) + let mut padded = String::with_capacity(length); + padded.push_str(&space_string[char_len..]); + padded.push_str(string); + builder.append_value(&padded); + } + } + _ => builder.append_null(), + } + } + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) +} diff --git a/native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs b/native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs index 0a8d8f3c55..4fc06d6a25 100644 --- a/native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs +++ b/native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +mod left_side_padding; mod read_side_padding; +pub use left_side_padding::{spark_left_side_padding, spark_lpad}; pub use read_side_padding::{spark_read_side_padding, spark_rpad}; diff --git a/native/spark-expr/src/static_invoke/mod.rs b/native/spark-expr/src/static_invoke/mod.rs index 39735f1569..12024c22b0 100644 --- a/native/spark-expr/src/static_invoke/mod.rs +++ b/native/spark-expr/src/static_invoke/mod.rs @@ -17,4 +17,5 @@ mod char_varchar_utils; +pub use char_varchar_utils::{spark_left_side_padding, spark_lpad}; pub use char_varchar_utils::{spark_read_side_padding, spark_rpad}; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 1b72521270..0ab28aaf65 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -158,7 +158,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[DateAdd] -> CometDateAdd, classOf[DateSub] -> CometDateSub, classOf[TruncDate] -> CometTruncDate, - classOf[TruncTimestamp] -> CometTruncTimestamp) + classOf[TruncTimestamp] -> CometTruncTimestamp, + classOf[StringLPad] -> CometStringLPad) /** * Mapping of Spark aggregate expression class to Comet expression handler. diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index 75e7e8bd4c..34b32f7dfe 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -19,7 +19,7 @@ package org.apache.comet.serde -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression, Like, Literal, RLike, StringRPad, Substring} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression, Like, Literal, RLike, StringLPad, StringRPad, Substring} import org.apache.spark.sql.types.{DataTypes, LongType, StringType} import org.apache.comet.CometConf @@ -180,3 +180,23 @@ object CometStringRPad extends CometExpressionSerde { } } } + +object CometStringLPad extends CometExpressionSerde { + + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + val stringLPad = expr.asInstanceOf[StringLPad] + stringLPad.pad match { + case Literal(str, DataTypes.StringType) if str.toString == " " => + scalarFunctionExprToProto( + "lpad", + exprToProtoInternal(stringLPad.str, inputs, binding), + exprToProtoInternal(stringLPad.len, inputs, binding)) + case _ => + withInfo(expr, "StringLPad with non-space characters is not supported") + None + } + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 60435cee7b..afcf43e813 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -323,6 +323,15 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("test_lpad_expression") { + withTable("t1") { + sql("create table t1(c1 varchar(100), c2 int) using parquet") + sql("insert into t1 values('IfIWasARoadIWouldBeBent', 10)") + val res = sql("select lpad(c1, 50), lpad(c1, 5) from t1 order by c1") + checkSparkAnswerAndOperator(res) + } + } + test("dictionary arithmetic") { // TODO: test ANSI mode withSQLConf(SQLConf.ANSI_ENABLED.key -> "false", "parquet.enable.dictionary" -> "true") {