Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions enginetest/queries/function_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -1529,3 +1529,16 @@ var FunctionQueryTests = []QueryTest{
},
},
}

// BrokenFunctionQueryTests contains SQL function call queries that don't match MySQL behavior
var BrokenFunctionQueryTests = []QueryTest{
// https://github.com/dolthub/dolt/issues/9735
{
Query: "select log('10asdf', '100f')",
Expected: []sql.Row{{float64(2)}},
},
{
Query: "select log('a10asdf', 'b100f')",
Expected: []sql.Row{{nil}},
},
}
189 changes: 189 additions & 0 deletions enginetest/queries/script_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -11439,6 +11439,195 @@ select * from t1 except (
},
},
},
{
// https://github.com/dolthub/dolt/issues/9733
// https://github.com/dolthub/dolt/issues/9739
Name: "strings cast to numbers",
SetUpScript: []string{
"create table test01(pk varchar(20) primary key)",
`insert into test01 values (' 3 12 4'),
(' 3.2 12 4'),('-3.1234'),('-3.1a'),('-5+8'),('+3.1234'),
('11d'),('11wha?'),('11'),('12'),('1a1'),('a1a1'),('11-5'),
('3. 12 4'),('5.932887e+07'),('5.932887e+07abc'),('5.932887e7'),('5.932887e7abc')`,
"create table test02(pk int primary key)",
"insert into test02 values(11),(12),(13),(14),(15)",
},
Assertions: []ScriptTestAssertion{
{
Dialect: "mysql",
Query: "select pk, cast(pk as float) from test01",
Expected: []sql.Row{
{" 3 12 4", float32(3)},
{" 3.2 12 4", float32(3.2)},
{"-3.1234", float32(-3.1234)},
{"-3.1a", float32(-3.1)},
{"-5+8", float32(-5)},
{"+3.1234", float32(3.1234)},
{"11", float32(11)},
{"11-5", float32(11)},
{"11d", float32(11)},
{"11wha?", float32(11)},
{"12", float32(12)},
{"1a1", float32(1)},
{"3. 12 4", float32(3)},
{"5.932887e+07", float32(5.932887e+07)},
{"5.932887e+07abc", float32(5.932887e+07)},
{"5.932887e7", float32(5.932887e+07)},
{"5.932887e7abc", float32(5.932887e+07)},
{"a1a1", float32(0)},
},
},
{
Dialect: "mysql",
Query: "select pk, cast(pk as double) from test01",
Expected: []sql.Row{
{" 3 12 4", 3.0},
{" 3.2 12 4", 3.2},
{"-3.1234", -3.1234},
{"-3.1a", -3.1},
{"-5+8", -5.0},
{"+3.1234", 3.1234},
{"11", 11.0},
{"11-5", 11.0},
{"11d", 11.0},
{"11wha?", 11.0},
{"12", 12.0},
{"1a1", 1.0},
{"3. 12 4", 3.0},
{"5.932887e+07", 5.932887e+07},
{"5.932887e+07abc", 5.932887e+07},
{"5.932887e7", 5.932887e+07},
{"5.932887e7abc", 5.932887e+07},
{"a1a1", 0.0},
},
},
{
Dialect: "mysql",
Query: "select pk, cast(pk as signed) from test01",
Expected: []sql.Row{
{" 3 12 4", 3},
{" 3.2 12 4", 3},
{"-3.1234", -3},
{"-3.1a", -3},
{"-5+8", -5},
{"+3.1234", 3},
{"11", 11},
{"11-5", 11},
{"11d", 11},
{"11wha?", 11},
{"12", 12},
{"1a1", 1},
{"3. 12 4", 3},
{"5.932887e+07", 5},
{"5.932887e+07abc", 5},
{"5.932887e7", 5},
{"5.932887e7abc", 5},
{"a1a1", 0},
},
},
{
Dialect: "mysql",
Query: "select pk, cast(pk as unsigned) from test01",
Expected: []sql.Row{
{" 3 12 4", uint64(3)},
{" 3.2 12 4", uint64(3)},
{"-3.1234", uint64(18446744073709551613)},
{"-3.1a", uint64(18446744073709551613)},
{"-5+8", uint64(18446744073709551611)},
{"+3.1234", uint64(3)},
{"11", uint64(11)},
{"11-5", uint64(11)},
{"11d", uint64(11)},
{"11wha?", uint64(11)},
{"12", uint64(12)},
{"1a1", uint64(1)},
{"3. 12 4", uint64(3)},
{"5.932887e+07", uint64(5)},
{"5.932887e+07abc", uint64(5)},
{"5.932887e7", uint64(5)},
{"5.932887e7abc", uint64(5)},
{"a1a1", uint64(0)},
},
},
{
Dialect: "mysql",
Query: "select pk, cast(pk as decimal(12,3)) from test01",
Expected: []sql.Row{
{" 3 12 4", "3.000"},
{" 3.2 12 4", "3.200"},
{"-3.1234", "-3.123"},
{"-3.1a", "-3.100"},
{"-5+8", "-5.000"},
{"+3.1234", "3.123"},
{"11", "11.000"},
{"11-5", "11.000"},
{"11d", "11.000"},
{"11wha?", "11.000"},
{"12", "12.000"},
{"1a1", "1.000"},
{"3. 12 4", "3.000"},
{"5.932887e+07", "59328870.000"},
{"5.932887e+07abc", "59328870.000"},
{"5.932887e7", "59328870.000"},
{"5.932887e7abc", "59328870.000"},
{"a1a1", "0.000"},
},
},
{
Query: "select * from test01 where pk in ('11')",
Expected: []sql.Row{{"11"}},
},
{
// https://github.com/dolthub/dolt/issues/9739
Skip: true,
Dialect: "mysql",
Query: "select * from test01 where pk in (11)",
Expected: []sql.Row{
{"11"},
{"11d"},
{"11wha?"},
},
},
{
// https://github.com/dolthub/dolt/issues/9739
Skip: true,
Dialect: "mysql",
Query: "select * from test01 where pk=3",
Expected: []sql.Row{
{" 3 12 4"},
{" 3. 12 4"},
{"3. 12 4"},
},
},
{
// https://github.com/dolthub/dolt/issues/9739
Skip: true,
Dialect: "mysql",
Query: "select * from test01 where pk>=3 and pk < 4",
Expected: []sql.Row{
{" 3 12 4"},
{" 3. 12 4"},
{" 3.2 12 4"},
{"+3.1234"},
{"3. 12 4"},
},
},
{
// https://github.com/dolthub/dolt/issues/9739
Skip: true,
Dialect: "mysql",
Query: "select * from test02 where pk in ('11asdf')",
Expected: []sql.Row{{"11"}},
},
{
// https://github.com/dolthub/dolt/issues/9739
Skip: true,
Dialect: "mysql",
Query: "select * from test02 where pk='11.12asdf'",
Expected: []sql.Row{},
},
},
},
}

var SpatialScriptTests = []ScriptTest{
Expand Down
49 changes: 44 additions & 5 deletions sql/expression/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"strconv"
"strings"
"time"
"unicode"

"github.com/dolthub/vitess/go/sqltypes"
"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -342,7 +343,7 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s
}
return d, nil
case ConvertToDecimal:
value, err := convertHexBlobToDecimalForNumericContext(val, originType)
value, err := prepareForNumericContext(val, originType, false)
if err != nil {
return nil, err
}
Expand All @@ -353,7 +354,7 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s
}
return d, nil
case ConvertToFloat:
value, err := convertHexBlobToDecimalForNumericContext(val, originType)
value, err := prepareForNumericContext(val, originType, false)
if err != nil {
return nil, err
}
Expand All @@ -363,7 +364,7 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s
}
return d, nil
case ConvertToDouble, ConvertToReal:
value, err := convertHexBlobToDecimalForNumericContext(val, originType)
value, err := prepareForNumericContext(val, originType, false)
if err != nil {
return nil, err
}
Expand All @@ -379,7 +380,7 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s
}
return js, nil
case ConvertToSigned:
value, err := convertHexBlobToDecimalForNumericContext(val, originType)
value, err := prepareForNumericContext(val, originType, true)
if err != nil {
return nil, err
}
Expand All @@ -396,7 +397,7 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s
}
return t, nil
case ConvertToUnsigned:
value, err := convertHexBlobToDecimalForNumericContext(val, originType)
value, err := prepareForNumericContext(val, originType, true)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -473,6 +474,44 @@ func createConvertedDecimalType(length, scale int, logErrors bool) sql.DecimalTy
return types.InternalDecimalType
}

// prepareForNumberContext makes necessary preparations to strings and byte arrays for conversions to numbers
func prepareForNumericContext(val interface{}, originType sql.Type, isInt bool) (interface{}, error) {
if s, isString := val.(string); isString && types.IsTextOnly(originType) {
return trimStringToNumberPrefix(s, isInt), nil
}
return convertHexBlobToDecimalForNumericContext(val, originType)
}

// trimStringToNumberPrefix trims a string to the appropriate number prefix
func trimStringToNumberPrefix(s string, isInt bool) string {
if isInt {
s = strings.TrimLeft(s, types.IntCutSet)
} else {
s = strings.TrimLeft(s, types.NumericCutSet)
}

seenDigit := false
seenDot := false
seenExp := false
signIndex := 0

for i := 0; i < len(s); i++ {
char := rune(s[i])

if unicode.IsDigit(char) {
seenDigit = true
} else if char == '.' && !seenDot && !isInt {
seenDot = true
} else if (char == 'e' || char == 'E') && !seenExp && seenDigit && !isInt {
seenExp = true
signIndex = i + 1
} else if !((char == '-' || char == '+') && i == signIndex) {
return s[:i]
}
}
return s
}

// convertHexBlobToDecimalForNumericContext converts byte array value to unsigned int value if originType is BLOB type.
// This function is called when convertTo type is number type only. The hex literal values are parsed into blobs as
// binary string as default, but for numeric context, the value should be a number.
Expand Down
2 changes: 1 addition & 1 deletion sql/types/decimal.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ func (t DecimalType_) ConvertToNullDecimal(v interface{}) (decimal.NullDecimal,
return t.ConvertToNullDecimal(decimal.NewFromFloat(value))
case string:
// TODO: implement truncation here
value = strings.Trim(value, numericCutSet)
value = strings.Trim(value, NumericCutSet)
if len(value) == 0 {
return t.ConvertToNullDecimal(decimal.NewFromInt(0))
}
Expand Down
20 changes: 10 additions & 10 deletions sql/types/number.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ var (
)

const (
// intCutSet is the set of characters that should be trimmed from the beginning and end of a string
// IntCutSet is the set of characters that should be trimmed from the beginning and end of a string
// when converting to a signed or unsigned integer
intCutSet = " \t"
IntCutSet = " \t"

// numericCutSet is the set of characters to trim from a string before converting it to a number.
numericCutSet = " \t\n\r"
// NumericCutSet is the set of characters to trim from a string before converting it to a number.
NumericCutSet = " \t\n\r"
)

type NumberTypeImpl_ struct {
Expand Down Expand Up @@ -991,7 +991,7 @@ func convertToInt64(t NumberTypeImpl_, v interface{}) (int64, sql.ConvertInRange
}
return i, sql.InRange, nil
case string:
v = strings.Trim(v, intCutSet)
v = strings.Trim(v, IntCutSet)
if v == "" {
// StringType{}.Zero() returns empty string, but should represent "0" for number value
return 0, sql.InRange, nil
Expand Down Expand Up @@ -1178,7 +1178,7 @@ func convertToUint64(t NumberTypeImpl_, v interface{}) (uint64, sql.ConvertInRan
}
return i, sql.InRange, nil
case string:
v = strings.Trim(v, intCutSet)
v = strings.Trim(v, IntCutSet)
if i, err := strconv.ParseUint(v, 10, 64); err == nil {
return i, sql.InRange, nil
} else if err == strconv.ErrRange {
Expand Down Expand Up @@ -1281,7 +1281,7 @@ func convertToUint32(t NumberTypeImpl_, v interface{}) (uint32, sql.ConvertInRan
}
return uint32(i), sql.InRange, nil
case string:
v = strings.Trim(v, intCutSet)
v = strings.Trim(v, IntCutSet)
if i, err := strconv.ParseUint(v, 10, 32); err == nil {
return uint32(i), sql.InRange, nil
}
Expand Down Expand Up @@ -1377,7 +1377,7 @@ func convertToUint16(t NumberTypeImpl_, v interface{}) (uint16, sql.ConvertInRan
}
return uint16(i), sql.InRange, nil
case string:
v = strings.Trim(v, intCutSet)
v = strings.Trim(v, IntCutSet)
if i, err := strconv.ParseUint(v, 10, 16); err == nil {
return uint16(i), sql.InRange, nil
}
Expand Down Expand Up @@ -1477,7 +1477,7 @@ func convertToUint8(t NumberTypeImpl_, v interface{}) (uint8, sql.ConvertInRange
}
return uint8(i), sql.InRange, nil
case string:
v = strings.Trim(v, intCutSet)
v = strings.Trim(v, IntCutSet)
if i, err := strconv.ParseUint(v, 10, 8); err == nil {
return uint8(i), sql.InRange, nil
}
Expand Down Expand Up @@ -1537,7 +1537,7 @@ func convertToFloat64(t NumberTypeImpl_, v interface{}) (float64, error) {
}
return float64(i), nil
case string:
v = strings.Trim(v, numericCutSet)
v = strings.Trim(v, NumericCutSet)
i, err := strconv.ParseFloat(v, 64)
if err != nil {
// parse the first longest valid numbers
Expand Down
Loading