Skip to content

Commit b09002f

Browse files
committed
address_review_comments
1 parent b65602b commit b09002f

File tree

1 file changed

+95
-31
lines changed

1 file changed

+95
-31
lines changed

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 95 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -575,8 +575,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
575575
// CAST from StringType
576576

577577
test("cast StringType to BooleanType") {
578-
// TODO fix for Spark 4.0.0
579-
assume(!isSpark40Plus)
580578
val testValues =
581579
(Seq("TRUE", "True", "true", "FALSE", "False", "false", "1", "0", "", null) ++
582580
gen.generateStrings(dataSize, "truefalseTRUEFALSEyesno10" + whitespaceChars, 8)).toDF("a")
@@ -617,35 +615,27 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
617615
)
618616

619617
test("cast StringType to ByteType") {
620-
// TODO fix for Spark 4.0.0
621-
assume(!isSpark40Plus)
622618
// test with hand-picked values
623619
castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ByteType)
624620
// fuzz test
625621
castTest(gen.generateStrings(dataSize, numericPattern, 4).toDF("a"), DataTypes.ByteType)
626622
}
627623

628624
test("cast StringType to ShortType") {
629-
// TODO fix for Spark 4.0.0
630-
assume(!isSpark40Plus)
631625
// test with hand-picked values
632626
castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ShortType)
633627
// fuzz test
634628
castTest(gen.generateStrings(dataSize, numericPattern, 5).toDF("a"), DataTypes.ShortType)
635629
}
636630

637631
test("cast StringType to IntegerType") {
638-
// TODO fix for Spark 4.0.0
639-
assume(!isSpark40Plus)
640632
// test with hand-picked values
641633
castTest(castStringToIntegralInputs.toDF("a"), DataTypes.IntegerType)
642634
// fuzz test
643635
castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.IntegerType)
644636
}
645637

646638
test("cast StringType to LongType") {
647-
// TODO fix for Spark 4.0.0
648-
assume(!isSpark40Plus)
649639
// test with hand-picked values
650640
castTest(castStringToIntegralInputs.toDF("a"), DataTypes.LongType)
651641
// fuzz test
@@ -672,7 +662,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
672662
// https://github.com/apache/datafusion-comet/issues/326
673663
castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.DoubleType)
674664
}
675-
676665
test("cast StringType to DoubleType (partial support)") {
677666
withSQLConf(
678667
CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true",
@@ -684,21 +673,88 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
684673
}
685674
}
686675

676+
// This is to pass the first `all cast combinations are covered`
687677
ignore("cast StringType to DecimalType(10,2)") {
688-
// https://github.com/apache/datafusion-comet/issues/325
689-
val values = gen.generateStrings(dataSize, numericPattern, 8).toDF("a")
690-
castTest(values, DataTypes.createDecimalType(10, 2))
678+
val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a")
679+
castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = false)
691680
}
692681

693-
test("cast StringType to DecimalType(10,2) (partial support)") {
694-
withSQLConf(
695-
CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true",
696-
SQLConf.ANSI_ENABLED.key -> "false") {
697-
val values = gen
698-
.generateStrings(dataSize, "0123456789.", 8)
699-
.filter(_.exists(_.isDigit))
700-
.toDF("a")
701-
castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = false)
682+
test("cast StringType to DecimalType(10,2) (does not support fullwidth unicode digits)") {
683+
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
684+
// TODO fix for Spark 4.0.0
685+
assume(!isSpark40Plus)
686+
val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a")
687+
Seq(true, false).foreach(ansiEnabled =>
688+
castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = ansiEnabled))
689+
}
690+
}
691+
692+
test("cast StringType to DecimalType(2,2)") {
693+
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
694+
// TODO fix for Spark 4.0.0
695+
assume(!isSpark40Plus)
696+
val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a")
697+
Seq(true, false).foreach(ansiEnabled =>
698+
castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = ansiEnabled))
699+
}
700+
}
701+
702+
test("cast StringType to DecimalType(38,10) high precision") {
703+
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
704+
// TODO fix for Spark 4.0.0
705+
assume(!isSpark40Plus)
706+
val values = gen.generateStrings(dataSize, numericPattern, 38).toDF("a")
707+
Seq(true, false).foreach(ansiEnabled =>
708+
castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = ansiEnabled))
709+
}
710+
}
711+
712+
test("cast StringType to DecimalType(10,2) basic values") {
713+
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
714+
// TODO fix for Spark 4.0.0
715+
assume(!isSpark40Plus)
716+
val values = Seq(
717+
"123.45",
718+
"-67.89",
719+
"-67.89",
720+
"-67.895",
721+
"67.895",
722+
"0.001",
723+
"999.99",
724+
"123.456",
725+
"123.45D",
726+
".5",
727+
"5.",
728+
"+123.45",
729+
" 123.45 ",
730+
"inf",
731+
"",
732+
"abc",
733+
null).toDF("a")
734+
Seq(true, false).foreach(ansiEnabled =>
735+
castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = ansiEnabled))
736+
}
737+
}
738+
739+
test("cast StringType to Decimal type scientific notation") {
740+
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
741+
// TODO fix for Spark 4.0.0
742+
assume(!isSpark40Plus)
743+
val values = Seq(
744+
"1.23E-5",
745+
"1.23e10",
746+
"1.23E+10",
747+
"-1.23e-5",
748+
"1e5",
749+
"1E-2",
750+
"-1.5e3",
751+
"1.23E0",
752+
"0e0",
753+
"1.23e",
754+
"e5",
755+
null).toDF("a")
756+
Seq(true, false).foreach(ansiEnabled =>
757+
castTest(values, DataTypes.createDecimalType(23, 8), testAnsi = ansiEnabled))
702758
}
703759
}
704760

@@ -707,8 +763,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
707763
}
708764

709765
test("cast StringType to DateType") {
710-
// TODO fix for Spark 4.0.0
711-
assume(!isSpark40Plus)
712766
val validDates = Seq(
713767
"262142-01-01",
714768
"262142-01-01 ",
@@ -1290,15 +1344,26 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
12901344
else cometException.getMessage
12911345
// this if branch should only check decimal to decimal cast and errors when output precision, scale causes overflow.
12921346
if (df.schema("a").dataType.typeName.contains("decimal") && toType.typeName
1293-
.contains("decimal") && sparkMessage.contains("cannot be represented as")) {
1347+
.contains("decimal") && sparkMessage.contains("cannot be represented as")) {
12941348
assert(cometMessage.contains("too large to store"))
12951349
} else {
12961350
if (CometSparkSessionExtensions.isSpark40Plus) {
12971351
// for Spark 4 we expect to sparkException carries the message
1298-
assert(
1299-
sparkException.getMessage
1300-
.replace(".WITH_SUGGESTION] ", "]")
1301-
.startsWith(cometMessage))
1352+
assert(sparkMessage.contains("SQLSTATE"))
1353+
if (sparkMessage.startsWith("[NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION]")) {
1354+
assert(
1355+
sparkMessage.replace(".WITH_SUGGESTION] ", "]").startsWith(cometMessage))
1356+
} else if (cometMessage.startsWith("[CAST_INVALID_INPUT]") || cometMessage
1357+
.startsWith("[CAST_OVERFLOW]")) {
1358+
assert(
1359+
sparkMessage.startsWith(
1360+
cometMessage
1361+
.replace(
1362+
"If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.",
1363+
"")))
1364+
} else {
1365+
assert(sparkMessage.startsWith(cometMessage))
1366+
}
13021367
} else {
13031368
// for Spark 3.4 we expect to reproduce the error message exactly
13041369
assert(cometMessage == sparkMessage)
@@ -1325,5 +1390,4 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
13251390
df.write.mode(SaveMode.Overwrite).parquet(filename)
13261391
spark.read.parquet(filename)
13271392
}
1328-
13291393
}

0 commit comments

Comments
 (0)