Skip to content

Commit ed696d2

Browse files
committed
fix week, modify getDate to use evaluated value
1 parent 76cb5d0 commit ed696d2

File tree

2 files changed

+43
-24
lines changed

2 files changed

+43
-24
lines changed

sql/expression/function/extract.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,11 @@ func (td *Extract) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
128128
case "MONTH":
129129
return int(dateTime.Month()), nil
130130
case "WEEK":
131-
date, err := getDate(ctx, expression.UnaryExpression{Child: td.RightChild}, row)
131+
dateVal, err := td.RightChild.Eval(ctx, row)
132+
if err != nil {
133+
return nil, err
134+
}
135+
date, err := getDate(ctx, dateVal)
132136
if err != nil {
133137
return nil, err
134138
}

sql/expression/function/time.go

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,7 @@ var ErrUnknownType = errors.NewKind("function '%s' encountered unknown type %T")
3636

3737
var ErrTooHighPrecision = errors.NewKind("Too-big precision %d for '%s'. Maximum is %d.")
3838

39-
func getDate(ctx *sql.Context,
40-
u expression.UnaryExpression,
41-
row sql.Row) (interface{}, error) {
42-
43-
val, err := u.Child.Eval(ctx, row)
44-
if err != nil {
45-
return nil, err
46-
}
47-
39+
func getDate(ctx *sql.Context, val interface{}) (interface{}, error) {
4840
if val == nil {
4941
return nil, nil
5042
}
@@ -62,8 +54,12 @@ func getDatePart(ctx *sql.Context,
6254
u expression.UnaryExpression,
6355
row sql.Row,
6456
f func(interface{}) interface{}) (interface{}, error) {
57+
val, err := u.Child.Eval(ctx, row)
58+
if err != nil {
59+
return nil, err
60+
}
6561

66-
date, err := getDate(ctx, u, row)
62+
date, err := getDate(ctx, val)
6763
if err != nil {
6864
return nil, err
6965
}
@@ -602,22 +598,26 @@ func (*YearWeek) CollationCoercibility(ctx *sql.Context) (collation sql.Collatio
602598

603599
// Eval implements the Expression interface.
604600
func (d *YearWeek) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
605-
date, err := getDate(ctx, expression.UnaryExpression{Child: d.date}, row)
601+
dateVal, err := d.date.Eval(ctx, row)
602+
if err != nil {
603+
return nil, err
604+
}
605+
date, err := getDate(ctx, dateVal)
606606
if err != nil {
607607
return nil, err
608608
}
609609
if date == nil {
610610
return nil, nil
611611
}
612-
yyyy, ok := year(date).(int32)
612+
yyyy, ok := year(date).(int)
613613
if !ok {
614614
return nil, sql.ErrInvalidArgumentDetails.New("YEARWEEK", "invalid year")
615615
}
616-
mm, ok := month(date).(int32)
616+
mm, ok := month(date).(int)
617617
if !ok {
618618
return nil, sql.ErrInvalidArgumentDetails.New("YEARWEEK", "invalid month")
619619
}
620-
dd, ok := day(date).(int32)
620+
dd, ok := day(date).(int)
621621
if !ok {
622622
return nil, sql.ErrInvalidArgumentDetails.New("YEARWEEK", "invalid day")
623623
}
@@ -634,9 +634,9 @@ func (d *YearWeek) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
634634
}
635635
}
636636
}
637-
yyyy, week := calcWeek(yyyy, mm, dd, weekMode(mode)|weekBehaviourYear)
637+
yr, week := calcWeek(int32(yyyy), int32(mm), int32(dd), weekMode(mode)|weekBehaviourYear)
638638

639-
return (yyyy * 100) + week, nil
639+
return (yr * 100) + week, nil
640640
}
641641

642642
// Resolved implements the Expression interface.
@@ -710,20 +710,34 @@ func (*Week) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID,
710710

711711
// Eval implements the Expression interface.
712712
func (d *Week) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
713-
date, err := getDate(ctx, expression.UnaryExpression{Child: d.date}, row)
713+
dateVal, err := d.date.Eval(ctx, row)
714+
if err != nil {
715+
return nil, err
716+
}
717+
718+
date, err := getDate(ctx, dateVal)
714719
if err != nil {
715720
return nil, err
716721
}
722+
if date == nil {
723+
return nil, nil
724+
}
725+
726+
dateTime, ok := date.(time.Time)
727+
if !ok || dateTime.Equal(types.ZeroTime) {
728+
ctx.Warn(1292, "%s", types.ErrConvertingToTime.New(dateVal).Error())
729+
return nil, nil
730+
}
717731

718-
yyyy, ok := year(date).(int32)
732+
yyyy, ok := year(date).(int)
719733
if !ok {
720734
return nil, sql.ErrInvalidArgumentDetails.New("WEEK", "invalid year")
721735
}
722-
mm, ok := month(date).(int32)
736+
mm, ok := month(date).(int)
723737
if !ok {
724738
return nil, sql.ErrInvalidArgumentDetails.New("WEEK", "invalid month")
725739
}
726-
dd, ok := day(date).(int32)
740+
dd, ok := day(date).(int)
727741
if !ok {
728742
return nil, sql.ErrInvalidArgumentDetails.New("WEEK", "invalid day")
729743
}
@@ -741,11 +755,12 @@ func (d *Week) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
741755
}
742756
}
743757

744-
yearForWeek, week := calcWeek(yyyy, mm, dd, weekMode(mode)|weekBehaviourYear)
758+
yr := int32(yyyy)
759+
yearForWeek, week := calcWeek(yr, int32(mm), int32(dd), weekMode(mode)|weekBehaviourYear)
745760

746-
if yearForWeek < yyyy {
761+
if yearForWeek < yr {
747762
week = 0
748-
} else if yearForWeek > yyyy {
763+
} else if yearForWeek > yr {
749764
week = 53
750765
}
751766

0 commit comments

Comments
 (0)