@@ -896,17 +896,38 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
896896 }
897897
898898 test(" cast between decimals with different precision and scale" ) {
899- // cast between default Decimal(38, 18) to Decimal(9,1 )
899+ // cast between default Decimal(38, 18) to Decimal(7,2 )
900900 val values = Seq (BigDecimal (" 12345.6789" ), BigDecimal (" 9876.5432" ), BigDecimal (" 123.4567" ))
901901 val df = withNulls(values).toDF(" a" )
902902 castTest(df, DataTypes .createDecimalType(7 , 2 ))
903903 }
904904
905- test(" cast two between decimals with different precision and scale" ) {
905+ test(" cast between decimals with lower precision and scale" ) {
906906 // cast between Decimal(10, 2) to Decimal(9,1)
907907 castTest(generateDecimalsPrecision10Scale2(), DataTypes .createDecimalType(9 , 1 ))
908908 }
909909
910+ test(" cast between decimals with higher precision than source" ) {
911+ // cast between Decimal(10, 2) to Decimal(10,4)
912+ withSQLConf(" spark.comet.explainFallback.enabled" -> " true" ) {
913+ castTest(generateDecimalsPrecision10Scale2(), DataTypes .createDecimalType(10 , 4 ))
914+ }
915+ }
916+
917+ test(" cast between decimals with negative precision" ) {
918+ // cast to negative scale
919+ checkSparkMaybeThrows(
920+ spark.sql(" select a, cast(a as DECIMAL(10,-4)) from t order by a" )) match {
921+ case (expected, actual) =>
922+ assert(expected.contains(" PARSE_SYNTAX_ERROR" ) === actual.contains(" PARSE_SYNTAX_ERROR" ))
923+ }
924+ }
925+
926+ test(" cast between decimals with zero precision" ) {
927+ // cast between Decimal(10, 2) to Decimal(10,4)
928+ castTest(generateDecimalsPrecision10Scale2(), DataTypes .createDecimalType(10 , 0 ))
929+ }
930+
910931 private def generateFloats (): DataFrame = {
911932 withNulls(gen.generateFloats(dataSize)).toDF(" a" )
912933 }
0 commit comments