Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ Julien Schmidt <go-sql-driver at julienschmidt.com>
Justin Li <jli at j-li.net>
Justin Nuß <nuss.justin at gmail.com>
Kamil Dziedzic <kamil at klecza.pl>
Kei Kamikawa <x00.x7f.x86 at gmail.com>
Kevin Malachowski <kevin at chowski.com>
Kieron Woodhouse <kieron.woodhouse at infosum.com>
Lennart Rudolph <lrudolph at hmc.edu>
Expand Down
4 changes: 2 additions & 2 deletions nulltime.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ func (nt *NullTime) Scan(value interface{}) (err error) {
nt.Time, nt.Valid = v, true
return
case []byte:
nt.Time, err = parseDateTime(string(v), time.UTC)
nt.Time, err = parseDateTime(v, time.UTC)
nt.Valid = (err == nil)
return
case string:
nt.Time, err = parseDateTime(v, time.UTC)
nt.Time, err = parseDateTime([]byte(v), time.UTC)
nt.Valid = (err == nil)
return
}
Expand Down
2 changes: 1 addition & 1 deletion packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,7 @@ func (rows *textRows) readRow(dest []driver.Value) error {
case fieldTypeTimestamp, fieldTypeDateTime,
fieldTypeDate, fieldTypeNewDate:
dest[i], err = parseDateTime(
string(dest[i].([]byte)),
dest[i].([]byte),
mc.cfg.Loc,
)
if err == nil {
Expand Down
141 changes: 131 additions & 10 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package mysql

import (
"bytes"
"crypto/tls"
"database/sql"
"database/sql/driver"
Expand Down Expand Up @@ -106,21 +107,141 @@ func readBool(input string) (value bool, valid bool) {
* Time related utils *
******************************************************************************/

func parseDateTime(str string, loc *time.Location) (t time.Time, err error) {
base := "0000-00-00 00:00:00.0000000"
switch len(str) {
var (
nullTimeBaseStr = "0000-00-00 00:00:00.000000"
nullTimeBaseByte = []byte(nullTimeBaseStr)
)

func parseDateTime(b []byte, loc *time.Location) (time.Time, error) {
switch len(b) {
case 10, 19, 21, 22, 23, 24, 25, 26: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM"
if str == base[:len(str)] {
return
if bytes.Compare(b, nullTimeBaseByte[:len(b)]) == 0 {
return time.Time{}, nil
}

year, err := parseByteYear(b)
if err != nil {
return time.Time{}, err
}
if year <= 0 {
year = 1
}
if loc == time.UTC {
return time.Parse(timeFormat[:len(str)], str)

if b[4] != '-' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[4])
}

m, err := parseByte2Digits(b[5], b[6])
if err != nil {
return time.Time{}, err
}
if m <= 0 {
m = 1
}
month := time.Month(m)

if b[7] != '-' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[7])
}

day, err := parseByte2Digits(b[8], b[9])
if err != nil {
return time.Time{}, err
}
if day <= 0 {
day = 1
}
if len(b) == 10 {
return time.Date(year, month, day, 0, 0, 0, 0, loc), nil
}

if b[10] != ' ' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[10])
}

hour, err := parseByte2Digits(b[11], b[12])
if err != nil {
return time.Time{}, err
}
if b[13] != ':' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[13])
}

min, err := parseByte2Digits(b[14], b[15])
if err != nil {
return time.Time{}, err
}
return time.ParseInLocation(timeFormat[:len(str)], str, loc)
if b[16] != ':' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[16])
}

sec, err := parseByte2Digits(b[17], b[18])
if err != nil {
return time.Time{}, err
}
if len(b) == 19 {
return time.Date(year, month, day, hour, min, sec, 0, loc), nil
}

if b[19] != '.' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[19])
}
nsec, err := parseByteNanoSec(b)
if err != nil {
return time.Time{}, err
}
return time.Date(year, month, day, hour, min, sec, nsec, loc), nil
default:
err = fmt.Errorf("invalid time string: %s", str)
return
return time.Time{}, fmt.Errorf("invalid time bytes: %s", b)
}
}

func parseByteYear(b []byte) (int, error) {
year, n := 0, 1000
for i := 0; i < 4; i++ {
v, err := bToi(b[i])
if err != nil {
return 0, err
}
year += v * n
n = n / 10
}
return year, nil
}

func parseByte2Digits(b1, b2 byte) (int, error) {
d2, err := bToi(b1)
if err != nil {
return 0, err
}
d1, err := bToi(b2)
if err != nil {
return 0, err
}
return d2*10 + d1, nil
}

func parseByteNanoSec(b []byte) (int, error) {
l := len(b)
ns, digit := 0, 100000 // max is 6-digits
for i := 20; i < l; i++ {
v, err := bToi(b[i])
if err != nil {
return 0, err
}
ns += v * digit
digit /= 10
}
// nanoseconds has 10-digits. (needs to scale digits)
// 10 - 6 = 4, so we have to multiple 1000.
return ns * 1000, nil
}

func bToi(b byte) (int, error) {
if b < '0' || b > '9' {
return 0, errors.New("not [0-9]")
}
return int(b - '0'), nil
}

func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) {
Expand Down
Loading