Skip to content

Commit 2254e0d

Browse files
authored
Merge pull request #3118 from dolthub/angela/timezone
Validate time zone before setting it
2 parents 9929479 + 210a23c commit 2254e0d

File tree

6 files changed

+37
-10
lines changed

6 files changed

+37
-10
lines changed

enginetest/queries/time_queries.go

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,8 @@ var TimeQueryTests = []ScriptTest{
5454
Expected: []sql.Row{{time.Date(2025, time.July, 23, 6, 43, 21, 0, time.UTC)}},
5555
},
5656
{
57-
// https://github.com/dolthub/dolt/issues/9559
58-
Skip: true,
59-
Query: "set time_zone='invalid time zone",
60-
// update to actual error or error string
61-
ExpectedErrStr: "Unknown of incorrect time zone: 'invalid time zone'",
57+
Query: "set time_zone='invalid time zone'",
58+
ExpectedErr: sql.ErrInvalidTimeZone,
6259
},
6360
},
6461
},

sql/errors.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ var (
4242
// the execution tree.
4343
ErrInvalidType = errors.NewKind("invalid type: %s")
4444

45+
// ErrInvalidTimeZone is thrown when an invalid time zone is found
46+
ErrInvalidTimeZone = errors.NewKind("Unknown or incorrect time zone: %s")
47+
4548
// ErrTableAlreadyExists is thrown when someone tries to create a
4649
// table with a name of an existing one
4750
ErrTableAlreadyExists = errors.NewKind("table with name %s already exists")

sql/events.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -401,13 +401,13 @@ func GetTimeValueFromStringInput(field, t string) (time.Time, error) {
401401
datetimeVal := fmt.Sprintf("%4d-%02d-%02d %02d:%02d:%02d", year, month, day, hour, minute, second)
402402
tVal, err := time.Parse(EventDateSpaceTimeFormat, datetimeVal)
403403
if err != nil {
404-
return time.Time{}, fmt.Errorf("invalid time zone: %s", sessTz)
404+
return time.Time{}, ErrInvalidTimeZone.New(sessTz)
405405
}
406406

407407
// convert the time value to the session timezone for display and storage
408408
tVal, ok = ConvertTimeZone(tVal, inputTz, sessTz)
409409
if !ok {
410-
return time.Time{}, fmt.Errorf("invalid time zone: %s", sessTz)
410+
return time.Time{}, ErrInvalidTimeZone.New(sessTz)
411411
}
412412
return tVal, nil
413413
} else {

sql/expression/function/time.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,7 +1017,7 @@ func (n *Now) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
10171017
if n.prec == nil {
10181018
t, ok := sql.ConvertTimeZone(currentTime, sql.SystemTimezoneOffset(), sessionTimeZone)
10191019
if !ok {
1020-
return nil, fmt.Errorf("invalid time zone: %s", sessionTimeZone)
1020+
return nil, sql.ErrInvalidTimeZone.New(sessionTimeZone)
10211021
}
10221022
tt := time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), 0, time.UTC)
10231023
return tt, nil
@@ -1057,7 +1057,7 @@ func (n *Now) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
10571057
// Get the timestamp
10581058
t, ok := sql.ConvertTimeZone(currentTime, sql.SystemTimezoneOffset(), sessionTimeZone)
10591059
if !ok {
1060-
return nil, fmt.Errorf("invalid time zone: %s", sessionTimeZone)
1060+
return nil, sql.ErrInvalidTimeZone.New(sessionTimeZone)
10611061
}
10621062

10631063
// Calculate precision

sql/rowexec/rel_iters.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"errors"
1919
"io"
2020
"strings"
21+
"time"
2122

2223
"github.com/dolthub/go-mysql-server/sql"
2324
"github.com/dolthub/go-mysql-server/sql/expression"
@@ -392,6 +393,10 @@ func setSystemVar(ctx *sql.Context, sysVar *expression.SystemVar, right sql.Expr
392393
if err != nil {
393394
return err
394395
}
396+
err = validateSystemVariableValue(sysVar.Name, val)
397+
if err != nil {
398+
return err
399+
}
395400
err = sysVar.Scope.SetValue(ctx, sysVar.Name, val)
396401
if err != nil {
397402
return err
@@ -464,6 +469,24 @@ func setSystemVar(ctx *sql.Context, sysVar *expression.SystemVar, right sql.Expr
464469
return nil
465470
}
466471

472+
func validateSystemVariableValue(sysVarName string, val interface{}) error {
473+
switch strings.ToLower(sysVarName) {
474+
case "time_zone":
475+
valStr, ok := val.(string)
476+
if !ok {
477+
return sql.ErrInvalidTimeZone.New(val)
478+
}
479+
_, err := time.LoadLocation(valStr)
480+
if err == nil {
481+
return nil
482+
}
483+
if !sql.ValidTimeOffset(valStr) {
484+
return sql.ErrInvalidTimeZone.New(valStr)
485+
}
486+
}
487+
return nil
488+
}
489+
467490
// Applies the update expressions given to the row given, returning the new resultant row.
468491
func applyUpdateExpressions(ctx *sql.Context, updateExprs []sql.Expression, row sql.Row) (sql.Row, error) {
469492
var ok bool

sql/time.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ func ConvertTimeZone(datetime time.Time, fromLocation string, toLocation string)
4747
return datetime.Add(delta), true
4848
}
4949

50+
func ValidTimeOffset(str string) bool {
51+
return offsetRegex.MatchString(str)
52+
}
53+
5054
// MySQLOffsetToDuration takes in a MySQL timezone offset (e.g. "+01:00") and returns it as a time.Duration.
5155
// If any problems are encountered, an error is returned.
5256
func MySQLOffsetToDuration(d string) (time.Duration, error) {
@@ -113,7 +117,7 @@ func ConvertTimeToLocation(datetime time.Time, location string) (time.Time, erro
113117
return getCopy(datetime, time.UTC).Add(-1 * duration), nil
114118
}
115119

116-
return time.Time{}, errors.New(fmt.Sprintf("error: unable to parse timezone '%s'", location))
120+
return time.Time{}, ErrInvalidTimeZone.New(location)
117121
}
118122

119123
// getCopy recreates the time t in the wanted timezone.

0 commit comments

Comments
 (0)