Skip to content

Commit 3062de2

Browse files
author
himadripal
committed
add more test cases for negative scale and higher precision
1 parent 471c2a7 commit 3062de2

File tree

2 files changed

+25
-6
lines changed

2 files changed

+25
-6
lines changed

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -896,17 +896,38 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
896896
}
897897

898898
test("cast between decimals with different precision and scale") {
899-
// cast between default Decimal(38, 18) to Decimal(9,1)
899+
// cast between default Decimal(38, 18) to Decimal(7,2)
900900
val values = Seq(BigDecimal("12345.6789"), BigDecimal("9876.5432"), BigDecimal("123.4567"))
901901
val df = withNulls(values).toDF("a")
902902
castTest(df, DataTypes.createDecimalType(7, 2))
903903
}
904904

905-
test("cast two between decimals with different precision and scale") {
905+
test("cast between decimals with lower precision and scale") {
906906
// cast between Decimal(10, 2) to Decimal(9,1)
907907
castTest(generateDecimalsPrecision10Scale2(), DataTypes.createDecimalType(9, 1))
908908
}
909909

910+
test("cast between decimals with higher precision than source") {
911+
// cast between Decimal(10, 2) to Decimal(10,4)
912+
withSQLConf("spark.comet.explainFallback.enabled" -> "true") {
913+
castTest(generateDecimalsPrecision10Scale2(), DataTypes.createDecimalType(10, 4))
914+
}
915+
}
916+
917+
test("cast between decimals with negative precision") {
918+
// cast to negative scale
919+
checkSparkMaybeThrows(
920+
spark.sql("select a, cast(a as DECIMAL(10,-4)) from t order by a")) match {
921+
case (expected, actual) =>
922+
assert(expected.contains("PARSE_SYNTAX_ERROR") === actual.contains("PARSE_SYNTAX_ERROR"))
923+
}
924+
}
925+
926+
test("cast between decimals with zero precision") {
927+
// cast between Decimal(10, 2) to Decimal(10,4)
928+
castTest(generateDecimalsPrecision10Scale2(), DataTypes.createDecimalType(10, 0))
929+
}
930+
910931
private def generateFloats(): DataFrame = {
911932
withNulls(gen.generateFloats(dataSize)).toDF("a")
912933
}

spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,11 +231,9 @@ abstract class CometTestBase
231231
df: => DataFrame): (Option[Throwable], Option[Throwable]) = {
232232
var expected: Option[Throwable] = None
233233
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
234-
val dfSpark = Dataset.ofRows(spark, df.logicalPlan)
235-
expected = Try(dfSpark.collect()).failed.toOption
234+
expected = Try(Dataset.ofRows(spark, df.logicalPlan).collect()).failed.toOption
236235
}
237-
val dfComet = Dataset.ofRows(spark, df.logicalPlan)
238-
val actual = Try(dfComet.collect()).failed.toOption
236+
val actual = Try(Dataset.ofRows(spark, df.logicalPlan).collect()).failed.toOption
239237
(expected, actual)
240238
}
241239

0 commit comments

Comments
 (0)