Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -204,14 +204,21 @@ fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
);

for (string, length) in string_array.iter().zip(int_pad_array) {
let length = length.unwrap();
match string {
Some(string) => builder.append_value(add_padding_string(
string.parse().unwrap(),
length.unwrap() as usize,
truncate,
pad_string,
is_left_pad,
)?),
Some(string) => {
if length >= 0 {
builder.append_value(add_padding_string(
string.parse().unwrap(),
length as usize,
truncate,
pad_string,
is_left_pad,
)?)
} else {
builder.append_value("");
}
}
_ => builder.append_null(),
}
}
Expand Down
35 changes: 20 additions & 15 deletions spark/src/main/scala/org/apache/comet/serde/strings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,16 @@ object CometRLike extends CometExpressionSerde[RLike] {

object CometStringRPad extends CometExpressionSerde[StringRPad] {

override def getSupportLevel(expr: StringRPad): SupportLevel = {
if (expr.str.isInstanceOf[Literal]) {
return Unsupported(Some("Scalar values are not supported for the str argument"))
}
if (!expr.pad.isInstanceOf[Literal]) {
return Unsupported(Some("Only scalar values are supported for the pad argument"))
}
Compatible()
}

override def convert(
expr: StringRPad,
inputs: Seq[Attribute],
Expand All @@ -177,21 +187,16 @@ object CometStringRPad extends CometExpressionSerde[StringRPad] {

object CometStringLPad extends CometExpressionSerde[StringLPad] {

/**
* Convert a Spark expression into a protocol buffer representation that can be passed into
* native code.
*
* @param expr
* The Spark expression.
* @param inputs
* The input attributes.
* @param binding
* Whether the attributes are bound (this is only relevant in aggregate expressions).
* @return
* Protocol buffer representation, or None if the expression could not be converted. In this
* case it is expected that the input expression will have been tagged with reasons why it
* could not be converted.
*/
override def getSupportLevel(expr: StringLPad): SupportLevel = {
if (expr.str.isInstanceOf[Literal]) {
return Unsupported(Some("Scalar values are not supported for the str argument"))
}
if (!expr.pad.isInstanceOf[Literal]) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if we'll ever hit this. As far as I can see (in functions.lpad), Spark expects the pad argument to be a literal as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spark doesn't require pad to be a literal:

scala> spark.sql("select a, lpad('foo', 6, a) from t1").show
+---+---------------+
|  a|lpad(foo, 6, a)|
+---+---------------+
|  $|         $$$foo|
|  @|         @@@foo|
+---+---------------+

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

case class StringRPad(str: Expression, len: Expression, pad: Expression = Literal(" "))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose this is a good example of the benefit of fuzz testing (which is how this issue was discovered). The fuzzer will generate test cases that most developers would not consider. It does seem unlikely that anyone would want to use a column for the pad value, but I suppose it is possible that someone may have that requirement.

return Unsupported(Some("Only scalar values are supported for the pad argument"))
}
Compatible()
}

override def convert(
expr: StringLPad,
inputs: Seq[Attribute],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,13 @@ object FuzzDataGenerator {
case 1 => r.nextInt().toByte.toString
case 2 => r.nextLong().toString
case 3 => r.nextDouble().toString
case 4 => RandomStringUtils.randomAlphabetic(8)
case 4 => RandomStringUtils.randomAlphabetic(options.maxStringLength)
case 5 =>
// use a constant value to trigger dictionary encoding
"dict_encode_me!"
case _ => r.nextString(8)
case 6 if options.customStrings.nonEmpty =>
randomChoice(options.customStrings, r)
case _ => r.nextString(options.maxStringLength)
}
})
case DataTypes.BinaryType =>
Expand All @@ -221,6 +223,11 @@ object FuzzDataGenerator {
case _ => throw new IllegalStateException(s"Cannot generate data for $dataType yet")
}
}

private def randomChoice[T](list: Seq[T], r: Random): T = {
list(r.nextInt(list.length))
}

}

object SchemaGenOptions {
Expand Down Expand Up @@ -250,4 +257,6 @@ case class SchemaGenOptions(
case class DataGenOptions(
allowNull: Boolean = true,
generateNegativeZero: Boolean = true,
baseDate: Long = FuzzDataGenerator.defaultBaseDate)
baseDate: Long = FuzzDataGenerator.defaultBaseDate,
customStrings: Seq[String] = Seq.empty,
maxStringLength: Int = 8)
62 changes: 0 additions & 62 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -414,41 +414,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}
}
test("Verify rpad expr support for second arg instead of just literal") {
val data = Seq(("IfIWasARoadIWouldBeBent", 10), ("తెలుగు", 2))
withParquetTable(data, "t1") {
val res = sql("select rpad(_1,_2) , rpad(_1,2) from t1 order by _1")
checkSparkAnswerAndOperator(res)
}
}

test("RPAD with character support other than default space") {
val data = Seq(("IfIWasARoadIWouldBeBent", 10), ("hi", 2))
withParquetTable(data, "t1") {
val res = sql(
""" select rpad(_1,_2,'?'), rpad(_1,_2,'??') , rpad(_1,2, '??'), hex(rpad(unhex('aabb'), 5)),
rpad(_1, 5, '??') from t1 order by _1 """.stripMargin)
checkSparkAnswerAndOperator(res)
}
}

test("test lpad expression support") {
val data = Seq(("IfIWasARoadIWouldBeBent", 10), ("తెలుగు", 2))
withParquetTable(data, "t1") {
val res = sql("select lpad(_1,_2) , lpad(_1,2) from t1 order by _1")
checkSparkAnswerAndOperator(res)
}
}

test("LPAD with character support other than default space") {
val data = Seq(("IfIWasARoadIWouldBeBent", 10), ("hi", 2))
withParquetTable(data, "t1") {
val res = sql(
""" select lpad(_1,_2,'?'), lpad(_1,_2,'??') , lpad(_1,2, '??'), hex(lpad(unhex('aabb'), 5)),
rpad(_1, 5, '??') from t1 order by _1 """.stripMargin)
checkSparkAnswerAndOperator(res)
}
}

test("dictionary arithmetic") {
// TODO: test ANSI mode
Expand Down Expand Up @@ -2292,33 +2257,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("rpad") {
val table = "rpad"
val gen = new DataGenerator(new Random(42))
withTable(table) {
// generate some data
val dataChars = "abc123"
sql(s"create table $table(id int, name1 char(8), name2 varchar(8)) using parquet")
val testData = gen.generateStrings(100, dataChars, 6) ++ Seq(
"é", // unicode 'e\\u{301}'
"é" // unicode '\\u{e9}'
)
testData.zipWithIndex.foreach { x =>
sql(s"insert into $table values(${x._2}, '${x._1}', '${x._1}')")
}
// test 2-arg version
checkSparkAnswerAndOperator(
s"SELECT id, rpad(name1, 10), rpad(name2, 10) FROM $table ORDER BY id")
// test 3-arg version
for (length <- Seq(2, 10)) {
checkSparkAnswerAndOperator(
s"SELECT id, name1, rpad(name1, $length, ' ') FROM $table ORDER BY id")
checkSparkAnswerAndOperator(
s"SELECT id, name2, rpad(name2, $length, ' ') FROM $table ORDER BY id")
}
}
}

test("isnan") {
Seq("true", "false").foreach { dictionary =>
withSQLConf("parquet.enable.dictionary" -> dictionary) {
Expand Down
121 changes: 121 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,133 @@

package org.apache.comet

import scala.util.Random

import org.apache.parquet.hadoop.ParquetOutputFormat
import org.apache.spark.sql.{CometTestBase, DataFrame}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataTypes, StructField, StructType}

import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator}

class CometStringExpressionSuite extends CometTestBase {

test("lpad string") {
testStringPadding("lpad")
}

test("rpad string") {
testStringPadding("rpad")
}

test("lpad binary") {
testBinaryPadding("lpad")
}

test("rpad binary") {
testBinaryPadding("rpad")
}

private def testStringPadding(expr: String): Unit = {
val r = new Random(42)
val schema = StructType(
Seq(
StructField("str", DataTypes.StringType, nullable = true),
StructField("len", DataTypes.IntegerType, nullable = true),
StructField("pad", DataTypes.StringType, nullable = true)))
// scalastyle:off
val edgeCases = Seq(
"é", // unicode 'e\\u{301}'
"é", // unicode '\\u{e9}'
"తెలుగు")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, what makes this an edge case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first two were added in #772 to make sure Comet was consistent with Spark even though Rust and Java have different ways of representing unicode and graphemes.

// scalastyle:on
val df = FuzzDataGenerator.generateDataFrame(
r,
spark,
schema,
1000,
DataGenOptions(maxStringLength = 6, customStrings = edgeCases))
df.createOrReplaceTempView("t1")

// test all combinations of scalar and array arguments
for (str <- Seq("'hello'", "str")) {
Copy link
Contributor

@hsiang-c hsiang-c Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spark doc says it also supports binary string input: e.g unhex('aabb').

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, thanks. That opens up another set of issues! 😭

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added separate tests for binary inputs

for (len <- Seq("6", "-6", "0", "len % 10")) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

for (pad <- Seq(Some("'x'"), Some("'zzz'"), Some("pad"), None)) {
val sql = pad match {
case Some(p) =>
// 3 args
s"SELECT $str, $len, $expr($str, $len, $p) FROM t1 ORDER BY str, len, pad"
case _ =>
// 2 args (default pad of ' ')
s"SELECT $str, $len, $expr($str, $len) FROM t1 ORDER BY str, len, pad"
}
val isLiteralStr = str == "'hello'"
val isLiteralLen = !len.contains("len")
val isLiteralPad = !pad.contains("pad")
if (isLiteralStr && isLiteralLen && isLiteralPad) {
// all arguments are literal, so Spark constant folding will kick in
// and pad function will not be evaluated by Comet
checkSparkAnswer(sql)
} else if (isLiteralStr) {
checkSparkAnswerAndFallbackReason(
sql,
"Scalar values are not supported for the str argument")
} else if (!isLiteralPad) {
checkSparkAnswerAndFallbackReason(
sql,
"Only scalar values are supported for the pad argument")
} else {
checkSparkAnswerAndOperator(sql)
}
}
}
}
}

private def testBinaryPadding(expr: String): Unit = {
val r = new Random(42)
val schema = StructType(
Seq(
StructField("str", DataTypes.BinaryType, nullable = true),
StructField("len", DataTypes.IntegerType, nullable = true),
StructField("pad", DataTypes.BinaryType, nullable = true)))
val df = FuzzDataGenerator.generateDataFrame(r, spark, schema, 1000, DataGenOptions())
df.createOrReplaceTempView("t1")

// test all combinations of scalar and array arguments
for (str <- Seq("unhex('DDEEFF')", "str")) {
// Spark does not support negative length for lpad/rpad with binary input and Comet does
// not support abs yet, so use `10 + len % 10` to avoid negative length
for (len <- Seq("6", "0", "10 + len % 10")) {
for (pad <- Seq(Some("unhex('CAFE')"), Some("pad"), None)) {

val sql = pad match {
case Some(p) =>
// 3 args
s"SELECT $str, $len, $expr($str, $len, $p) FROM t1 ORDER BY str, len, pad"
case _ =>
// 2 args (default pad of ' ')
s"SELECT $str, $len, $expr($str, $len) FROM t1 ORDER BY str, len, pad"
}

val isLiteralStr = str != "str"
val isLiteralLen = !len.contains("len")
val isLiteralPad = !pad.contains("pad")

if (isLiteralStr && isLiteralLen && isLiteralPad) {
// all arguments are literal, so Spark constant folding will kick in
// and pad function will not be evaluated by Comet
checkSparkAnswer(sql)
} else {
// Comet will fall back to Spark because the plan contains a staticinvoke instruction
// which is not supported
checkSparkAnswerAndFallbackReason(sql, "staticinvoke is not supported")
}
}
}
}
}

test("Various String scalar functions") {
val table = "names"
withTable(table) {
Expand Down
7 changes: 7 additions & 0 deletions spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,13 @@ abstract class CometTestBase
(sparkPlan, dfComet.queryExecution.executedPlan)
}

/** Check for the correct results as well as the expected fallback reason */
def checkSparkAnswerAndFallbackReason(sql: String, fallbackReason: String): Unit = {
val (_, cometPlan) = checkSparkAnswer(sql)
val explain = new ExtendedExplainInfo().generateVerboseExtendedInfo(cometPlan)
assert(explain.contains(fallbackReason))
}

protected def checkSparkAnswerAndOperator(query: String, excludedClasses: Class[_]*): Unit = {
checkSparkAnswerAndOperator(sql(query), excludedClasses: _*)
}
Expand Down
Loading