diff --git a/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs b/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs index d969b6279b..89485ddec4 100644 --- a/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs +++ b/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs @@ -204,14 +204,21 @@ fn spark_read_side_padding_internal( ); for (string, length) in string_array.iter().zip(int_pad_array) { + let length = length.unwrap(); match string { - Some(string) => builder.append_value(add_padding_string( - string.parse().unwrap(), - length.unwrap() as usize, - truncate, - pad_string, - is_left_pad, - )?), + Some(string) => { + if length >= 0 { + builder.append_value(add_padding_string( + string.parse().unwrap(), + length as usize, + truncate, + pad_string, + is_left_pad, + )?) + } else { + builder.append_value(""); + } + } _ => builder.append_null(), } } diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index c6f5a85089..3d4bacfa26 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -162,6 +162,16 @@ object CometRLike extends CometExpressionSerde[RLike] { object CometStringRPad extends CometExpressionSerde[StringRPad] { + override def getSupportLevel(expr: StringRPad): SupportLevel = { + if (expr.str.isInstanceOf[Literal]) { + return Unsupported(Some("Scalar values are not supported for the str argument")) + } + if (!expr.pad.isInstanceOf[Literal]) { + return Unsupported(Some("Only scalar values are supported for the pad argument")) + } + Compatible() + } + override def convert( expr: StringRPad, inputs: Seq[Attribute], @@ -177,21 +187,16 @@ object CometStringRPad extends CometExpressionSerde[StringRPad] { object CometStringLPad extends CometExpressionSerde[StringLPad] { - /** - * Convert a Spark expression into a protocol buffer representation that can be passed into - * native code. - * - * @param expr - * The Spark expression. - * @param inputs - * The input attributes. - * @param binding - * Whether the attributes are bound (this is only relevant in aggregate expressions). - * @return - * Protocol buffer representation, or None if the expression could not be converted. In this - * case it is expected that the input expression will have been tagged with reasons why it - * could not be converted. - */ + override def getSupportLevel(expr: StringLPad): SupportLevel = { + if (expr.str.isInstanceOf[Literal]) { + return Unsupported(Some("Scalar values are not supported for the str argument")) + } + if (!expr.pad.isInstanceOf[Literal]) { + return Unsupported(Some("Only scalar values are supported for the pad argument")) + } + Compatible() + } + override def convert( expr: StringLPad, inputs: Seq[Attribute], diff --git a/spark/src/main/scala/org/apache/comet/testing/FuzzDataGenerator.scala b/spark/src/main/scala/org/apache/comet/testing/FuzzDataGenerator.scala index 188da1d799..5363fda15e 100644 --- a/spark/src/main/scala/org/apache/comet/testing/FuzzDataGenerator.scala +++ b/spark/src/main/scala/org/apache/comet/testing/FuzzDataGenerator.scala @@ -194,11 +194,13 @@ object FuzzDataGenerator { case 1 => r.nextInt().toByte.toString case 2 => r.nextLong().toString case 3 => r.nextDouble().toString - case 4 => RandomStringUtils.randomAlphabetic(8) + case 4 => RandomStringUtils.randomAlphabetic(options.maxStringLength) case 5 => // use a constant value to trigger dictionary encoding "dict_encode_me!" - case _ => r.nextString(8) + case 6 if options.customStrings.nonEmpty => + randomChoice(options.customStrings, r) + case _ => r.nextString(options.maxStringLength) } }) case DataTypes.BinaryType => @@ -221,6 +223,11 @@ object FuzzDataGenerator { case _ => throw new IllegalStateException(s"Cannot generate data for $dataType yet") } } + + private def randomChoice[T](list: Seq[T], r: Random): T = { + list(r.nextInt(list.length)) + } + } object SchemaGenOptions { @@ -250,4 +257,6 @@ case class SchemaGenOptions( case class DataGenOptions( allowNull: Boolean = true, generateNegativeZero: Boolean = true, - baseDate: Long = FuzzDataGenerator.defaultBaseDate) + baseDate: Long = FuzzDataGenerator.defaultBaseDate, + customStrings: Seq[String] = Seq.empty, + maxStringLength: Int = 8) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 1eca17dccc..ddbe7d14e2 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -414,41 +414,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } } - test("Verify rpad expr support for second arg instead of just literal") { - val data = Seq(("IfIWasARoadIWouldBeBent", 10), ("తెలుగు", 2)) - withParquetTable(data, "t1") { - val res = sql("select rpad(_1,_2) , rpad(_1,2) from t1 order by _1") - checkSparkAnswerAndOperator(res) - } - } - - test("RPAD with character support other than default space") { - val data = Seq(("IfIWasARoadIWouldBeBent", 10), ("hi", 2)) - withParquetTable(data, "t1") { - val res = sql( - """ select rpad(_1,_2,'?'), rpad(_1,_2,'??') , rpad(_1,2, '??'), hex(rpad(unhex('aabb'), 5)), - rpad(_1, 5, '??') from t1 order by _1 """.stripMargin) - checkSparkAnswerAndOperator(res) - } - } - - test("test lpad expression support") { - val data = Seq(("IfIWasARoadIWouldBeBent", 10), ("తెలుగు", 2)) - withParquetTable(data, "t1") { - val res = sql("select lpad(_1,_2) , lpad(_1,2) from t1 order by _1") - checkSparkAnswerAndOperator(res) - } - } - - test("LPAD with character support other than default space") { - val data = Seq(("IfIWasARoadIWouldBeBent", 10), ("hi", 2)) - withParquetTable(data, "t1") { - val res = sql( - """ select lpad(_1,_2,'?'), lpad(_1,_2,'??') , lpad(_1,2, '??'), hex(lpad(unhex('aabb'), 5)), - rpad(_1, 5, '??') from t1 order by _1 """.stripMargin) - checkSparkAnswerAndOperator(res) - } - } test("dictionary arithmetic") { // TODO: test ANSI mode @@ -2292,33 +2257,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("rpad") { - val table = "rpad" - val gen = new DataGenerator(new Random(42)) - withTable(table) { - // generate some data - val dataChars = "abc123" - sql(s"create table $table(id int, name1 char(8), name2 varchar(8)) using parquet") - val testData = gen.generateStrings(100, dataChars, 6) ++ Seq( - "é", // unicode 'e\\u{301}' - "é" // unicode '\\u{e9}' - ) - testData.zipWithIndex.foreach { x => - sql(s"insert into $table values(${x._2}, '${x._1}', '${x._1}')") - } - // test 2-arg version - checkSparkAnswerAndOperator( - s"SELECT id, rpad(name1, 10), rpad(name2, 10) FROM $table ORDER BY id") - // test 3-arg version - for (length <- Seq(2, 10)) { - checkSparkAnswerAndOperator( - s"SELECT id, name1, rpad(name1, $length, ' ') FROM $table ORDER BY id") - checkSparkAnswerAndOperator( - s"SELECT id, name2, rpad(name2, $length, ' ') FROM $table ORDER BY id") - } - } - } - test("isnan") { Seq("true", "false").foreach { dictionary => withSQLConf("parquet.enable.dictionary" -> dictionary) { diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index 44d40cf1c1..a63aba8da9 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -19,12 +19,133 @@ package org.apache.comet +import scala.util.Random + import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.spark.sql.{CometTestBase, DataFrame} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{DataTypes, StructField, StructType} + +import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator} class CometStringExpressionSuite extends CometTestBase { + test("lpad string") { + testStringPadding("lpad") + } + + test("rpad string") { + testStringPadding("rpad") + } + + test("lpad binary") { + testBinaryPadding("lpad") + } + + test("rpad binary") { + testBinaryPadding("rpad") + } + + private def testStringPadding(expr: String): Unit = { + val r = new Random(42) + val schema = StructType( + Seq( + StructField("str", DataTypes.StringType, nullable = true), + StructField("len", DataTypes.IntegerType, nullable = true), + StructField("pad", DataTypes.StringType, nullable = true))) + // scalastyle:off + val edgeCases = Seq( + "é", // unicode 'e\\u{301}' + "é", // unicode '\\u{e9}' + "తెలుగు") + // scalastyle:on + val df = FuzzDataGenerator.generateDataFrame( + r, + spark, + schema, + 1000, + DataGenOptions(maxStringLength = 6, customStrings = edgeCases)) + df.createOrReplaceTempView("t1") + + // test all combinations of scalar and array arguments + for (str <- Seq("'hello'", "str")) { + for (len <- Seq("6", "-6", "0", "len % 10")) { + for (pad <- Seq(Some("'x'"), Some("'zzz'"), Some("pad"), None)) { + val sql = pad match { + case Some(p) => + // 3 args + s"SELECT $str, $len, $expr($str, $len, $p) FROM t1 ORDER BY str, len, pad" + case _ => + // 2 args (default pad of ' ') + s"SELECT $str, $len, $expr($str, $len) FROM t1 ORDER BY str, len, pad" + } + val isLiteralStr = str == "'hello'" + val isLiteralLen = !len.contains("len") + val isLiteralPad = !pad.contains("pad") + if (isLiteralStr && isLiteralLen && isLiteralPad) { + // all arguments are literal, so Spark constant folding will kick in + // and pad function will not be evaluated by Comet + checkSparkAnswer(sql) + } else if (isLiteralStr) { + checkSparkAnswerAndFallbackReason( + sql, + "Scalar values are not supported for the str argument") + } else if (!isLiteralPad) { + checkSparkAnswerAndFallbackReason( + sql, + "Only scalar values are supported for the pad argument") + } else { + checkSparkAnswerAndOperator(sql) + } + } + } + } + } + + private def testBinaryPadding(expr: String): Unit = { + val r = new Random(42) + val schema = StructType( + Seq( + StructField("str", DataTypes.BinaryType, nullable = true), + StructField("len", DataTypes.IntegerType, nullable = true), + StructField("pad", DataTypes.BinaryType, nullable = true))) + val df = FuzzDataGenerator.generateDataFrame(r, spark, schema, 1000, DataGenOptions()) + df.createOrReplaceTempView("t1") + + // test all combinations of scalar and array arguments + for (str <- Seq("unhex('DDEEFF')", "str")) { + // Spark does not support negative length for lpad/rpad with binary input and Comet does + // not support abs yet, so use `10 + len % 10` to avoid negative length + for (len <- Seq("6", "0", "10 + len % 10")) { + for (pad <- Seq(Some("unhex('CAFE')"), Some("pad"), None)) { + + val sql = pad match { + case Some(p) => + // 3 args + s"SELECT $str, $len, $expr($str, $len, $p) FROM t1 ORDER BY str, len, pad" + case _ => + // 2 args (default pad of ' ') + s"SELECT $str, $len, $expr($str, $len) FROM t1 ORDER BY str, len, pad" + } + + val isLiteralStr = str != "str" + val isLiteralLen = !len.contains("len") + val isLiteralPad = !pad.contains("pad") + + if (isLiteralStr && isLiteralLen && isLiteralPad) { + // all arguments are literal, so Spark constant folding will kick in + // and pad function will not be evaluated by Comet + checkSparkAnswer(sql) + } else { + // Comet will fall back to Spark because the plan contains a staticinvoke instruction + // which is not supported + checkSparkAnswerAndFallbackReason(sql, "staticinvoke is not supported") + } + } + } + } + } + test("Various String scalar functions") { val table = "names" withTable(table) { diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 3a4e52b4ad..2308858f61 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -166,6 +166,13 @@ abstract class CometTestBase (sparkPlan, dfComet.queryExecution.executedPlan) } + /** Check for the correct results as well as the expected fallback reason */ + def checkSparkAnswerAndFallbackReason(sql: String, fallbackReason: String): Unit = { + val (_, cometPlan) = checkSparkAnswer(sql) + val explain = new ExtendedExplainInfo().generateVerboseExtendedInfo(cometPlan) + assert(explain.contains(fallbackReason)) + } + protected def checkSparkAnswerAndOperator(query: String, excludedClasses: Class[_]*): Unit = { checkSparkAnswerAndOperator(sql(query), excludedClasses: _*) }