@@ -33,7 +33,6 @@ import org.apache.spark.sql.functions.col
3333import org .apache .spark .sql .internal .SQLConf
3434import org .apache .spark .sql .types .{ArrayType , BooleanType , ByteType , DataType , DataTypes , DecimalType , IntegerType , LongType , ShortType , StringType , StructField , StructType }
3535
36- import org .apache .comet .CometSparkSessionExtensions .isSpark40Plus
3736import org .apache .comet .expressions .{CometCast , CometEvalMode }
3837import org .apache .comet .rules .CometScanTypeChecker
3938import org .apache .comet .serde .Compatible
@@ -575,8 +574,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
575574 // CAST from StringType
576575
577576 test(" cast StringType to BooleanType" ) {
578- // TODO fix for Spark 4.0.0
579- assume(! isSpark40Plus)
580577 val testValues =
581578 (Seq (" TRUE" , " True" , " true" , " FALSE" , " False" , " false" , " 1" , " 0" , " " , null ) ++
582579 gen.generateStrings(dataSize, " truefalseTRUEFALSEyesno10" + whitespaceChars, 8 )).toDF(" a" )
@@ -617,35 +614,27 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
617614 )
618615
619616 test(" cast StringType to ByteType" ) {
620- // TODO fix for Spark 4.0.0
621- assume(! isSpark40Plus)
622617 // test with hand-picked values
623618 castTest(castStringToIntegralInputs.toDF(" a" ), DataTypes .ByteType )
624619 // fuzz test
625620 castTest(gen.generateStrings(dataSize, numericPattern, 4 ).toDF(" a" ), DataTypes .ByteType )
626621 }
627622
628623 test(" cast StringType to ShortType" ) {
629- // TODO fix for Spark 4.0.0
630- assume(! isSpark40Plus)
631624 // test with hand-picked values
632625 castTest(castStringToIntegralInputs.toDF(" a" ), DataTypes .ShortType )
633626 // fuzz test
634627 castTest(gen.generateStrings(dataSize, numericPattern, 5 ).toDF(" a" ), DataTypes .ShortType )
635628 }
636629
637630 test(" cast StringType to IntegerType" ) {
638- // TODO fix for Spark 4.0.0
639- assume(! isSpark40Plus)
640631 // test with hand-picked values
641632 castTest(castStringToIntegralInputs.toDF(" a" ), DataTypes .IntegerType )
642633 // fuzz test
643634 castTest(gen.generateStrings(dataSize, numericPattern, 8 ).toDF(" a" ), DataTypes .IntegerType )
644635 }
645636
646637 test(" cast StringType to LongType" ) {
647- // TODO fix for Spark 4.0.0
648- assume(! isSpark40Plus)
649638 // test with hand-picked values
650639 castTest(castStringToIntegralInputs.toDF(" a" ), DataTypes .LongType )
651640 // fuzz test
@@ -707,8 +696,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
707696 }
708697
709698 test(" cast StringType to DateType" ) {
710- // TODO fix for Spark 4.0.0
711- assume(! isSpark40Plus)
712699 val validDates = Seq (
713700 " 262142-01-01" ,
714701 " 262142-01-01 " ,
@@ -1295,10 +1282,21 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
12951282 } else {
12961283 if (CometSparkSessionExtensions .isSpark40Plus) {
12971284 // for Spark 4 we expect to sparkException carries the message
1298- assert(
1299- sparkException.getMessage
1300- .replace(" .WITH_SUGGESTION] " , " ]" )
1301- .startsWith(cometMessage))
1285+ assert(sparkMessage.contains(" SQLSTATE" ))
1286+ if (sparkMessage.startsWith(" [NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION]" )) {
1287+ assert(
1288+ sparkMessage.replace(" .WITH_SUGGESTION] " , " ]" ).startsWith(cometMessage))
1289+ } else if (cometMessage.startsWith(" [CAST_INVALID_INPUT]" ) || cometMessage
1290+ .startsWith(" [CAST_OVERFLOW]" )) {
1291+ assert(
1292+ sparkMessage.startsWith(
1293+ cometMessage
1294+ .replace(
1295+ " If necessary set \" spark.sql.ansi.enabled\" to \" false\" to bypass this error." ,
1296+ " " )))
1297+ } else {
1298+ assert(sparkMessage.startsWith(cometMessage))
1299+ }
13021300 } else {
13031301 // for Spark 3.4 we expect to reproduce the error message exactly
13041302 assert(cometMessage == sparkMessage)
@@ -1325,5 +1323,4 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
13251323 df.write.mode(SaveMode .Overwrite ).parquet(filename)
13261324 spark.read.parquet(filename)
13271325 }
1328-
13291326}
0 commit comments