Skip to content

Commit 54e058d

Browse files
MaxGekksrowen
authored andcommitted
[SPARK-28416][SQL] Use java.time API in timestampAddInterval
## What changes were proposed in this pull request? The `DateTimeUtils.timestampAddInterval` method was rewritten by using Java 8 time API. To add months and microseconds, I used the `plusMonths()` and `plus()` methods of `ZonedDateTime`. Also the signature of `timestampAddInterval()` was changed to accept an `ZoneId` instance instead of `TimeZone`. Using `ZoneId` allows to avoid the conversion `TimeZone` -> `ZoneId` on every invoke of `timestampAddInterval()`. ## How was this patch tested? By existing test suites `DateExpressionsSuite`, `TypeCoercionSuite` and `CollectionExpressionsSuite`. Closes apache#25173 from MaxGekk/timestamp-add-interval. Authored-by: Maxim Gekk <[email protected]> Signed-off-by: Sean Owen <[email protected]>
1 parent 0c21404 commit 54e058d

File tree

4 files changed

+24
-25
lines changed

4 files changed

+24
-25
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
*/
1717
package org.apache.spark.sql.catalyst.expressions
1818

19-
import java.util.{Comparator, TimeZone}
19+
import java.time.ZoneId
20+
import java.util.Comparator
2021

2122
import scala.collection.mutable
2223
import scala.reflect.ClassTag
@@ -2459,10 +2460,10 @@ case class Sequence(
24592460
new IntegralSequenceImpl(iType)(ct, iType.integral)
24602461

24612462
case TimestampType =>
2462-
new TemporalSequenceImpl[Long](LongType, 1, identity, timeZone)
2463+
new TemporalSequenceImpl[Long](LongType, 1, identity, zoneId)
24632464

24642465
case DateType =>
2465-
new TemporalSequenceImpl[Int](IntegerType, MICROS_PER_DAY, _.toInt, timeZone)
2466+
new TemporalSequenceImpl[Int](IntegerType, MICROS_PER_DAY, _.toInt, zoneId)
24662467
}
24672468

24682469
override def eval(input: InternalRow): Any = {
@@ -2603,7 +2604,7 @@ object Sequence {
26032604
}
26042605

26052606
private class TemporalSequenceImpl[T: ClassTag]
2606-
(dt: IntegralType, scale: Long, fromLong: Long => T, timeZone: TimeZone)
2607+
(dt: IntegralType, scale: Long, fromLong: Long => T, zoneId: ZoneId)
26072608
(implicit num: Integral[T]) extends SequenceImpl {
26082609

26092610
override val defaultStep: DefaultStep = new DefaultStep(
@@ -2642,7 +2643,7 @@ object Sequence {
26422643
while (t < exclusiveItem ^ stepSign < 0) {
26432644
arr(i) = fromLong(t / scale)
26442645
i += 1
2645-
t = timestampAddInterval(startMicros, i * stepMonths, i * stepMicros, timeZone)
2646+
t = timestampAddInterval(startMicros, i * stepMonths, i * stepMicros, zoneId)
26462647
}
26472648

26482649
// truncate array to the correct length
@@ -2668,7 +2669,7 @@ object Sequence {
26682669
val exclusiveItem = ctx.freshName("exclusiveItem")
26692670
val t = ctx.freshName("t")
26702671
val i = ctx.freshName("i")
2671-
val genTimeZone = ctx.addReferenceObj("timeZone", timeZone, classOf[TimeZone].getName)
2672+
val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
26722673

26732674
val sequenceLengthCode =
26742675
s"""
@@ -2701,7 +2702,7 @@ object Sequence {
27012702
| $arr[$i] = ($elemType) ($t / ${scale}L);
27022703
| $i += 1;
27032704
| $t = org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampAddInterval(
2704-
| $startMicros, $i * $stepMonths, $i * $stepMicros, $genTimeZone);
2705+
| $startMicros, $i * $stepMonths, $i * $stepMicros, $zid);
27052706
| }
27062707
|
27072708
| if ($arr.length > $i) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -996,14 +996,14 @@ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[S
996996
override def nullSafeEval(start: Any, interval: Any): Any = {
997997
val itvl = interval.asInstanceOf[CalendarInterval]
998998
DateTimeUtils.timestampAddInterval(
999-
start.asInstanceOf[Long], itvl.months, itvl.microseconds, timeZone)
999+
start.asInstanceOf[Long], itvl.months, itvl.microseconds, zoneId)
10001000
}
10011001

10021002
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
1003-
val tz = ctx.addReferenceObj("timeZone", timeZone)
1003+
val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
10041004
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
10051005
defineCodeGen(ctx, ev, (sd, i) => {
1006-
s"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds, $tz)"""
1006+
s"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds, $zid)"""
10071007
})
10081008
}
10091009
}
@@ -1111,14 +1111,14 @@ case class TimeSub(start: Expression, interval: Expression, timeZoneId: Option[S
11111111
override def nullSafeEval(start: Any, interval: Any): Any = {
11121112
val itvl = interval.asInstanceOf[CalendarInterval]
11131113
DateTimeUtils.timestampAddInterval(
1114-
start.asInstanceOf[Long], 0 - itvl.months, 0 - itvl.microseconds, timeZone)
1114+
start.asInstanceOf[Long], 0 - itvl.months, 0 - itvl.microseconds, zoneId)
11151115
}
11161116

11171117
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
1118-
val tz = ctx.addReferenceObj("timeZone", timeZone)
1118+
val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
11191119
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
11201120
defineCodeGen(ctx, ev, (sd, i) => {
1121-
s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds, $tz)"""
1121+
s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds, $zid)"""
11221122
})
11231123
}
11241124
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.util
1919

2020
import java.sql.{Date, Timestamp}
2121
import java.time._
22-
import java.time.Year.isLeap
23-
import java.time.temporal.IsoFields
22+
import java.time.temporal.{ChronoUnit, IsoFields}
2423
import java.util.{Locale, TimeZone}
2524
import java.util.concurrent.TimeUnit._
2625

@@ -521,12 +520,12 @@ object DateTimeUtils {
521520
start: SQLTimestamp,
522521
months: Int,
523522
microseconds: Long,
524-
timeZone: TimeZone): SQLTimestamp = {
525-
val days = millisToDays(MICROSECONDS.toMillis(start), timeZone)
526-
val newDays = dateAddMonths(days, months)
527-
start +
528-
MILLISECONDS.toMicros(daysToMillis(newDays, timeZone) - daysToMillis(days, timeZone)) +
529-
microseconds
523+
zoneId: ZoneId): SQLTimestamp = {
524+
val resultTimestamp = microsToInstant(start)
525+
.atZone(zoneId)
526+
.plusMonths(months)
527+
.plus(microseconds, ChronoUnit.MICROS)
528+
instantToMicros(resultTimestamp.toInstant)
530529
}
531530

532531
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ import org.apache.spark.unsafe.types.UTF8String
3131
class DateTimeUtilsSuite extends SparkFunSuite {
3232

3333
val TimeZonePST = TimeZone.getTimeZone("PST")
34-
private def defaultTz = DateTimeUtils.defaultTimeZone()
3534
private def defaultZoneId = ZoneId.systemDefault()
3635

3736
test("nanoseconds truncation") {
@@ -366,13 +365,13 @@ class DateTimeUtilsSuite extends SparkFunSuite {
366365
test("timestamp add months") {
367366
val ts1 = date(1997, 2, 28, 10, 30, 0)
368367
val ts2 = date(2000, 2, 28, 10, 30, 0, 123000)
369-
assert(timestampAddInterval(ts1, 36, 123000, defaultTz) === ts2)
368+
assert(timestampAddInterval(ts1, 36, 123000, defaultZoneId) === ts2)
370369

371370
val ts3 = date(1997, 2, 27, 16, 0, 0, 0, TimeZonePST)
372371
val ts4 = date(2000, 2, 27, 16, 0, 0, 123000, TimeZonePST)
373372
val ts5 = date(2000, 2, 28, 0, 0, 0, 123000, TimeZoneGMT)
374-
assert(timestampAddInterval(ts3, 36, 123000, TimeZonePST) === ts4)
375-
assert(timestampAddInterval(ts3, 36, 123000, TimeZoneGMT) === ts5)
373+
assert(timestampAddInterval(ts3, 36, 123000, TimeZonePST.toZoneId) === ts4)
374+
assert(timestampAddInterval(ts3, 36, 123000, TimeZoneGMT.toZoneId) === ts5)
376375
}
377376

378377
test("monthsBetween") {

0 commit comments

Comments
 (0)