From ceb9efddc686a7e25546e6dee49a8424890170c9 Mon Sep 17 00:00:00 2001 From: Bhargava Vadlamani Date: Sat, 9 Aug 2025 11:21:24 -0700 Subject: [PATCH 1/2] implement_comet_native_lpad_expr --- native/spark-expr/src/comet_scalar_funcs.rs | 15 +- .../char_varchar_utils/left_side_padding.rs | 115 +++++++++++++ .../static_invoke/char_varchar_utils/mod.rs | 2 + .../char_varchar_utils/read_side_padding.rs | 155 ++++++++++++++---- native/spark-expr/src/static_invoke/mod.rs | 1 + .../apache/comet/serde/QueryPlanSerde.scala | 3 +- .../org/apache/comet/serde/strings.scala | 22 ++- .../apache/comet/CometExpressionSuite.scala | 20 +++ 8 files changed, 294 insertions(+), 39 deletions(-) create mode 100644 native/spark-expr/src/static_invoke/char_varchar_utils/left_side_padding.rs diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 961309b788..d20f2c9d2d 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -19,9 +19,10 @@ use crate::hash_funcs::*; use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, - spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal, - spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value, - SparkBitwiseCount, SparkBitwiseGet, SparkBitwiseNot, SparkDateTrunc, SparkStringSpace, + spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_left_side_padding, + spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex, + spark_unscaled_value, SparkBitwiseCount, SparkBitwiseGet, SparkBitwiseNot, SparkDateTrunc, + SparkStringSpace, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -83,6 +84,14 @@ pub fn create_comet_physical_fun( let func = Arc::new(spark_read_side_padding); make_comet_scalar_udf!("read_side_padding", func, without data_type) } + "lpad" => { + let func = Arc::new(spark_lpad); + make_comet_scalar_udf!("lpad", func, without data_type) + } + "left_side_padding" => { + let func = Arc::new(spark_left_side_padding); + make_comet_scalar_udf!("left_side_padding", func, without data_type) + } "rpad" => { let func = Arc::new(spark_rpad); make_comet_scalar_udf!("rpad", func, without data_type) diff --git a/native/spark-expr/src/static_invoke/char_varchar_utils/left_side_padding.rs b/native/spark-expr/src/static_invoke/char_varchar_utils/left_side_padding.rs new file mode 100644 index 0000000000..c54ba0bb23 --- /dev/null +++ b/native/spark-expr/src/static_invoke/char_varchar_utils/left_side_padding.rs @@ -0,0 +1,115 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::builder::GenericStringBuilder; +use arrow::array::cast::as_dictionary_array; +use arrow::array::types::Int32Type; +use arrow::array::{make_array, Array, DictionaryArray}; +use arrow::array::{ArrayRef, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use datafusion::common::{cast::as_generic_string_array, DataFusionError, ScalarValue}; +use datafusion::physical_plan::ColumnarValue; +use std::sync::Arc; + +/// Similar to DataFusion `rpad`, but not to truncate when the string is already longer than length +pub fn spark_left_side_padding(args: &[ColumnarValue]) -> Result { + spark_left_side_padding2(args, false) +} + +/// Custom `lpad` because DataFusion's `lpad` has differences in Unicode handling +pub fn spark_lpad(args: &[ColumnarValue]) -> Result { + spark_left_side_padding2(args, true) +} + +fn spark_left_side_padding2( + args: &[ColumnarValue], + truncate: bool, +) -> Result { + match args { + [ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => { + match array.data_type() { + DataType::Utf8 => spark_left_side_padding_internal::(array, *length, truncate), + DataType::LargeUtf8 => { + spark_left_side_padding_internal::(array, *length, truncate) + } + // Dictionary support required for SPARK-48498 + DataType::Dictionary(_, value_type) => { + let dict = as_dictionary_array::(array); + let col = if value_type.as_ref() == &DataType::Utf8 { + spark_left_side_padding_internal::(dict.values(), *length, truncate)? + } else { + spark_left_side_padding_internal::(dict.values(), *length, truncate)? + }; + // col consists of an array, so arg of to_array() is not used. Can be anything + let values = col.to_array(0)?; + let result = DictionaryArray::try_new(dict.keys().clone(), values)?; + Ok(ColumnarValue::Array(make_array(result.into()))) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function rpad/read_side_padding", + ))), + } + } + other => Err(DataFusionError::Internal(format!( + "Unsupported arguments {other:?} for function rpad/read_side_padding", + ))), + } +} + +fn spark_left_side_padding_internal( + array: &ArrayRef, + length: i32, + truncate: bool, +) -> Result { + let string_array = as_generic_string_array::(array)?; + let length = 0.max(length) as usize; + let space_string = " ".repeat(length); + + let mut builder = + GenericStringBuilder::::with_capacity(string_array.len(), string_array.len() * length); + + for string in string_array.iter() { + match string { + Some(string) => { + // It looks Spark's UTF8String is closer to chars rather than graphemes + // https://stackoverflow.com/a/46290728 + let char_len = string.chars().count(); + if length <= char_len { + if truncate { + let idx = string + .char_indices() + .nth(length) + .map(|(i, _)| i) + .unwrap_or(string.len()); + builder.append_value(&string[..idx]); + } else { + builder.append_value(string); + } + } else { + // write_str updates only the value buffer, not null nor offset buffer + // This is convenient for concatenating str(s) + let mut padded = String::with_capacity(length); + padded.push_str(&space_string[char_len..]); + padded.push_str(string); + builder.append_value(&padded); + } + } + _ => builder.append_null(), + } + } + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) +} diff --git a/native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs b/native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs index 0a8d8f3c55..4fc06d6a25 100644 --- a/native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs +++ b/native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +mod left_side_padding; mod read_side_padding; +pub use left_side_padding::{spark_left_side_padding, spark_lpad}; pub use read_side_padding::{spark_read_side_padding, spark_rpad}; diff --git a/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs b/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs index 6e56d9d86f..e92059ccf0 100644 --- a/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs +++ b/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs @@ -18,10 +18,10 @@ use arrow::array::builder::GenericStringBuilder; use arrow::array::cast::as_dictionary_array; use arrow::array::types::Int32Type; -use arrow::array::{make_array, Array, DictionaryArray}; +use arrow::array::{make_array, Array, AsArray, DictionaryArray}; use arrow::array::{ArrayRef, OffsetSizeTrait}; use arrow::datatypes::DataType; -use datafusion::common::{cast::as_generic_string_array, DataFusionError, ScalarValue}; +use datafusion::common::{cast::as_generic_string_array, DataFusionError, HashMap, ScalarValue}; use datafusion::physical_plan::ColumnarValue; use std::fmt::Write; use std::sync::Arc; @@ -42,18 +42,48 @@ fn spark_read_side_padding2( ) -> Result { match args { [ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => { + let rpad_arg = RPadArgument::ConstLength(*length); match array.data_type() { - DataType::Utf8 => spark_read_side_padding_internal::(array, *length, truncate), + DataType::Utf8 => { + spark_read_side_padding_internal::(array, truncate, rpad_arg) + } + DataType::LargeUtf8 => { + spark_read_side_padding_internal::(array, truncate, rpad_arg) + } + // Dictionary support required for SPARK-48498 + DataType::Dictionary(_, value_type) => { + let dict = as_dictionary_array::(array); + let col = if value_type.as_ref() == &DataType::Utf8 { + spark_read_side_padding_internal::(dict.values(), truncate, rpad_arg)? + } else { + spark_read_side_padding_internal::(dict.values(), truncate, rpad_arg)? + }; + // col consists of an array, so arg of to_array() is not used. Can be anything + let values = col.to_array(0)?; + let result = DictionaryArray::try_new(dict.keys().clone(), values)?; + Ok(ColumnarValue::Array(make_array(result.into()))) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function rpad/read_side_padding", + ))), + } + } + [ColumnarValue::Array(array), ColumnarValue::Array(array_int)] => { + let rpad_arg = RPadArgument::ColArray(Arc::clone(array_int)); + match array.data_type() { + DataType::Utf8 => { + spark_read_side_padding_internal::(array, truncate, rpad_arg) + } DataType::LargeUtf8 => { - spark_read_side_padding_internal::(array, *length, truncate) + spark_read_side_padding_internal::(array, truncate, rpad_arg) } // Dictionary support required for SPARK-48498 DataType::Dictionary(_, value_type) => { let dict = as_dictionary_array::(array); let col = if value_type.as_ref() == &DataType::Utf8 { - spark_read_side_padding_internal::(dict.values(), *length, truncate)? + spark_read_side_padding_internal::(dict.values(), truncate, rpad_arg)? } else { - spark_read_side_padding_internal::(dict.values(), *length, truncate)? + spark_read_side_padding_internal::(dict.values(), truncate, rpad_arg)? }; // col consists of an array, so arg of to_array() is not used. Can be anything let values = col.to_array(0)?; @@ -71,44 +101,101 @@ fn spark_read_side_padding2( } } +enum RPadArgument { + ConstLength(i32), + ColArray(ArrayRef), +} + fn spark_read_side_padding_internal( array: &ArrayRef, - length: i32, truncate: bool, + rpad_argument: RPadArgument, ) -> Result { let string_array = as_generic_string_array::(array)?; - let length = 0.max(length) as usize; - let space_string = " ".repeat(length); + match rpad_argument { + RPadArgument::ColArray(array_int) => { + let int_pad_array = array_int.as_primitive::(); + let mut str_pad_value_map = HashMap::new(); + for i in 0..string_array.len() { + if string_array.is_null(i) || int_pad_array.is_null(i) { + continue; // skip nulls + } + str_pad_value_map.insert(string_array.value(i), int_pad_array.value(i)); + } - let mut builder = - GenericStringBuilder::::with_capacity(string_array.len(), string_array.len() * length); + let mut builder = GenericStringBuilder::::with_capacity( + str_pad_value_map.len(), + str_pad_value_map.len() * int_pad_array.len(), + ); - for string in string_array.iter() { - match string { - Some(string) => { - // It looks Spark's UTF8String is closer to chars rather than graphemes - // https://stackoverflow.com/a/46290728 - let char_len = string.chars().count(); - if length <= char_len { - if truncate { - let idx = string - .char_indices() - .nth(length) - .map(|(i, _)| i) - .unwrap_or(string.len()); - builder.append_value(&string[..idx]); - } else { - builder.append_value(string); + for string in string_array.iter() { + match string { + Some(string) => { + // It looks Spark's UTF8String is closer to chars rather than graphemes + // https://stackoverflow.com/a/46290728 + let char_len = string.chars().count(); + let length: usize = 0.max(*str_pad_value_map.get(string).unwrap()) as usize; + let space_string = " ".repeat(length); + if length <= char_len { + if truncate { + let idx = string + .char_indices() + .nth(length) + .map(|(i, _)| i) + .unwrap_or(string.len()); + builder.append_value(&string[..idx]); + } else { + builder.append_value(string); + } + } else { + // write_str updates only the value buffer, not null nor offset buffer + // This is convenient for concatenating str(s) + builder.write_str(string)?; + builder.append_value(&space_string[char_len..]); + } + } + _ => builder.append_null(), + } + } + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + } + RPadArgument::ConstLength(length) => { + let length = 0.max(length) as usize; + let space_string = " ".repeat(length); + + let mut builder = GenericStringBuilder::::with_capacity( + string_array.len(), + string_array.len() * length, + ); + + for string in string_array.iter() { + match string { + Some(string) => { + // It looks Spark's UTF8String is closer to chars rather than graphemes + // https://stackoverflow.com/a/46290728 + let char_len = string.chars().count(); + if length <= char_len { + if truncate { + let idx = string + .char_indices() + .nth(length) + .map(|(i, _)| i) + .unwrap_or(string.len()); + builder.append_value(&string[..idx]); + } else { + builder.append_value(string); + } + } else { + // write_str updates only the value buffer, not null nor offset buffer + // This is convenient for concatenating str(s) + builder.write_str(string)?; + builder.append_value(&space_string[char_len..]); + } } - } else { - // write_str updates only the value buffer, not null nor offset buffer - // This is convenient for concatenating str(s) - builder.write_str(string)?; - builder.append_value(&space_string[char_len..]); + _ => builder.append_null(), } } - _ => builder.append_null(), + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) } } - Ok(ColumnarValue::Array(Arc::new(builder.finish()))) } diff --git a/native/spark-expr/src/static_invoke/mod.rs b/native/spark-expr/src/static_invoke/mod.rs index 39735f1569..12024c22b0 100644 --- a/native/spark-expr/src/static_invoke/mod.rs +++ b/native/spark-expr/src/static_invoke/mod.rs @@ -17,4 +17,5 @@ mod char_varchar_utils; +pub use char_varchar_utils::{spark_left_side_padding, spark_lpad}; pub use char_varchar_utils::{spark_read_side_padding, spark_rpad}; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 1b72521270..0ab28aaf65 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -158,7 +158,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[DateAdd] -> CometDateAdd, classOf[DateSub] -> CometDateSub, classOf[TruncDate] -> CometTruncDate, - classOf[TruncTimestamp] -> CometTruncTimestamp) + classOf[TruncTimestamp] -> CometTruncTimestamp, + classOf[StringLPad] -> CometStringLPad) /** * Mapping of Spark aggregate expression class to Comet expression handler. diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index 75e7e8bd4c..34b32f7dfe 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -19,7 +19,7 @@ package org.apache.comet.serde -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression, Like, Literal, RLike, StringRPad, Substring} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression, Like, Literal, RLike, StringLPad, StringRPad, Substring} import org.apache.spark.sql.types.{DataTypes, LongType, StringType} import org.apache.comet.CometConf @@ -180,3 +180,23 @@ object CometStringRPad extends CometExpressionSerde { } } } + +object CometStringLPad extends CometExpressionSerde { + + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + val stringLPad = expr.asInstanceOf[StringLPad] + stringLPad.pad match { + case Literal(str, DataTypes.StringType) if str.toString == " " => + scalarFunctionExprToProto( + "lpad", + exprToProtoInternal(stringLPad.str, inputs, binding), + exprToProtoInternal(stringLPad.len, inputs, binding)) + case _ => + withInfo(expr, "StringLPad with non-space characters is not supported") + None + } + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 60435cee7b..1bc8704a9e 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -323,6 +323,26 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("fix_rpad") { + withTable("t1") { + sql("create table t1(c1 varchar(100), c2 int) using parquet") + sql("insert into t1 values('IfIWasARoadIWouldBeBent', 10)") + sql("insert into t1 values('IfIWereATrainIwouldBeLate', 9)") + sql("insert into t1 values(NULL, 10)") + val res = sql("select rpad(c1,c2) , rpad(c1,5) from t1 order by c1") + checkSparkAnswerAndOperator(res) + } + } + + test("test_lpad_expression") { + withTable("t1") { + sql("create table t1(c1 varchar(100), c2 int) using parquet") + sql("insert into t1 values('IfIWasARoadIWouldBeBent', 10)") + val res = sql("select lpad(c1, 50), lpad(c1, 5) from t1 order by c1") + checkSparkAnswerAndOperator(res) + } + } + test("dictionary arithmetic") { // TODO: test ANSI mode withSQLConf(SQLConf.ANSI_ENABLED.key -> "false", "parquet.enable.dictionary" -> "true") { From a586a5739aa5129e95fc865788d092097ce9dd4f Mon Sep 17 00:00:00 2001 From: Bhargava Vadlamani Date: Sat, 9 Aug 2025 11:33:45 -0700 Subject: [PATCH 2/2] implement_comet_native_lpad_expr --- .../char_varchar_utils/read_side_padding.rs | 155 ++++-------------- .../apache/comet/CometExpressionSuite.scala | 11 -- 2 files changed, 34 insertions(+), 132 deletions(-) diff --git a/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs b/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs index e92059ccf0..6e56d9d86f 100644 --- a/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs +++ b/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs @@ -18,10 +18,10 @@ use arrow::array::builder::GenericStringBuilder; use arrow::array::cast::as_dictionary_array; use arrow::array::types::Int32Type; -use arrow::array::{make_array, Array, AsArray, DictionaryArray}; +use arrow::array::{make_array, Array, DictionaryArray}; use arrow::array::{ArrayRef, OffsetSizeTrait}; use arrow::datatypes::DataType; -use datafusion::common::{cast::as_generic_string_array, DataFusionError, HashMap, ScalarValue}; +use datafusion::common::{cast::as_generic_string_array, DataFusionError, ScalarValue}; use datafusion::physical_plan::ColumnarValue; use std::fmt::Write; use std::sync::Arc; @@ -42,48 +42,18 @@ fn spark_read_side_padding2( ) -> Result { match args { [ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => { - let rpad_arg = RPadArgument::ConstLength(*length); match array.data_type() { - DataType::Utf8 => { - spark_read_side_padding_internal::(array, truncate, rpad_arg) - } - DataType::LargeUtf8 => { - spark_read_side_padding_internal::(array, truncate, rpad_arg) - } - // Dictionary support required for SPARK-48498 - DataType::Dictionary(_, value_type) => { - let dict = as_dictionary_array::(array); - let col = if value_type.as_ref() == &DataType::Utf8 { - spark_read_side_padding_internal::(dict.values(), truncate, rpad_arg)? - } else { - spark_read_side_padding_internal::(dict.values(), truncate, rpad_arg)? - }; - // col consists of an array, so arg of to_array() is not used. Can be anything - let values = col.to_array(0)?; - let result = DictionaryArray::try_new(dict.keys().clone(), values)?; - Ok(ColumnarValue::Array(make_array(result.into()))) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function rpad/read_side_padding", - ))), - } - } - [ColumnarValue::Array(array), ColumnarValue::Array(array_int)] => { - let rpad_arg = RPadArgument::ColArray(Arc::clone(array_int)); - match array.data_type() { - DataType::Utf8 => { - spark_read_side_padding_internal::(array, truncate, rpad_arg) - } + DataType::Utf8 => spark_read_side_padding_internal::(array, *length, truncate), DataType::LargeUtf8 => { - spark_read_side_padding_internal::(array, truncate, rpad_arg) + spark_read_side_padding_internal::(array, *length, truncate) } // Dictionary support required for SPARK-48498 DataType::Dictionary(_, value_type) => { let dict = as_dictionary_array::(array); let col = if value_type.as_ref() == &DataType::Utf8 { - spark_read_side_padding_internal::(dict.values(), truncate, rpad_arg)? + spark_read_side_padding_internal::(dict.values(), *length, truncate)? } else { - spark_read_side_padding_internal::(dict.values(), truncate, rpad_arg)? + spark_read_side_padding_internal::(dict.values(), *length, truncate)? }; // col consists of an array, so arg of to_array() is not used. Can be anything let values = col.to_array(0)?; @@ -101,101 +71,44 @@ fn spark_read_side_padding2( } } -enum RPadArgument { - ConstLength(i32), - ColArray(ArrayRef), -} - fn spark_read_side_padding_internal( array: &ArrayRef, + length: i32, truncate: bool, - rpad_argument: RPadArgument, ) -> Result { let string_array = as_generic_string_array::(array)?; - match rpad_argument { - RPadArgument::ColArray(array_int) => { - let int_pad_array = array_int.as_primitive::(); - let mut str_pad_value_map = HashMap::new(); - for i in 0..string_array.len() { - if string_array.is_null(i) || int_pad_array.is_null(i) { - continue; // skip nulls - } - str_pad_value_map.insert(string_array.value(i), int_pad_array.value(i)); - } - - let mut builder = GenericStringBuilder::::with_capacity( - str_pad_value_map.len(), - str_pad_value_map.len() * int_pad_array.len(), - ); - - for string in string_array.iter() { - match string { - Some(string) => { - // It looks Spark's UTF8String is closer to chars rather than graphemes - // https://stackoverflow.com/a/46290728 - let char_len = string.chars().count(); - let length: usize = 0.max(*str_pad_value_map.get(string).unwrap()) as usize; - let space_string = " ".repeat(length); - if length <= char_len { - if truncate { - let idx = string - .char_indices() - .nth(length) - .map(|(i, _)| i) - .unwrap_or(string.len()); - builder.append_value(&string[..idx]); - } else { - builder.append_value(string); - } - } else { - // write_str updates only the value buffer, not null nor offset buffer - // This is convenient for concatenating str(s) - builder.write_str(string)?; - builder.append_value(&space_string[char_len..]); - } - } - _ => builder.append_null(), - } - } - Ok(ColumnarValue::Array(Arc::new(builder.finish()))) - } - RPadArgument::ConstLength(length) => { - let length = 0.max(length) as usize; - let space_string = " ".repeat(length); + let length = 0.max(length) as usize; + let space_string = " ".repeat(length); - let mut builder = GenericStringBuilder::::with_capacity( - string_array.len(), - string_array.len() * length, - ); + let mut builder = + GenericStringBuilder::::with_capacity(string_array.len(), string_array.len() * length); - for string in string_array.iter() { - match string { - Some(string) => { - // It looks Spark's UTF8String is closer to chars rather than graphemes - // https://stackoverflow.com/a/46290728 - let char_len = string.chars().count(); - if length <= char_len { - if truncate { - let idx = string - .char_indices() - .nth(length) - .map(|(i, _)| i) - .unwrap_or(string.len()); - builder.append_value(&string[..idx]); - } else { - builder.append_value(string); - } - } else { - // write_str updates only the value buffer, not null nor offset buffer - // This is convenient for concatenating str(s) - builder.write_str(string)?; - builder.append_value(&space_string[char_len..]); - } + for string in string_array.iter() { + match string { + Some(string) => { + // It looks Spark's UTF8String is closer to chars rather than graphemes + // https://stackoverflow.com/a/46290728 + let char_len = string.chars().count(); + if length <= char_len { + if truncate { + let idx = string + .char_indices() + .nth(length) + .map(|(i, _)| i) + .unwrap_or(string.len()); + builder.append_value(&string[..idx]); + } else { + builder.append_value(string); } - _ => builder.append_null(), + } else { + // write_str updates only the value buffer, not null nor offset buffer + // This is convenient for concatenating str(s) + builder.write_str(string)?; + builder.append_value(&space_string[char_len..]); } } - Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + _ => builder.append_null(), } } + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) } diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 1bc8704a9e..afcf43e813 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -323,17 +323,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("fix_rpad") { - withTable("t1") { - sql("create table t1(c1 varchar(100), c2 int) using parquet") - sql("insert into t1 values('IfIWasARoadIWouldBeBent', 10)") - sql("insert into t1 values('IfIWereATrainIwouldBeLate', 9)") - sql("insert into t1 values(NULL, 10)") - val res = sql("select rpad(c1,c2) , rpad(c1,5) from t1 order by c1") - checkSparkAnswerAndOperator(res) - } - } - test("test_lpad_expression") { withTable("t1") { sql("create table t1(c1 varchar(100), c2 int) using parquet")