Skip to content

Commit 7551888

Browse files
committed
move string trimming to expression/convert.go
1 parent c641501 commit 7551888

File tree

3 files changed

+52
-41
lines changed

3 files changed

+52
-41
lines changed

sql/expression/convert.go

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"strconv"
2121
"strings"
2222
"time"
23+
"unicode"
2324

2425
"github.com/dolthub/vitess/go/sqltypes"
2526
"github.com/sirupsen/logrus"
@@ -353,7 +354,7 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s
353354
}
354355
return d, nil
355356
case ConvertToFloat:
356-
value, err := convertHexBlobToDecimalForNumericContext(val, originType)
357+
value, err := prepareForNumericContext(val, originType, false)
357358
if err != nil {
358359
return nil, err
359360
}
@@ -363,7 +364,7 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s
363364
}
364365
return d, nil
365366
case ConvertToDouble, ConvertToReal:
366-
value, err := convertHexBlobToDecimalForNumericContext(val, originType)
367+
value, err := prepareForNumericContext(val, originType, false)
367368
if err != nil {
368369
return nil, err
369370
}
@@ -379,7 +380,7 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s
379380
}
380381
return js, nil
381382
case ConvertToSigned:
382-
value, err := convertHexBlobToDecimalForNumericContext(val, originType)
383+
value, err := prepareForNumericContext(val, originType, true)
383384
if err != nil {
384385
return nil, err
385386
}
@@ -396,7 +397,7 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s
396397
}
397398
return t, nil
398399
case ConvertToUnsigned:
399-
value, err := convertHexBlobToDecimalForNumericContext(val, originType)
400+
value, err := prepareForNumericContext(val, originType, true)
400401
if err != nil {
401402
return nil, err
402403
}
@@ -473,6 +474,42 @@ func createConvertedDecimalType(length, scale int, logErrors bool) sql.DecimalTy
473474
return types.InternalDecimalType
474475
}
475476

477+
func prepareForNumericContext(val interface{}, originType sql.Type, isInt bool) (interface{}, error) {
478+
if s, isString := val.(string); isString && types.IsTextOnly(originType) {
479+
return trimStringToNumberPrefix(s, isInt), nil
480+
}
481+
return convertHexBlobToDecimalForNumericContext(val, originType)
482+
}
483+
484+
func trimStringToNumberPrefix(s string, isInt bool) string {
485+
if isInt {
486+
s = strings.Trim(s, types.IntCutSet)
487+
} else {
488+
s = strings.Trim(s, types.NumericCutSet)
489+
}
490+
491+
seenDigit := false
492+
seenDot := false
493+
seenExp := false
494+
signIndex := 0
495+
496+
for i := 0; i < len(s); i++ {
497+
char := rune(s[i])
498+
499+
if unicode.IsDigit(char) {
500+
seenDigit = true
501+
} else if char == '.' && !seenDot && !isInt {
502+
seenDot = true
503+
} else if (char == 'e' || char == 'E') && !seenExp && seenDigit && !isInt {
504+
seenExp = true
505+
signIndex = i + 1
506+
} else if !((char == '-' || char == '+') && i == signIndex) {
507+
return s[:i]
508+
}
509+
}
510+
return s
511+
}
512+
476513
// convertHexBlobToDecimalForNumericContext converts byte array value to unsigned int value if originType is BLOB type.
477514
// This function is called when convertTo type is number type only. The hex literal values are parsed into blobs as
478515
// binary string as default, but for numeric context, the value should be a number.

sql/types/decimal.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ func (t DecimalType_) ConvertToNullDecimal(v interface{}) (decimal.NullDecimal,
202202
return t.ConvertToNullDecimal(decimal.NewFromFloat(value))
203203
case string:
204204
// TODO: implement truncation here
205-
value = strings.Trim(value, numericCutSet)
205+
value = strings.Trim(value, NumericCutSet)
206206
if len(value) == 0 {
207207
return t.ConvertToNullDecimal(decimal.NewFromInt(0))
208208
}

sql/types/number.go

Lines changed: 10 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import (
2424
"strconv"
2525
"strings"
2626
"time"
27-
"unicode"
2827

2928
"github.com/dolthub/vitess/go/sqltypes"
3029
"github.com/dolthub/vitess/go/vt/proto/query"
@@ -89,12 +88,12 @@ var (
8988
)
9089

9190
const (
92-
// intCutSet is the set of characters that should be trimmed from the beginning and end of a string
91+
// IntCutSet is the set of characters that should be trimmed from the beginning and end of a string
9392
// when converting to a signed or unsigned integer
94-
intCutSet = " \t"
93+
IntCutSet = " \t"
9594

96-
// numericCutSet is the set of characters to trim from a string before converting it to a number.
97-
numericCutSet = " \t\n\r"
95+
// NumericCutSet is the set of characters to trim from a string before converting it to a number.
96+
NumericCutSet = " \t\n\r"
9897
)
9998

10099
type NumberTypeImpl_ struct {
@@ -992,7 +991,7 @@ func convertToInt64(t NumberTypeImpl_, v interface{}) (int64, sql.ConvertInRange
992991
}
993992
return i, sql.InRange, nil
994993
case string:
995-
v = trimStringToNumberPrefix(v, false)
994+
v = strings.Trim(v, IntCutSet)
996995
if v == "" {
997996
// StringType{}.Zero() returns empty string, but should represent "0" for number value
998997
return 0, sql.InRange, nil
@@ -1179,7 +1178,7 @@ func convertToUint64(t NumberTypeImpl_, v interface{}) (uint64, sql.ConvertInRan
11791178
}
11801179
return i, sql.InRange, nil
11811180
case string:
1182-
v = trimStringToNumberPrefix(v, false)
1181+
v = strings.Trim(v, IntCutSet)
11831182
if i, err := strconv.ParseUint(v, 10, 64); err == nil {
11841183
return i, sql.InRange, nil
11851184
} else if err == strconv.ErrRange {
@@ -1282,7 +1281,7 @@ func convertToUint32(t NumberTypeImpl_, v interface{}) (uint32, sql.ConvertInRan
12821281
}
12831282
return uint32(i), sql.InRange, nil
12841283
case string:
1285-
v = strings.Trim(v, intCutSet)
1284+
v = strings.Trim(v, IntCutSet)
12861285
if i, err := strconv.ParseUint(v, 10, 32); err == nil {
12871286
return uint32(i), sql.InRange, nil
12881287
}
@@ -1378,7 +1377,7 @@ func convertToUint16(t NumberTypeImpl_, v interface{}) (uint16, sql.ConvertInRan
13781377
}
13791378
return uint16(i), sql.InRange, nil
13801379
case string:
1381-
v = strings.Trim(v, intCutSet)
1380+
v = strings.Trim(v, IntCutSet)
13821381
if i, err := strconv.ParseUint(v, 10, 16); err == nil {
13831382
return uint16(i), sql.InRange, nil
13841383
}
@@ -1478,7 +1477,7 @@ func convertToUint8(t NumberTypeImpl_, v interface{}) (uint8, sql.ConvertInRange
14781477
}
14791478
return uint8(i), sql.InRange, nil
14801479
case string:
1481-
v = strings.Trim(v, intCutSet)
1480+
v = strings.Trim(v, IntCutSet)
14821481
if i, err := strconv.ParseUint(v, 10, 8); err == nil {
14831482
return uint8(i), sql.InRange, nil
14841483
}
@@ -1538,7 +1537,7 @@ func convertToFloat64(t NumberTypeImpl_, v interface{}) (float64, error) {
15381537
}
15391538
return float64(i), nil
15401539
case string:
1541-
v = trimStringToNumberPrefix(v, true)
1540+
v = strings.Trim(v, NumericCutSet)
15421541
i, err := strconv.ParseFloat(v, 64)
15431542
if err != nil {
15441543
// parse the first longest valid numbers
@@ -1750,28 +1749,3 @@ func convertUintToUint32(v uint64) (uint32, sql.ConvertInRange, error) {
17501749
}
17511750
return uint32(v), sql.InRange, nil
17521751
}
1753-
1754-
func trimStringToNumberPrefix(s string, isFloat bool) string {
1755-
s = strings.TrimSpace(s)
1756-
1757-
seenDigit := false
1758-
seenDot := false
1759-
seenExp := false
1760-
signIndex := 0
1761-
1762-
for i := 0; i < len(s); i++ {
1763-
char := rune(s[i])
1764-
1765-
if unicode.IsDigit(char) {
1766-
seenDigit = true
1767-
} else if char == '.' && !seenDot && isFloat {
1768-
seenDot = true
1769-
} else if (char == 'e' || char == 'E') && !seenExp && seenDigit && isFloat {
1770-
seenExp = true
1771-
signIndex = i + 1
1772-
} else if !((char == '-' || char == '+') && i == signIndex) {
1773-
return s[:i]
1774-
}
1775-
}
1776-
return s
1777-
}

0 commit comments

Comments
 (0)