Skip to content

Commit bbbe54a

Browse files
rednaxelafxcloud-fan
authored andcommitted
[SPARK-27199][SQL][FOLLOWUP] Fix bug in codegen templates in UnixTime and FromUnixTime
## What changes were proposed in this pull request? SPARK-27199 introduced the use of `ZoneId` instead of `TimeZone` in a few date/time expressions. There were 3 occurrences of `ctx.addReferenceObj("zoneId", zoneId)` in that PR, which had a bug because while the `java.time.ZoneId` base type is public, the actual concrete implementation classes are not public, so using the 2-arg version of `CodegenContext.addReferenceObj` would incorrectly generate code that reference non-public types (`java.time.ZoneRegion`, to be specific). The 3-arg version should be used, with the class name of the referenced object explicitly specified to the public base type. One of such occurrences was caught in testing in the main PR of SPARK-27199 (apache#24141), for `DateFormatClass`. But the other 2 occurrences slipped through because there were no test cases that covered them. Example of this bug in the current Apache Spark master, in a Spark Shell: ``` scala> Seq(("2016-04-08", "yyyy-MM-dd")).toDF("s", "f").repartition(1).selectExpr("to_unix_timestamp(s, f)").show ... java.lang.IllegalAccessError: tried to access class java.time.ZoneRegion from class org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1 ``` This PR fixes the codegen issues and adds the corresponding unit tests. ## How was this patch tested? Enhanced tests in `DateExpressionsSuite` for `to_unix_timestamp` and `from_unixtime`. Closes apache#24352 from rednaxelafx/fix-spark-27199. Authored-by: Kris Mok <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 94adffa commit bbbe54a

File tree

2 files changed

+40
-18
lines changed

2 files changed

+40
-18
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,7 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti
541541

542542
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
543543
val tf = TimestampFormatter.getClass.getName.stripSuffix("$")
544-
val zid = ctx.addReferenceObj("zoneId", zoneId, "java.time.ZoneId")
544+
val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
545545
val locale = ctx.addReferenceObj("locale", Locale.US)
546546
defineCodeGen(ctx, ev, (timestamp, format) => {
547547
s"""UTF8String.fromString($tf$$.MODULE$$.apply($format.toString(), $zid, $locale)
@@ -710,13 +710,13 @@ abstract class UnixTime
710710
}""")
711711
}
712712
case StringType =>
713-
val tz = ctx.addReferenceObj("zoneId", zoneId)
713+
val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
714714
val locale = ctx.addReferenceObj("locale", Locale.US)
715715
val tf = TimestampFormatter.getClass.getName.stripSuffix("$")
716716
nullSafeCodeGen(ctx, ev, (string, format) => {
717717
s"""
718718
try {
719-
${ev.value} = $tf$$.MODULE$$.apply($format.toString(), $tz, $locale)
719+
${ev.value} = $tf$$.MODULE$$.apply($format.toString(), $zid, $locale)
720720
.parse($string.toString()) / $MICROS_PER_SECOND;
721721
} catch (java.lang.IllegalArgumentException e) {
722722
${ev.isNull} = true;
@@ -849,13 +849,13 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[
849849
}""")
850850
}
851851
} else {
852-
val tz = ctx.addReferenceObj("zoneId", zoneId)
852+
val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
853853
val locale = ctx.addReferenceObj("locale", Locale.US)
854854
val tf = TimestampFormatter.getClass.getName.stripSuffix("$")
855855
nullSafeCodeGen(ctx, ev, (seconds, f) => {
856856
s"""
857857
try {
858-
${ev.value} = UTF8String.fromString($tf.apply($f.toString(), $tz, $locale).
858+
${ev.value} = UTF8String.fromString($tf.apply($f.toString(), $zid, $locale).
859859
format($seconds * 1000000L));
860860
} catch (java.lang.IllegalArgumentException e) {
861861
${ev.isNull} = true;

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,14 @@ import java.util.concurrent.TimeUnit._
2626

2727
import org.apache.spark.SparkFunSuite
2828
import org.apache.spark.sql.AnalysisException
29+
import org.apache.spark.sql.catalyst.InternalRow
2930
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
3031
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, TimestampFormatter}
3132
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
3233
import org.apache.spark.sql.catalyst.util.DateTimeUtils.TimeZoneGMT
3334
import org.apache.spark.sql.internal.SQLConf
3435
import org.apache.spark.sql.types._
35-
import org.apache.spark.unsafe.types.CalendarInterval
36+
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
3637

3738
class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
3839

@@ -652,7 +653,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
652653
}
653654

654655
test("from_unixtime") {
655-
val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US)
656+
val fmt1 = "yyyy-MM-dd HH:mm:ss"
657+
val sdf1 = new SimpleDateFormat(fmt1, Locale.US)
656658
val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS"
657659
val sdf2 = new SimpleDateFormat(fmt2, Locale.US)
658660
for (tz <- Seq(TimeZoneGMT, TimeZonePST, TimeZoneJST)) {
@@ -661,10 +663,10 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
661663
sdf2.setTimeZone(tz)
662664

663665
checkEvaluation(
664-
FromUnixTime(Literal(0L), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId),
666+
FromUnixTime(Literal(0L), Literal(fmt1), timeZoneId),
665667
sdf1.format(new Timestamp(0)))
666668
checkEvaluation(FromUnixTime(
667-
Literal(1000L), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId),
669+
Literal(1000L), Literal(fmt1), timeZoneId),
668670
sdf1.format(new Timestamp(1000000)))
669671
checkEvaluation(
670672
FromUnixTime(Literal(-1000L), Literal(fmt2), timeZoneId),
@@ -673,13 +675,22 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
673675
FromUnixTime(Literal.create(null, LongType), Literal.create(null, StringType), timeZoneId),
674676
null)
675677
checkEvaluation(
676-
FromUnixTime(Literal.create(null, LongType), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId),
678+
FromUnixTime(Literal.create(null, LongType), Literal(fmt1), timeZoneId),
677679
null)
678680
checkEvaluation(
679681
FromUnixTime(Literal(1000L), Literal.create(null, StringType), timeZoneId),
680682
null)
681683
checkEvaluation(
682684
FromUnixTime(Literal(0L), Literal("not a valid format"), timeZoneId), null)
685+
686+
// The codegen path for non-literal input should also work
687+
checkEvaluation(
688+
expression = FromUnixTime(
689+
BoundReference(ordinal = 0, dataType = LongType, nullable = true),
690+
BoundReference(ordinal = 1, dataType = StringType, nullable = true),
691+
timeZoneId),
692+
expected = UTF8String.fromString(sdf1.format(new Timestamp(0))),
693+
inputRow = InternalRow(0L, UTF8String.fromString(fmt1)))
683694
}
684695
}
685696

@@ -739,7 +750,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
739750
}
740751

741752
test("to_unix_timestamp") {
742-
val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US)
753+
val fmt1 = "yyyy-MM-dd HH:mm:ss"
754+
val sdf1 = new SimpleDateFormat(fmt1, Locale.US)
743755
val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS"
744756
val sdf2 = new SimpleDateFormat(fmt2, Locale.US)
745757
val fmt3 = "yy-MM-dd"
@@ -754,15 +766,15 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
754766

755767
val date1 = Date.valueOf("2015-07-24")
756768
checkEvaluation(ToUnixTimestamp(
757-
Literal(sdf1.format(new Timestamp(0))), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), 0L)
769+
Literal(sdf1.format(new Timestamp(0))), Literal(fmt1), timeZoneId), 0L)
758770
checkEvaluation(ToUnixTimestamp(
759-
Literal(sdf1.format(new Timestamp(1000000))), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId),
771+
Literal(sdf1.format(new Timestamp(1000000))), Literal(fmt1), timeZoneId),
760772
1000L)
761773
checkEvaluation(ToUnixTimestamp(
762-
Literal(new Timestamp(1000000)), Literal("yyyy-MM-dd HH:mm:ss")),
774+
Literal(new Timestamp(1000000)), Literal(fmt1)),
763775
1000L)
764776
checkEvaluation(
765-
ToUnixTimestamp(Literal(date1), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId),
777+
ToUnixTimestamp(Literal(date1), Literal(fmt1), timeZoneId),
766778
MILLISECONDS.toSeconds(DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz)))
767779
checkEvaluation(
768780
ToUnixTimestamp(Literal(sdf2.format(new Timestamp(-1000000))), Literal(fmt2), timeZoneId),
@@ -772,21 +784,31 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
772784
MILLISECONDS.toSeconds(DateTimeUtils.daysToMillis(
773785
DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24")), tz)))
774786
val t1 = ToUnixTimestamp(
775-
CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long]
787+
CurrentTimestamp(), Literal(fmt1)).eval().asInstanceOf[Long]
776788
val t2 = ToUnixTimestamp(
777-
CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long]
789+
CurrentTimestamp(), Literal(fmt1)).eval().asInstanceOf[Long]
778790
assert(t2 - t1 <= 1)
779791
checkEvaluation(ToUnixTimestamp(
780792
Literal.create(null, DateType), Literal.create(null, StringType), timeZoneId), null)
781793
checkEvaluation(
782794
ToUnixTimestamp(
783-
Literal.create(null, DateType), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId),
795+
Literal.create(null, DateType), Literal(fmt1), timeZoneId),
784796
null)
785797
checkEvaluation(ToUnixTimestamp(
786798
Literal(date1), Literal.create(null, StringType), timeZoneId),
787799
MILLISECONDS.toSeconds(DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz)))
788800
checkEvaluation(
789801
ToUnixTimestamp(Literal("2015-07-24"), Literal("not a valid format"), timeZoneId), null)
802+
803+
// The codegen path for non-literal input should also work
804+
checkEvaluation(
805+
expression = ToUnixTimestamp(
806+
BoundReference(ordinal = 0, dataType = StringType, nullable = true),
807+
BoundReference(ordinal = 1, dataType = StringType, nullable = true),
808+
timeZoneId),
809+
expected = 0L,
810+
inputRow = InternalRow(
811+
UTF8String.fromString(sdf1.format(new Timestamp(0))), UTF8String.fromString(fmt1)))
790812
}
791813
}
792814
}

0 commit comments

Comments
 (0)