@@ -149,11 +149,11 @@ func (t NumberTypeImpl_) Compare(s context.Context, a interface{}, b interface{}
149149
150150 switch t .baseType {
151151 case sqltypes .Uint8 , sqltypes .Uint16 , sqltypes .Uint24 , sqltypes .Uint32 , sqltypes .Uint64 :
152- ca , _ , err := convertToUint64 (t , a )
152+ ca , _ , err := convertToUint64 (t , a , false )
153153 if err != nil {
154154 return 0 , err
155155 }
156- cb , _ , err := convertToUint64 (t , b )
156+ cb , _ , err := convertToUint64 (t , b , false )
157157 if err != nil {
158158 return 0 , err
159159 }
@@ -320,7 +320,7 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{
320320 case sqltypes .Int64 :
321321 return convertToInt64 (t , v , false )
322322 case sqltypes .Uint64 :
323- return convertToUint64 (t , v )
323+ return convertToUint64 (t , v , false )
324324 case sqltypes .Float32 :
325325 num , err := convertToFloat64 (t , v )
326326 if err != nil && ! sql .ErrTruncatedIncorrect .Is (err ) {
@@ -445,6 +445,8 @@ func (t NumberTypeImpl_) ConvertRound(ctx context.Context, v interface{}) (any,
445445 return uint32 (num ), sql .InRange , nil
446446 case sqltypes .Int64 :
447447 return convertToInt64 (t , v , true )
448+ case sqltypes .Uint64 :
449+ return convertToUint64 (t , v , true )
448450 default :
449451 return t .Convert (ctx , v )
450452 }
@@ -569,7 +571,7 @@ func (t NumberTypeImpl_) SQLInt64(ctx *sql.Context, dest []byte, v interface{})
569571}
570572
571573func (t NumberTypeImpl_ ) SQLUint8 (ctx * sql.Context , dest []byte , v interface {}) ([]byte , error ) {
572- num , _ , err := convertToUint64 (t , v )
574+ num , _ , err := convertToUint64 (t , v , false )
573575 if err != nil {
574576 return nil , err
575577 }
@@ -582,7 +584,7 @@ func (t NumberTypeImpl_) SQLUint8(ctx *sql.Context, dest []byte, v interface{})
582584}
583585
584586func (t NumberTypeImpl_ ) SQLUint16 (ctx * sql.Context , dest []byte , v interface {}) ([]byte , error ) {
585- num , _ , err := convertToUint64 (t , v )
587+ num , _ , err := convertToUint64 (t , v , false )
586588 if err != nil {
587589 return nil , err
588590 }
@@ -595,7 +597,7 @@ func (t NumberTypeImpl_) SQLUint16(ctx *sql.Context, dest []byte, v interface{})
595597}
596598
597599func (t NumberTypeImpl_ ) SQLUint24 (ctx * sql.Context , dest []byte , v interface {}) ([]byte , error ) {
598- num , _ , err := convertToUint64 (t , v )
600+ num , _ , err := convertToUint64 (t , v , false )
599601 if err != nil {
600602 return nil , err
601603 }
@@ -608,7 +610,7 @@ func (t NumberTypeImpl_) SQLUint24(ctx *sql.Context, dest []byte, v interface{})
608610}
609611
610612func (t NumberTypeImpl_ ) SQLUint32 (ctx * sql.Context , dest []byte , v interface {}) ([]byte , error ) {
611- num , _ , err := convertToUint64 (t , v )
613+ num , _ , err := convertToUint64 (t , v , false )
612614 if err != nil {
613615 return nil , err
614616 }
@@ -621,7 +623,7 @@ func (t NumberTypeImpl_) SQLUint32(ctx *sql.Context, dest []byte, v interface{})
621623}
622624
623625func (t NumberTypeImpl_ ) SQLUint64 (ctx * sql.Context , dest []byte , v interface {}) ([]byte , error ) {
624- num , _ , err := convertToUint64 (t , v )
626+ num , _ , err := convertToUint64 (t , v , false )
625627 if err != nil {
626628 return nil , err
627629 }
@@ -1227,7 +1229,7 @@ func convertValueToUint64(t NumberTypeImpl_, v sql.Value) (uint64, error) {
12271229 }
12281230}
12291231
1230- func convertToUint64 (t NumberTypeImpl_ , v interface {}) (uint64 , sql.ConvertInRange , error ) {
1232+ func convertToUint64 (t NumberTypeImpl_ , v interface {}, round bool ) (uint64 , sql.ConvertInRange , error ) {
12311233 switch v := v .(type ) {
12321234 case time.Time :
12331235 return uint64 (v .UTC ().Unix ()), sql .InRange , nil
@@ -1300,29 +1302,46 @@ func convertToUint64(t NumberTypeImpl_, v interface{}) (uint64, sql.ConvertInRan
13001302 }
13011303 return i , sql .InRange , nil
13021304 case string :
1303- // TODO: this currently assumes we are always rounding to preserve behavior
1304- // but we should only be rounding on inserts
13051305 var err error
1306- truncStr , didTrunc := sql .TruncateStringToDouble (v )
1306+ if round {
1307+ truncStr , didTrunc := sql .TruncateStringToDouble (v )
1308+ if didTrunc {
1309+ err = sql .ErrTruncatedIncorrect .New (t , v )
1310+ }
1311+ // Parse first an integer, which allows for more values than float64
1312+ i , pErr := strconv .ParseUint (truncStr , 10 , 64 )
1313+ if pErr == nil {
1314+ return i , sql .InRange , err
1315+ }
1316+ // If that fails, try as a float
1317+ f , pErr := strconv .ParseFloat (truncStr , 64 )
1318+ if pErr != nil {
1319+ return 0 , sql .OutOfRange , sql .ErrInvalidValue .New (v , t .String ())
1320+ }
1321+ i , inRange , _ := convertToUint64 (t , f , round )
1322+ return i , inRange , err
1323+ }
1324+ truncStr , didTrunc := sql .TruncateStringToInt (v )
13071325 if didTrunc {
13081326 err = sql .ErrTruncatedIncorrect .New (t , v )
13091327 }
1328+ var neg bool
1329+ if truncStr [0 ] == '+' {
1330+ truncStr = truncStr [1 :]
1331+ } else if truncStr [0 ] == '-' {
1332+ truncStr = truncStr [1 :]
1333+ neg = true
1334+ }
13101335 // Parse first as an integer, which allows for more values than float64
13111336 i , pErr := strconv .ParseUint (truncStr , 10 , 64 )
1312- if pErr == nil {
1313- return i , sql .InRange , err
1314- }
13151337 // Number is too large for uint64, return max value and OutOfRange
1316- if errors .Is (err , strconv .ErrRange ) {
1338+ if errors .Is (pErr , strconv .ErrRange ) {
13171339 return math .MaxUint64 , sql .OutOfRange , nil
13181340 }
1319- // If that fails, try as a float
1320- f , pErr := strconv .ParseFloat (truncStr , 64 )
1321- if pErr != nil {
1322- return 0 , sql .OutOfRange , sql .ErrInvalidValue .New (v , t .String ())
1341+ if neg {
1342+ return math .MaxUint64 - i + 1 , sql .OutOfRange , err
13231343 }
1324- i , inRange , _ := convertToUint64 (t , f )
1325- return i , inRange , err
1344+ return i , sql .InRange , err
13261345 case bool :
13271346 if v {
13281347 return 1 , sql .InRange , nil
0 commit comments