Skip to content

Commit 7736f6b

Browse files
elianddbclaude
andcommitted
Fix string-to-number conversion to match MySQL behavior
* Update regex pattern to support scientific notation and signs * Implement MySQL-compatible string truncation for all numeric types * Convert invalid strings to 0 instead of throwing errors * Add comprehensive tests for string truncation edge cases * Fix existing test to match MySQL behavior Fixes dolthub/dolt#7128 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 028d9ca commit 7736f6b

File tree

2 files changed

+118
-10
lines changed

2 files changed

+118
-10
lines changed

sql/types/number.go

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ var (
8484
numberFloat32ValueType = reflect.TypeOf(float32(0))
8585
numberFloat64ValueType = reflect.TypeOf(float64(0))
8686

87-
numre = regexp.MustCompile(`^[ ]*[0-9]*\.?[0-9]+`)
87+
numre = regexp.MustCompile(`^[ \t\n\r]*[+-]?([0-9]+\.?[0-9]*|\.[0-9]+)([eE][+-]?[0-9]+)?`)
8888
)
8989

9090
const (
@@ -1004,7 +1004,15 @@ func convertToInt64(t NumberTypeImpl_, v interface{}) (int64, sql.ConvertInRange
10041004
// If that fails, try as a float and truncate it to integral
10051005
f, err := strconv.ParseFloat(v, 64)
10061006
if err != nil {
1007-
return 0, sql.OutOfRange, sql.ErrInvalidValue.New(v, t.String())
1007+
// Use same truncation logic as float conversion for MySQL compatibility
1008+
s := numre.FindString(v)
1009+
if s != "" {
1010+
f, _ = strconv.ParseFloat(s, 64)
1011+
f = math.Round(f)
1012+
return int64(f), sql.InRange, nil
1013+
}
1014+
// If no valid number found, return 0 (MySQL behavior for pure non-numeric strings)
1015+
return 0, sql.InRange, nil
10081016
}
10091017
f = math.Round(f)
10101018
return int64(f), sql.InRange, nil
@@ -1190,7 +1198,17 @@ func convertToUint64(t NumberTypeImpl_, v interface{}) (uint64, sql.ConvertInRan
11901198
return val, inRange, err
11911199
}
11921200
}
1193-
return 0, sql.OutOfRange, sql.ErrInvalidValue.New(v, t.String())
1201+
// Use same truncation logic as float conversion for MySQL compatibility
1202+
s := numre.FindString(v)
1203+
if s != "" {
1204+
if f, err := strconv.ParseFloat(s, 64); err == nil {
1205+
if val, inRange, err := convertToUint64(t, f); err == nil {
1206+
return val, inRange, err
1207+
}
1208+
}
1209+
}
1210+
// If no valid number found, return 0 (MySQL behavior for pure non-numeric strings)
1211+
return 0, sql.InRange, nil
11941212
case bool:
11951213
if v {
11961214
return 1, sql.InRange, nil
@@ -1290,7 +1308,17 @@ func convertToUint32(t NumberTypeImpl_, v interface{}) (uint32, sql.ConvertInRan
12901308
return val, inRange, err
12911309
}
12921310
}
1293-
return 0, sql.OutOfRange, sql.ErrInvalidValue.New(v, t.String())
1311+
// Use same truncation logic as float conversion for MySQL compatibility
1312+
s := numre.FindString(v)
1313+
if s != "" {
1314+
if f, err := strconv.ParseFloat(s, 64); err == nil {
1315+
if val, inRange, err := convertToUint32(t, f); err == nil {
1316+
return val, inRange, err
1317+
}
1318+
}
1319+
}
1320+
// If no valid number found, return 0 (MySQL behavior for pure non-numeric strings)
1321+
return 0, sql.InRange, nil
12941322
case bool:
12951323
if v {
12961324
return 1, sql.InRange, nil
@@ -1386,7 +1414,17 @@ func convertToUint16(t NumberTypeImpl_, v interface{}) (uint16, sql.ConvertInRan
13861414
return val, inRange, err
13871415
}
13881416
}
1389-
return 0, sql.OutOfRange, sql.ErrInvalidValue.New(v, t.String())
1417+
// Use same truncation logic as float conversion for MySQL compatibility
1418+
s := numre.FindString(v)
1419+
if s != "" {
1420+
if f, err := strconv.ParseFloat(s, 64); err == nil {
1421+
if val, inRange, err := convertToUint16(t, f); err == nil {
1422+
return val, inRange, err
1423+
}
1424+
}
1425+
}
1426+
// If no valid number found, return 0 (MySQL behavior for pure non-numeric strings)
1427+
return 0, sql.InRange, nil
13901428
case bool:
13911429
if v {
13921430
return 1, sql.InRange, nil
@@ -1486,7 +1524,17 @@ func convertToUint8(t NumberTypeImpl_, v interface{}) (uint8, sql.ConvertInRange
14861524
return val, inRange, err
14871525
}
14881526
}
1489-
return 0, sql.OutOfRange, sql.ErrInvalidValue.New(v, t.String())
1527+
// Use same truncation logic as float conversion for MySQL compatibility
1528+
s := numre.FindString(v)
1529+
if s != "" {
1530+
if f, err := strconv.ParseFloat(s, 64); err == nil {
1531+
if val, inRange, err := convertToUint8(t, f); err == nil {
1532+
return val, inRange, err
1533+
}
1534+
}
1535+
}
1536+
// If no valid number found, return 0 (MySQL behavior for pure non-numeric strings)
1537+
return 0, sql.InRange, nil
14901538
case bool:
14911539
if v {
14921540
return 1, sql.InRange, nil
@@ -1542,8 +1590,12 @@ func convertToFloat64(t NumberTypeImpl_, v interface{}) (float64, error) {
15421590
if err != nil {
15431591
// parse the first longest valid numbers
15441592
s := numre.FindString(v)
1545-
i, _ = strconv.ParseFloat(s, 64)
1546-
return i, sql.ErrInvalidValue.New(v, t.String())
1593+
if s != "" {
1594+
i, _ = strconv.ParseFloat(s, 64)
1595+
return i, nil
1596+
}
1597+
// If no valid number found, return 0 (MySQL behavior for pure non-numeric strings)
1598+
return 0, nil
15471599
}
15481600
return i, nil
15491601
case bool:

sql/types/number_test.go

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,61 @@ func TestNumberConvert(t *testing.T) {
236236
}
237237
}
238238

239+
func TestFloat64StringTruncation(t *testing.T) {
240+
ctx := sql.NewEmptyContext()
241+
tests := []struct {
242+
name string
243+
input interface{}
244+
expected float64
245+
err bool
246+
inRange sql.ConvertInRange
247+
}{
248+
// Basic truncation cases
249+
{name: "numeric with invalid suffix", input: "123.456abc", expected: 123.456, err: false, inRange: sql.InRange},
250+
{name: "integer with invalid suffix", input: "123abc", expected: 123, err: false, inRange: sql.InRange},
251+
{name: "negative with invalid suffix", input: "-123.456abc", expected: -123.456, err: false, inRange: sql.InRange},
252+
{name: "positive sign with invalid suffix", input: "+123.456abc", expected: 123.456, err: false, inRange: sql.InRange},
253+
254+
// Scientific notation cases
255+
{name: "scientific notation with suffix", input: "1.5e2abc", expected: 150, err: false, inRange: sql.InRange},
256+
{name: "scientific notation negative exponent", input: "1e-4", expected: 0.0001, err: false, inRange: sql.InRange},
257+
{name: "uppercase E notation", input: "1.5E2abc", expected: 150, err: false, inRange: sql.InRange},
258+
{name: "positive exponent with suffix", input: "2.5e+3xyz", expected: 2500, err: false, inRange: sql.InRange},
259+
260+
// Edge cases that become 0
261+
{name: "pure non-numeric", input: "abc", expected: 0, err: false, inRange: sql.InRange},
262+
{name: "single letter", input: "a", expected: 0, err: false, inRange: sql.InRange},
263+
{name: "empty string", input: "", expected: 0, err: false, inRange: sql.InRange},
264+
265+
// Whitespace handling
266+
{name: "leading spaces", input: " 123.456abc", expected: 123.456, err: false, inRange: sql.InRange},
267+
{name: "leading tabs", input: "\t123.456abc", expected: 123.456, err: false, inRange: sql.InRange},
268+
{name: "mixed whitespace", input: " \t\n\r123.456abc", expected: 123.456, err: false, inRange: sql.InRange},
269+
{name: "only whitespace", input: " \t\n\r", expected: 0, err: false, inRange: sql.InRange},
270+
271+
// Decimal point variations
272+
{name: "decimal without leading digit", input: ".5abc", expected: 0.5, err: false, inRange: sql.InRange},
273+
{name: "decimal without trailing digits", input: "123.abc", expected: 123, err: false, inRange: sql.InRange},
274+
275+
// Multiple decimal points (should stop at first invalid)
276+
{name: "multiple decimal points", input: "1.2.3abc", expected: 1.2, err: false, inRange: sql.InRange},
277+
}
278+
279+
for _, test := range tests {
280+
t.Run(test.name, func(t *testing.T) {
281+
val, inRange, err := Float64.Convert(ctx, test.input)
282+
if test.err {
283+
assert.Error(t, err)
284+
} else {
285+
require.NoError(t, err)
286+
assert.Equal(t, test.expected, val)
287+
assert.Equal(t, test.inRange, inRange)
288+
assert.Equal(t, Float64.ValueType(), reflect.TypeOf(val))
289+
}
290+
})
291+
}
292+
}
293+
239294
func TestNumberSQL_BooleanFromBoolean(t *testing.T) {
240295
val, err := Boolean.SQL(sql.NewEmptyContext(), nil, true)
241296
require.NoError(t, err)
@@ -247,13 +302,14 @@ func TestNumberSQL_BooleanFromBoolean(t *testing.T) {
247302
}
248303

249304
func TestNumberSQL_NumberFromString(t *testing.T) {
305+
// MySQL converts invalid strings to 0 when used in numeric contexts
250306
val, err := Int64.SQL(sql.NewEmptyContext(), nil, "not a number")
251307
require.NoError(t, err)
252-
assert.Equal(t, "not a number", val.ToString())
308+
assert.Equal(t, "0", val.ToString())
253309

254310
val, err = Float64.SQL(sql.NewEmptyContext(), nil, "also not a number")
255311
require.NoError(t, err)
256-
assert.Equal(t, "also not a number", val.ToString())
312+
assert.Equal(t, "0", val.ToString())
257313
}
258314

259315
func TestNumberString(t *testing.T) {

0 commit comments

Comments
 (0)