Skip to content
2 changes: 1 addition & 1 deletion memory/table_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func (td TableData) partition(row sql.Row) (int, error) {

t, isStringType := td.schema.Schema[keyColumns[i]].Type.(sql.StringType)
if isStringType && v != nil {
v, err = types.ConvertToString(v, t)
v, err = types.ConvertToString(v, t, nil)
if err == nil {
err = t.Collation().WriteWeightString(hash, v.(string))
}
Expand Down
2 changes: 1 addition & 1 deletion sql/expression/function/inet_convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func (i *InetAton) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
}

// Expect to receive an IP address, so convert val into string
ipstr, err := types.ConvertToString(val, types.LongText)
ipstr, err := types.ConvertToString(val, types.LongText, nil)
if err != nil {
return nil, sql.ErrInvalidType.New(reflect.TypeOf(val).String())
}
Expand Down
4 changes: 2 additions & 2 deletions sql/expression/function/locks.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (nl *NamedLockFunction) GetLockName(ctx *sql.Context, row sql.Row) (*string
if !ok {
return nil, ErrIllegalLockNameArgType.New(nl.Child.Type().String(), nl.funcName)
}
lockName, err := types.ConvertToString(val, s)
lockName, err := types.ConvertToString(val, s, nil)
if err != nil {
return nil, fmt.Errorf("%w; %s", ErrIllegalLockNameArgType.New(nl.Child.Type().String(), nl.funcName), err)
}
Expand Down Expand Up @@ -328,7 +328,7 @@ func (gl *GetLock) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
return nil, ErrIllegalLockNameArgType.New(gl.LeftChild.Type().String(), gl.FunctionName())
}

lockName, err := types.ConvertToString(leftVal, s)
lockName, err := types.ConvertToString(leftVal, s, nil)
if err != nil {
return nil, fmt.Errorf("%w; %s", ErrIllegalLockNameArgType.New(gl.LeftChild.Type().String(), gl.FunctionName()), err)
}
Expand Down
2 changes: 1 addition & 1 deletion sql/rowexec/agg.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ func groupingKey(

t, isStringType := expr.Type().(sql.StringType)
if isStringType && v != nil {
v, err = types.ConvertToString(v, t)
v, err = types.ConvertToString(v, t, nil)
if err == nil {
err = t.Collation().WriteWeightString(hash, v.(string))
}
Expand Down
49 changes: 26 additions & 23 deletions sql/types/datetime.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,20 @@ var (
// versions of common cases and dates that use non common separators.
//
// https://github.com/MariaDB/server/blob/mysql-5.5.36/sql-common/my_time.c#L124
TimestampDatetimeLayouts = append(DateOnlyLayouts, []string{
TimestampDatetimeLayouts = append([]string{
time.RFC3339,
time.RFC3339Nano,
"2006-01-02 15:4",
"2006-01-02 15:04",
"2006-01-02 15:04:",
"2006-01-02 15:04:.",
"2006-01-02 15:04:05.",
"2006-01-02 15:04:05.999999",
"2006-1-2 15:4:5.999999",
time.RFC3339,
time.RFC3339Nano,
"2006-01-02T15:04:05",
"20060102150405",
"2006-01-02 15:04:05.999999999 -0700 MST", // represents standard Time.time.UTC()
}...)
}, DateOnlyLayouts...)

// zeroTime is 0000-01-01 00:00:00 UTC which is the closest Go can get to 0000-00-00 00:00:00
zeroTime = time.Unix(-62167219200, 0).UTC()
Expand Down Expand Up @@ -233,14 +233,8 @@ func (t datetimeType) ConvertWithoutRangeCheck(v interface{}) (time.Time, error)
return zeroTime, nil
}
// TODO: consider not using time.Parse if we want to match MySQL exactly ('2010-06-03 11:22.:.:.:.:' is a valid timestamp)
parsed := false
for _, fmt := range TimestampDatetimeLayouts {
if t, err := time.Parse(fmt, value); err == nil {
res = t.UTC()
parsed = true
break
}
}
var parsed bool
res, parsed = parseDatetime(value)
if !parsed {
return zeroTime, ErrConvertingToTime.New(v)
}
Expand Down Expand Up @@ -337,6 +331,15 @@ func (t datetimeType) ConvertWithoutRangeCheck(v interface{}) (time.Time, error)
return res, nil
}

func parseDatetime(value string) (time.Time, bool) {
for _, fmt := range TimestampDatetimeLayouts {
if t, err := time.Parse(fmt, value); err == nil {
return t.UTC(), true
}
}
return time.Time{}, false
}

func (t datetimeType) MustConvert(v interface{}) interface{} {
value, _, err := t.Convert(v)
if err != nil {
Expand Down Expand Up @@ -373,42 +376,42 @@ func (t datetimeType) SQL(_ *sql.Context, dest []byte, v interface{}) (sqltypes.
return sqltypes.NULL, nil
}

v, _, err := t.Convert(v)
vt, err := ConvertToTime(v, t)
if err != nil {
return sqltypes.Value{}, err
}
vt := v.(time.Time)

var typ query.Type
var val string
var val []byte

start := len(dest)
switch t.baseType {
case sqltypes.Date:
typ = sqltypes.Date
if vt.Equal(zeroTime) {
val = vt.Format(ZeroDateStr)
val = vt.AppendFormat(dest, ZeroDateStr)
} else {
val = vt.Format(sql.DateLayout)
val = vt.AppendFormat(dest, sql.DateLayout)
}
case sqltypes.Datetime:
typ = sqltypes.Datetime
if vt.Equal(zeroTime) {
val = vt.Format(ZeroTimestampDatetimeStr)
val = vt.AppendFormat(dest, ZeroTimestampDatetimeStr)
} else {
val = vt.Format(sql.TimestampDatetimeLayout)
val = vt.AppendFormat(dest, sql.TimestampDatetimeLayout)
}
case sqltypes.Timestamp:
typ = sqltypes.Timestamp
if vt.Equal(zeroTime) {
val = vt.Format(ZeroTimestampDatetimeStr)
val = vt.AppendFormat(dest, ZeroTimestampDatetimeStr)
} else {
val = vt.Format(sql.TimestampDatetimeLayout)
val = vt.AppendFormat(dest, sql.TimestampDatetimeLayout)
}
default:
panic(sql.ErrInvalidBaseType.New(t.baseType.String(), "datetime"))
return sqltypes.Value{}, sql.ErrInvalidBaseType.New(t.baseType.String(), "datetime")
}

valBytes := AppendAndSliceString(dest, val)
valBytes := val[start:]

return sqltypes.MakeTrusted(typ, valBytes), nil
}
Expand Down
12 changes: 7 additions & 5 deletions sql/types/decimal.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"gopkg.in/src-d/go-errors.v1"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/encodings"
)

const (
Expand Down Expand Up @@ -226,7 +227,7 @@ func (t DecimalType_) ConvertToNullDecimal(v interface{}) (decimal.NullDecimal,
case *big.Rat:
return t.ConvertToNullDecimal(new(big.Float).SetRat(value))
case decimal.Decimal:
if t.definesColumn {
if t.definesColumn && value.Exponent() != int32(t.scale) {
val, err := decimal.NewFromString(value.StringFixed(int32(t.scale)))
if err != nil {
return decimal.NullDecimal{}, err
Expand Down Expand Up @@ -311,9 +312,7 @@ func (t DecimalType_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltype
if err != nil {
return sqltypes.Value{}, err
}

val := AppendAndSliceString(dest, t.DecimalValueStringFixed(value.Decimal))

val := encodings.StringToBytes(t.DecimalValueStringFixed(value.Decimal))
return sqltypes.MakeTrusted(sqltypes.Decimal, val), nil
}

Expand Down Expand Up @@ -365,7 +364,10 @@ func (t DecimalType_) Scale() uint8 {
// it should use scale defined by the column. Otherwise, the result value should use its own precision and scale.
func (t DecimalType_) DecimalValueStringFixed(v decimal.Decimal) string {
if t.definesColumn {
return v.StringFixed(int32(t.scale))
if int32(t.scale) != v.Exponent() {
return v.StringFixed(int32(t.scale))
}
return v.String()
} else {
return v.StringFixed(v.Exponent() * -1)
}
Expand Down
15 changes: 11 additions & 4 deletions sql/types/enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ var (
type EnumType struct {
collation sql.CollationID
hashedValToIndex map[uint64]int
valToIdx map[string]int
indexToVal []string
maxResponseByteLength uint32
}
Expand All @@ -70,7 +71,8 @@ func CreateEnumType(values []string, collation sql.CollationID) (sql.EnumType, e
// including accounting for multibyte character representations.
var maxResponseByteLength uint32
maxCharLength := collation.Collation().CharacterSet.MaxLength()
valToIndex := make(map[uint64]int)
hashedValToIndex := make(map[uint64]int)
valToIdx := make(map[string]int)
for i, value := range values {
if !collation.Equals(sql.Collation_binary) {
// Trailing spaces are automatically deleted from ENUM member values in the table definition when a table
Expand All @@ -82,11 +84,12 @@ func CreateEnumType(values []string, collation sql.CollationID) (sql.EnumType, e
if err != nil {
return nil, err
}
if _, ok := valToIndex[hashedVal]; ok {
if _, ok := hashedValToIndex[hashedVal]; ok {
return nil, fmt.Errorf("duplicate entry: %v", value)
}
// The elements listed in the column specification are assigned index numbers, beginning with 1.
valToIndex[hashedVal] = i + 1
hashedValToIndex[hashedVal] = i + 1
valToIdx[value] = i + 1

byteLength := uint32(utf8.RuneCountInString(value) * int(maxCharLength))
if byteLength > maxResponseByteLength {
Expand All @@ -95,8 +98,9 @@ func CreateEnumType(values []string, collation sql.CollationID) (sql.EnumType, e
}
return EnumType{
collation: collation,
hashedValToIndex: valToIndex,
hashedValToIndex: hashedValToIndex,
indexToVal: values,
valToIdx: valToIdx,
maxResponseByteLength: maxResponseByteLength,
}, nil
}
Expand Down Expand Up @@ -309,6 +313,9 @@ func (t EnumType) Collation() sql.CollationID {

// IndexOf implements EnumType interface.
func (t EnumType) IndexOf(v string) int {
if idx, ok := t.valToIdx[v]; ok {
return idx
}
hashedVal, err := t.collation.HashToUint(v)
if err == nil {
if index, ok := t.hashedValToIndex[hashedVal]; ok {
Expand Down
Loading
Loading