Skip to content

Commit e57ca0f

Browse files
committed
amend date, datetime, timestamp limits
1 parent 635d803 commit e57ca0f

File tree

5 files changed

+100
-85
lines changed

5 files changed

+100
-85
lines changed

sql/expression/arithmetic.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -521,14 +521,14 @@ func plus(lval, rval interface{}) (interface{}, error) {
521521
case time.Time:
522522
switch r := rval.(type) {
523523
case *TimeDelta:
524-
return types.ValidateTime(r.Add(l)), nil
524+
return types.ValidateDatetime(r.Add(l)), nil
525525
case time.Time:
526526
return l.Unix() + r.Unix(), nil
527527
}
528528
case *TimeDelta:
529529
switch r := rval.(type) {
530530
case time.Time:
531-
return types.ValidateTime(l.Add(r)), nil
531+
return types.ValidateDatetime(l.Add(r)), nil
532532
}
533533
}
534534

@@ -595,7 +595,7 @@ func minus(lval, rval interface{}) (interface{}, error) {
595595
case time.Time:
596596
switch r := rval.(type) {
597597
case *TimeDelta:
598-
return types.ValidateTime(r.Sub(l)), nil
598+
return types.ValidateDatetime(r.Sub(l)), nil
599599
case time.Time:
600600
return l.Unix() - r.Unix(), nil
601601
}

sql/expression/function/time_math.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ func (d *DateAdd) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
239239
}
240240

241241
// return appropriate type
242-
res := types.ValidateTime(delta.Add(dateVal.(time.Time)))
242+
res := types.ValidateDatetime(delta.Add(dateVal.(time.Time)))
243243
if res == nil {
244244
return nil, nil
245245
}
@@ -387,7 +387,7 @@ func (d *DateSub) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
387387
}
388388

389389
// return appropriate type
390-
res := types.ValidateTime(delta.Sub(dateVal.(time.Time)))
390+
res := types.ValidateDatetime(delta.Sub(dateVal.(time.Time)))
391391
if res == nil {
392392
return nil, nil
393393
}

sql/expression/interval.go

Lines changed: 16 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -238,71 +238,22 @@ const (
238238
)
239239

240240
func (td TimeDelta) apply(t time.Time, sign int64) time.Time {
241-
y := int64(t.Year())
242-
mo := int64(t.Month())
243-
d := t.Day()
244-
h := t.Hour()
245-
min := t.Minute()
246-
s := t.Second()
247-
ns := t.Nanosecond()
248-
249-
if td.Years != 0 {
250-
y += td.Years * sign
241+
// add years, months, days using AddDate (handles normalization)
242+
t = t.AddDate(
243+
int(td.Years*sign),
244+
int(td.Months*sign),
245+
int(td.Days*sign),
246+
)
247+
248+
// add hours, minutes, seconds, microseconds
249+
duration := time.Duration(td.Hours*sign)*time.Hour +
250+
time.Duration(td.Minutes*sign)*time.Minute +
251+
time.Duration(td.Seconds*sign)*time.Second +
252+
time.Duration(td.Microseconds*sign)*time.Microsecond
253+
254+
if duration != 0 {
255+
t = t.Add(duration)
251256
}
252257

253-
if td.Months != 0 {
254-
m := mo + td.Months*sign
255-
if m < 1 {
256-
mo = 12 + (m % 12)
257-
y += m/12 - 1
258-
} else if m > 12 {
259-
mo = m % 12
260-
y += m / 12
261-
} else {
262-
mo = m
263-
}
264-
265-
// Due to the operations done before, month may be zero, which means it's
266-
// december.
267-
if mo == 0 {
268-
mo = 12
269-
}
270-
}
271-
272-
if days := daysInMonth(time.Month(mo), int(y)); days < d {
273-
d = days
274-
}
275-
276-
date := time.Date(int(y), time.Month(mo), d, h, min, s, ns, t.Location())
277-
278-
if td.Days != 0 {
279-
date = date.Add(time.Duration(td.Days) * day * time.Duration(sign))
280-
}
281-
282-
if td.Hours != 0 {
283-
date = date.Add(time.Duration(td.Hours) * time.Hour * time.Duration(sign))
284-
}
285-
286-
if td.Minutes != 0 {
287-
date = date.Add(time.Duration(td.Minutes) * time.Minute * time.Duration(sign))
288-
}
289-
290-
if td.Seconds != 0 {
291-
date = date.Add(time.Duration(td.Seconds) * time.Second * time.Duration(sign))
292-
}
293-
294-
if td.Microseconds != 0 {
295-
date = date.Add(time.Duration(td.Microseconds) * time.Microsecond * time.Duration(sign))
296-
}
297-
298-
return date
299-
}
300-
301-
func daysInMonth(month time.Month, year int) int {
302-
if month == time.December {
303-
return 31
304-
}
305-
306-
date := time.Date(year, month+time.Month(1), 1, 0, 0, 0, 0, time.Local)
307-
return date.Add(-1 * day).Day()
258+
return t
308259
}

sql/types/datetime.go

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ package types
1717
import (
1818
"context"
1919
"fmt"
20-
"math"
2120
"reflect"
2221
"time"
2322

@@ -39,17 +38,21 @@ var (
3938

4039
ErrConvertingToTimeOutOfRange = errors.NewKind("value %q is outside of %v range")
4140

42-
// datetimeTypeMaxDatetime is the maximum representable Datetime/Date value.
43-
datetimeTypeMaxDatetime = time.Date(9999, 12, 31, 23, 59, 59, 999999000, time.UTC)
41+
// datetimeTypeMaxDatetime is the maximum representable Datetime/Date value. MYSQL: 9999-12-31 23:59:59.499999 (microseconds)
42+
datetimeTypeMaxDatetime = time.Date(9999, 12, 31, 23, 59, 59, 499999000, time.UTC)
4443

45-
// datetimeTypeMinDatetime is the minimum representable Datetime/Date value.
46-
datetimeTypeMinDatetime = time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC)
44+
// datetimeTypeMinDatetime is the minimum representable Datetime/Date value. MYSQL: 1000-01-01 00:00:00.000000 (microseconds)
45+
datetimeTypeMinDatetime = time.Date(1000, 1, 1, 0, 0, 0, 0, time.UTC)
4746

48-
// datetimeTypeMaxTimestamp is the maximum representable Timestamp value, which is the maximum 32-bit integer as a Unix time.
49-
datetimeTypeMaxTimestamp = time.Unix(math.MaxInt32, 999999000)
47+
// datetimeTypeMaxTimestamp is the maximum representable Timestamp value, MYSQL: 2038-01-19 03:14:07.999999 (microseconds)
48+
datetimeTypeMaxTimestamp = time.Date(2038, 1, 19, 3, 14, 7, 999999000, time.UTC)
5049

51-
// datetimeTypeMinTimestamp is the minimum representable Timestamp value, which is one second past the epoch.
52-
datetimeTypeMinTimestamp = time.Unix(1, 0)
50+
// datetimeTypeMinTimestamp is the minimum representable Timestamp value, MYSQL: 1970-01-01 00:00:01.000000 (microseconds)
51+
datetimeTypeMinTimestamp = time.Date(1970, 1, 1, 0, 0, 1, 0, time.UTC)
52+
53+
datetimeTypeMaxDate = time.Date(9999, 12, 31, 0, 0, 0, 0, time.UTC)
54+
55+
datetimeTypeMinDate = time.Date(1000, 1, 1, 0, 0, 0, 0, time.UTC)
5356

5457
DateOnlyLayouts = []string{
5558
"20060102",
@@ -206,15 +209,15 @@ func ConvertToTime(ctx context.Context, v interface{}, t datetimeType) (time.Tim
206209

207210
switch t.baseType {
208211
case sqltypes.Date:
209-
if res.Year() < 0 || res.Year() > 9999 {
212+
if validated := ValidateDate(res); validated == nil {
210213
return time.Time{}, ErrConvertingToTimeOutOfRange.New(res.Format(sql.DateLayout), t.String())
211214
}
212215
case sqltypes.Datetime:
213-
if res.Year() < 0 || res.Year() > 9999 {
216+
if validated := ValidateDatetime(res); validated == nil {
214217
return time.Time{}, ErrConvertingToTimeOutOfRange.New(res.Format(sql.TimestampDatetimeLayout), t.String())
215218
}
216219
case sqltypes.Timestamp:
217-
if res.Before(datetimeTypeMinTimestamp) || res.After(datetimeTypeMaxTimestamp) {
220+
if validated := ValidateTimestamp(res); validated == nil {
218221
return time.Time{}, ErrConvertingToTimeOutOfRange.New(res.Format(sql.TimestampDatetimeLayout), t.String())
219222
}
220223
}
@@ -470,10 +473,28 @@ func (t datetimeType) MinimumTime() time.Time {
470473
return datetimeTypeMinDatetime
471474
}
472475

473-
// ValidateTime receives a time and returns either that time or nil if it's
476+
// validateDatetime receives a time and returns either that time or nil if it's
474477
// not a valid time.
475-
func ValidateTime(t time.Time) interface{} {
476-
if t.After(time.Date(9999, time.December, 31, 23, 59, 59, 999999999, time.UTC)) {
478+
func ValidateDatetime(t time.Time) interface{} {
479+
if t.Before(datetimeTypeMinDatetime) || t.After(datetimeTypeMaxDatetime) {
480+
return nil
481+
}
482+
return t
483+
}
484+
485+
// ValidateTimestamp receives a time and returns either that time or nil if it's
486+
// not a valid timestamp.
487+
func ValidateTimestamp(t time.Time) interface{} {
488+
if t.Before(datetimeTypeMinTimestamp) || t.After(datetimeTypeMaxTimestamp) {
489+
return nil
490+
}
491+
return t
492+
}
493+
494+
// validateDate receives a time and returns either that time or nil if it's
495+
// not a valid date.
496+
func ValidateDate(t time.Time) interface{} {
497+
if t.Before(datetimeTypeMinDatetime) || t.After(datetimeTypeMaxDate) {
477498
return nil
478499
}
479500
return t

sql/types/datetime_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,3 +405,46 @@ func TestDatetimeZero(t *testing.T) {
405405
_, ok = MustCreateDatetimeType(sqltypes.Timestamp, 0).Zero().(time.Time)
406406
require.True(t, ok)
407407
}
408+
409+
func TestDatetimeOverflowUnderflow(t *testing.T) {
410+
ctx := sql.NewEmptyContext()
411+
tests := []struct {
412+
typ sql.DatetimeType
413+
val interface{}
414+
expectError bool
415+
}{
416+
// Date underflow
417+
{Date, "0999-12-31", true},
418+
// Date overflow
419+
{Date, "10000-01-01", true},
420+
// Datetime underflow
421+
{Datetime, "0999-12-31 23:59:59", true},
422+
// Datetime overflow
423+
{Datetime, "10000-01-01 00:00:00", true},
424+
// Timestamp underflow
425+
{Timestamp, "1969-12-31 23:59:59", true},
426+
// Timestamp overflow
427+
{Timestamp, "2038-01-19 03:14:08", true},
428+
// Valid edge cases
429+
{Date, Date.MinimumTime().Format("2006-01-02"), false},
430+
{Date, Date.MaximumTime().Format("2006-01-02"), false},
431+
{Datetime, Datetime.MinimumTime().Format("2006-01-02 15:04:05"), false},
432+
{Datetime, Datetime.MaximumTime().Format("2006-01-02 15:04:05"), false},
433+
{Timestamp, Timestamp.MinimumTime().Format("2006-01-02 15:04:05"), false},
434+
{Timestamp, Timestamp.MaximumTime().Format("2006-01-02 15:04:05"), false},
435+
}
436+
437+
for _, tt := range tests {
438+
t.Run(tt.typ.String()+"_"+tt.val.(string), func(t *testing.T) {
439+
_, inRange, err := tt.typ.Convert(ctx, tt.val)
440+
441+
if tt.expectError {
442+
require.True(t, err != nil || inRange == sql.OutOfRange,
443+
"expected error or out-of-range but got neither; err: %v, inRange: %v", err, inRange)
444+
} else {
445+
require.NoError(t, err)
446+
require.Equal(t, sql.InRange, inRange)
447+
}
448+
})
449+
}
450+
}

0 commit comments

Comments
 (0)