Skip to content

Commit 1d3b1bb

Browse files
authored
Merge pull request #3176 from dolthub/angela/casting
Trim strings to number prefix when cast
2 parents a5d281d + 999ab3c commit 1d3b1bb

File tree

5 files changed

+257
-16
lines changed

5 files changed

+257
-16
lines changed

enginetest/queries/function_queries.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1529,3 +1529,16 @@ var FunctionQueryTests = []QueryTest{
15291529
},
15301530
},
15311531
}
1532+
1533+
// BrokenFunctionQueryTests contains SQL function call queries that don't match MySQL behavior
1534+
var BrokenFunctionQueryTests = []QueryTest{
1535+
// https://github.com/dolthub/dolt/issues/9735
1536+
{
1537+
Query: "select log('10asdf', '100f')",
1538+
Expected: []sql.Row{{float64(2)}},
1539+
},
1540+
{
1541+
Query: "select log('a10asdf', 'b100f')",
1542+
Expected: []sql.Row{{nil}},
1543+
},
1544+
}

enginetest/queries/script_queries.go

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11439,6 +11439,195 @@ select * from t1 except (
1143911439
},
1144011440
},
1144111441
},
11442+
{
11443+
// https://github.com/dolthub/dolt/issues/9733
11444+
// https://github.com/dolthub/dolt/issues/9739
11445+
Name: "strings cast to numbers",
11446+
SetUpScript: []string{
11447+
"create table test01(pk varchar(20) primary key)",
11448+
`insert into test01 values (' 3 12 4'),
11449+
(' 3.2 12 4'),('-3.1234'),('-3.1a'),('-5+8'),('+3.1234'),
11450+
('11d'),('11wha?'),('11'),('12'),('1a1'),('a1a1'),('11-5'),
11451+
('3. 12 4'),('5.932887e+07'),('5.932887e+07abc'),('5.932887e7'),('5.932887e7abc')`,
11452+
"create table test02(pk int primary key)",
11453+
"insert into test02 values(11),(12),(13),(14),(15)",
11454+
},
11455+
Assertions: []ScriptTestAssertion{
11456+
{
11457+
Dialect: "mysql",
11458+
Query: "select pk, cast(pk as float) from test01",
11459+
Expected: []sql.Row{
11460+
{" 3 12 4", float32(3)},
11461+
{" 3.2 12 4", float32(3.2)},
11462+
{"-3.1234", float32(-3.1234)},
11463+
{"-3.1a", float32(-3.1)},
11464+
{"-5+8", float32(-5)},
11465+
{"+3.1234", float32(3.1234)},
11466+
{"11", float32(11)},
11467+
{"11-5", float32(11)},
11468+
{"11d", float32(11)},
11469+
{"11wha?", float32(11)},
11470+
{"12", float32(12)},
11471+
{"1a1", float32(1)},
11472+
{"3. 12 4", float32(3)},
11473+
{"5.932887e+07", float32(5.932887e+07)},
11474+
{"5.932887e+07abc", float32(5.932887e+07)},
11475+
{"5.932887e7", float32(5.932887e+07)},
11476+
{"5.932887e7abc", float32(5.932887e+07)},
11477+
{"a1a1", float32(0)},
11478+
},
11479+
},
11480+
{
11481+
Dialect: "mysql",
11482+
Query: "select pk, cast(pk as double) from test01",
11483+
Expected: []sql.Row{
11484+
{" 3 12 4", 3.0},
11485+
{" 3.2 12 4", 3.2},
11486+
{"-3.1234", -3.1234},
11487+
{"-3.1a", -3.1},
11488+
{"-5+8", -5.0},
11489+
{"+3.1234", 3.1234},
11490+
{"11", 11.0},
11491+
{"11-5", 11.0},
11492+
{"11d", 11.0},
11493+
{"11wha?", 11.0},
11494+
{"12", 12.0},
11495+
{"1a1", 1.0},
11496+
{"3. 12 4", 3.0},
11497+
{"5.932887e+07", 5.932887e+07},
11498+
{"5.932887e+07abc", 5.932887e+07},
11499+
{"5.932887e7", 5.932887e+07},
11500+
{"5.932887e7abc", 5.932887e+07},
11501+
{"a1a1", 0.0},
11502+
},
11503+
},
11504+
{
11505+
Dialect: "mysql",
11506+
Query: "select pk, cast(pk as signed) from test01",
11507+
Expected: []sql.Row{
11508+
{" 3 12 4", 3},
11509+
{" 3.2 12 4", 3},
11510+
{"-3.1234", -3},
11511+
{"-3.1a", -3},
11512+
{"-5+8", -5},
11513+
{"+3.1234", 3},
11514+
{"11", 11},
11515+
{"11-5", 11},
11516+
{"11d", 11},
11517+
{"11wha?", 11},
11518+
{"12", 12},
11519+
{"1a1", 1},
11520+
{"3. 12 4", 3},
11521+
{"5.932887e+07", 5},
11522+
{"5.932887e+07abc", 5},
11523+
{"5.932887e7", 5},
11524+
{"5.932887e7abc", 5},
11525+
{"a1a1", 0},
11526+
},
11527+
},
11528+
{
11529+
Dialect: "mysql",
11530+
Query: "select pk, cast(pk as unsigned) from test01",
11531+
Expected: []sql.Row{
11532+
{" 3 12 4", uint64(3)},
11533+
{" 3.2 12 4", uint64(3)},
11534+
{"-3.1234", uint64(18446744073709551613)},
11535+
{"-3.1a", uint64(18446744073709551613)},
11536+
{"-5+8", uint64(18446744073709551611)},
11537+
{"+3.1234", uint64(3)},
11538+
{"11", uint64(11)},
11539+
{"11-5", uint64(11)},
11540+
{"11d", uint64(11)},
11541+
{"11wha?", uint64(11)},
11542+
{"12", uint64(12)},
11543+
{"1a1", uint64(1)},
11544+
{"3. 12 4", uint64(3)},
11545+
{"5.932887e+07", uint64(5)},
11546+
{"5.932887e+07abc", uint64(5)},
11547+
{"5.932887e7", uint64(5)},
11548+
{"5.932887e7abc", uint64(5)},
11549+
{"a1a1", uint64(0)},
11550+
},
11551+
},
11552+
{
11553+
Dialect: "mysql",
11554+
Query: "select pk, cast(pk as decimal(12,3)) from test01",
11555+
Expected: []sql.Row{
11556+
{" 3 12 4", "3.000"},
11557+
{" 3.2 12 4", "3.200"},
11558+
{"-3.1234", "-3.123"},
11559+
{"-3.1a", "-3.100"},
11560+
{"-5+8", "-5.000"},
11561+
{"+3.1234", "3.123"},
11562+
{"11", "11.000"},
11563+
{"11-5", "11.000"},
11564+
{"11d", "11.000"},
11565+
{"11wha?", "11.000"},
11566+
{"12", "12.000"},
11567+
{"1a1", "1.000"},
11568+
{"3. 12 4", "3.000"},
11569+
{"5.932887e+07", "59328870.000"},
11570+
{"5.932887e+07abc", "59328870.000"},
11571+
{"5.932887e7", "59328870.000"},
11572+
{"5.932887e7abc", "59328870.000"},
11573+
{"a1a1", "0.000"},
11574+
},
11575+
},
11576+
{
11577+
Query: "select * from test01 where pk in ('11')",
11578+
Expected: []sql.Row{{"11"}},
11579+
},
11580+
{
11581+
// https://github.com/dolthub/dolt/issues/9739
11582+
Skip: true,
11583+
Dialect: "mysql",
11584+
Query: "select * from test01 where pk in (11)",
11585+
Expected: []sql.Row{
11586+
{"11"},
11587+
{"11d"},
11588+
{"11wha?"},
11589+
},
11590+
},
11591+
{
11592+
// https://github.com/dolthub/dolt/issues/9739
11593+
Skip: true,
11594+
Dialect: "mysql",
11595+
Query: "select * from test01 where pk=3",
11596+
Expected: []sql.Row{
11597+
{" 3 12 4"},
11598+
{" 3. 12 4"},
11599+
{"3. 12 4"},
11600+
},
11601+
},
11602+
{
11603+
// https://github.com/dolthub/dolt/issues/9739
11604+
Skip: true,
11605+
Dialect: "mysql",
11606+
Query: "select * from test01 where pk>=3 and pk < 4",
11607+
Expected: []sql.Row{
11608+
{" 3 12 4"},
11609+
{" 3. 12 4"},
11610+
{" 3.2 12 4"},
11611+
{"+3.1234"},
11612+
{"3. 12 4"},
11613+
},
11614+
},
11615+
{
11616+
// https://github.com/dolthub/dolt/issues/9739
11617+
Skip: true,
11618+
Dialect: "mysql",
11619+
Query: "select * from test02 where pk in ('11asdf')",
11620+
Expected: []sql.Row{{"11"}},
11621+
},
11622+
{
11623+
// https://github.com/dolthub/dolt/issues/9739
11624+
Skip: true,
11625+
Dialect: "mysql",
11626+
Query: "select * from test02 where pk='11.12asdf'",
11627+
Expected: []sql.Row{},
11628+
},
11629+
},
11630+
},
1144211631
}
1144311632

1144411633
var SpatialScriptTests = []ScriptTest{

sql/expression/convert.go

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"strconv"
2121
"strings"
2222
"time"
23+
"unicode"
2324

2425
"github.com/dolthub/vitess/go/sqltypes"
2526
"github.com/sirupsen/logrus"
@@ -342,7 +343,7 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s
342343
}
343344
return d, nil
344345
case ConvertToDecimal:
345-
value, err := convertHexBlobToDecimalForNumericContext(val, originType)
346+
value, err := prepareForNumericContext(val, originType, false)
346347
if err != nil {
347348
return nil, err
348349
}
@@ -353,7 +354,7 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s
353354
}
354355
return d, nil
355356
case ConvertToFloat:
356-
value, err := convertHexBlobToDecimalForNumericContext(val, originType)
357+
value, err := prepareForNumericContext(val, originType, false)
357358
if err != nil {
358359
return nil, err
359360
}
@@ -363,7 +364,7 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s
363364
}
364365
return d, nil
365366
case ConvertToDouble, ConvertToReal:
366-
value, err := convertHexBlobToDecimalForNumericContext(val, originType)
367+
value, err := prepareForNumericContext(val, originType, false)
367368
if err != nil {
368369
return nil, err
369370
}
@@ -379,7 +380,7 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s
379380
}
380381
return js, nil
381382
case ConvertToSigned:
382-
value, err := convertHexBlobToDecimalForNumericContext(val, originType)
383+
value, err := prepareForNumericContext(val, originType, true)
383384
if err != nil {
384385
return nil, err
385386
}
@@ -396,7 +397,7 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s
396397
}
397398
return t, nil
398399
case ConvertToUnsigned:
399-
value, err := convertHexBlobToDecimalForNumericContext(val, originType)
400+
value, err := prepareForNumericContext(val, originType, true)
400401
if err != nil {
401402
return nil, err
402403
}
@@ -473,6 +474,44 @@ func createConvertedDecimalType(length, scale int, logErrors bool) sql.DecimalTy
473474
return types.InternalDecimalType
474475
}
475476

477+
// prepareForNumberContext makes necessary preparations to strings and byte arrays for conversions to numbers
478+
func prepareForNumericContext(val interface{}, originType sql.Type, isInt bool) (interface{}, error) {
479+
if s, isString := val.(string); isString && types.IsTextOnly(originType) {
480+
return trimStringToNumberPrefix(s, isInt), nil
481+
}
482+
return convertHexBlobToDecimalForNumericContext(val, originType)
483+
}
484+
485+
// trimStringToNumberPrefix trims a string to the appropriate number prefix
486+
func trimStringToNumberPrefix(s string, isInt bool) string {
487+
if isInt {
488+
s = strings.TrimLeft(s, types.IntCutSet)
489+
} else {
490+
s = strings.TrimLeft(s, types.NumericCutSet)
491+
}
492+
493+
seenDigit := false
494+
seenDot := false
495+
seenExp := false
496+
signIndex := 0
497+
498+
for i := 0; i < len(s); i++ {
499+
char := rune(s[i])
500+
501+
if unicode.IsDigit(char) {
502+
seenDigit = true
503+
} else if char == '.' && !seenDot && !isInt {
504+
seenDot = true
505+
} else if (char == 'e' || char == 'E') && !seenExp && seenDigit && !isInt {
506+
seenExp = true
507+
signIndex = i + 1
508+
} else if !((char == '-' || char == '+') && i == signIndex) {
509+
return s[:i]
510+
}
511+
}
512+
return s
513+
}
514+
476515
// convertHexBlobToDecimalForNumericContext converts byte array value to unsigned int value if originType is BLOB type.
477516
// This function is called when convertTo type is number type only. The hex literal values are parsed into blobs as
478517
// binary string as default, but for numeric context, the value should be a number.

sql/types/decimal.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ func (t DecimalType_) ConvertToNullDecimal(v interface{}) (decimal.NullDecimal,
202202
return t.ConvertToNullDecimal(decimal.NewFromFloat(value))
203203
case string:
204204
// TODO: implement truncation here
205-
value = strings.Trim(value, numericCutSet)
205+
value = strings.Trim(value, NumericCutSet)
206206
if len(value) == 0 {
207207
return t.ConvertToNullDecimal(decimal.NewFromInt(0))
208208
}

sql/types/number.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,12 @@ var (
8888
)
8989

9090
const (
91-
// intCutSet is the set of characters that should be trimmed from the beginning and end of a string
91+
// IntCutSet is the set of characters that should be trimmed from the beginning and end of a string
9292
// when converting to a signed or unsigned integer
93-
intCutSet = " \t"
93+
IntCutSet = " \t"
9494

95-
// numericCutSet is the set of characters to trim from a string before converting it to a number.
96-
numericCutSet = " \t\n\r"
95+
// NumericCutSet is the set of characters to trim from a string before converting it to a number.
96+
NumericCutSet = " \t\n\r"
9797
)
9898

9999
type NumberTypeImpl_ struct {
@@ -991,7 +991,7 @@ func convertToInt64(t NumberTypeImpl_, v interface{}) (int64, sql.ConvertInRange
991991
}
992992
return i, sql.InRange, nil
993993
case string:
994-
v = strings.Trim(v, intCutSet)
994+
v = strings.Trim(v, IntCutSet)
995995
if v == "" {
996996
// StringType{}.Zero() returns empty string, but should represent "0" for number value
997997
return 0, sql.InRange, nil
@@ -1178,7 +1178,7 @@ func convertToUint64(t NumberTypeImpl_, v interface{}) (uint64, sql.ConvertInRan
11781178
}
11791179
return i, sql.InRange, nil
11801180
case string:
1181-
v = strings.Trim(v, intCutSet)
1181+
v = strings.Trim(v, IntCutSet)
11821182
if i, err := strconv.ParseUint(v, 10, 64); err == nil {
11831183
return i, sql.InRange, nil
11841184
} else if err == strconv.ErrRange {
@@ -1281,7 +1281,7 @@ func convertToUint32(t NumberTypeImpl_, v interface{}) (uint32, sql.ConvertInRan
12811281
}
12821282
return uint32(i), sql.InRange, nil
12831283
case string:
1284-
v = strings.Trim(v, intCutSet)
1284+
v = strings.Trim(v, IntCutSet)
12851285
if i, err := strconv.ParseUint(v, 10, 32); err == nil {
12861286
return uint32(i), sql.InRange, nil
12871287
}
@@ -1377,7 +1377,7 @@ func convertToUint16(t NumberTypeImpl_, v interface{}) (uint16, sql.ConvertInRan
13771377
}
13781378
return uint16(i), sql.InRange, nil
13791379
case string:
1380-
v = strings.Trim(v, intCutSet)
1380+
v = strings.Trim(v, IntCutSet)
13811381
if i, err := strconv.ParseUint(v, 10, 16); err == nil {
13821382
return uint16(i), sql.InRange, nil
13831383
}
@@ -1477,7 +1477,7 @@ func convertToUint8(t NumberTypeImpl_, v interface{}) (uint8, sql.ConvertInRange
14771477
}
14781478
return uint8(i), sql.InRange, nil
14791479
case string:
1480-
v = strings.Trim(v, intCutSet)
1480+
v = strings.Trim(v, IntCutSet)
14811481
if i, err := strconv.ParseUint(v, 10, 8); err == nil {
14821482
return uint8(i), sql.InRange, nil
14831483
}
@@ -1537,7 +1537,7 @@ func convertToFloat64(t NumberTypeImpl_, v interface{}) (float64, error) {
15371537
}
15381538
return float64(i), nil
15391539
case string:
1540-
v = strings.Trim(v, numericCutSet)
1540+
v = strings.Trim(v, NumericCutSet)
15411541
i, err := strconv.ParseFloat(v, 64)
15421542
if err != nil {
15431543
// parse the first longest valid numbers

0 commit comments

Comments
 (0)