@@ -17,12 +17,12 @@ package types
1717import (
1818 "context"
1919 "encoding/hex"
20+ "errors"
2021 "fmt"
2122 "math"
2223 "reflect"
2324 "regexp"
2425 "strconv"
25- "strings"
2626 "time"
2727
2828 "github.com/dolthub/vitess/go/sqltypes"
@@ -1001,13 +1001,13 @@ func convertToInt64(t NumberTypeImpl_, v interface{}) (int64, sql.ConvertInRange
10011001 if pErr == nil {
10021002 return i , sql .InRange , err
10031003 }
1004- // If that fails, try as a float and round it to integral
1004+ // If that fails, try as a float
10051005 f , pErr := strconv .ParseFloat (truncStr , 64 )
1006- if pErr == nil {
1007- f = math .Round (f )
1008- return int64 (f ), sql .InRange , err
1006+ if pErr != nil {
1007+ return 0 , sql .OutOfRange , sql .ErrInvalidValue .New (v , t .String ())
10091008 }
1010- return 0 , sql .OutOfRange , sql .ErrInvalidValue .New (v , t .String ())
1009+ i , inRange , _ := convertToInt64 (t , f )
1010+ return i , inRange , err
10111011 case bool :
10121012 if v {
10131013 return 1 , sql .InRange , nil
@@ -1152,15 +1152,15 @@ func convertToUint64(t NumberTypeImpl_, v interface{}) (uint64, sql.ConvertInRan
11521152 return math .MaxUint64 , sql .OutOfRange , nil
11531153 }
11541154 if v < 0 {
1155- return uint64 (math .MaxUint64 - v ), sql .OutOfRange , nil
1155+ return uint64 (math .MaxUint64 - uint ( - v - 1 ) ), sql .OutOfRange , nil
11561156 }
11571157 return uint64 (math .Round (float64 (v ))), sql .InRange , nil
11581158 case float64 :
11591159 if v >= float64 (math .MaxUint64 ) {
11601160 return math .MaxUint64 , sql .OutOfRange , nil
11611161 }
11621162 if v < 0 {
1163- return uint64 (math .MaxUint64 - v ), sql .OutOfRange , nil
1163+ return uint64 (math .MaxUint64 - uint ( - v - 1 ) ), sql .OutOfRange , nil
11641164 }
11651165 return uint64 (math .Round (v )), sql .InRange , nil
11661166 case decimal.Decimal :
@@ -1181,19 +1181,29 @@ func convertToUint64(t NumberTypeImpl_, v interface{}) (uint64, sql.ConvertInRan
11811181 }
11821182 return i , sql .InRange , nil
11831183 case string :
1184- v = strings .Trim (v , sql .IntCutSet )
1185- if i , err := strconv .ParseUint (v , 10 , 64 ); err == nil {
1186- return i , sql .InRange , nil
1187- } else if err == strconv .ErrRange {
1188- // Number is too large for uint64, return max value and OutOfRange
1184+ // TODO: this currently assumes we are always rounding to preserve behavior
1185+ // but we should only be rounding on inserts
1186+ var err error
1187+ truncStr , didTrunc := sql .TruncateStringToDouble (v )
1188+ if didTrunc {
1189+ err = sql .ErrTruncatedIncorrect .New (t , v )
1190+ }
1191+ // Parse first as an integer, which allows for more values than float64
1192+ i , pErr := strconv .ParseUint (truncStr , 10 , 64 )
1193+ if pErr == nil {
1194+ return i , sql .InRange , err
1195+ }
1196+ // Number is too large for uint64, return max value and OutOfRange
1197+ if errors .Is (err , strconv .ErrRange ) {
11891198 return math .MaxUint64 , sql .OutOfRange , nil
11901199 }
1191- if f , err := strconv . ParseFloat ( v , 64 ); err == nil {
1192- if val , inRange , err := convertToUint64 ( t , f ); err == nil && inRange {
1193- return val , inRange , err
1194- }
1200+ // If that fails, try as a float
1201+ f , pErr := strconv . ParseFloat ( truncStr , 64 )
1202+ if pErr != nil {
1203+ return 0 , sql . OutOfRange , sql . ErrInvalidValue . New ( v , t . String ())
11951204 }
1196- return 0 , sql .OutOfRange , sql .ErrInvalidValue .New (v , t .String ())
1205+ i , inRange , _ := convertToUint64 (t , f )
1206+ return i , inRange , err
11971207 case bool :
11981208 if v {
11991209 return 1 , sql .InRange , nil
@@ -1244,16 +1254,13 @@ func convertToFloat64(t NumberTypeImpl_, v interface{}) (float64, error) {
12441254 }
12451255 return float64 (i ), nil
12461256 case string :
1247- // TODO: proper truncation and rounding behavior
1248- v = strings .Trim (v , sql .NumericCutSet )
1249- i , err := strconv .ParseFloat (v , 64 )
1250- if err != nil {
1251- // parse the first longest valid numbers
1252- s := numre .FindString (v )
1253- i , _ = strconv .ParseFloat (s , 64 )
1254- return i , sql .ErrTruncatedIncorrect .New (t .String (), v )
1257+ var err error
1258+ truncStr , didTrunc := sql .TruncateStringToDouble (v )
1259+ if didTrunc {
1260+ err = sql .ErrTruncatedIncorrect .New (t , v )
12551261 }
1256- return i , nil
1262+ f , _ := strconv .ParseFloat (truncStr , 64 )
1263+ return f , err
12571264 case bool :
12581265 if v {
12591266 return 1 , nil
0 commit comments