@@ -17,16 +17,14 @@ package types
1717import (
1818 "context"
1919 "fmt"
20- "math/big"
21- "reflect"
22- "strings"
23-
20+ "github.com/dolthub/go-mysql-server/sql"
2421 "github.com/dolthub/vitess/go/sqltypes"
2522 "github.com/dolthub/vitess/go/vt/proto/query"
2623 "github.com/shopspring/decimal"
2724 "gopkg.in/src-d/go-errors.v1"
28-
29- "github.com/dolthub/go-mysql-server/sql"
25+ "math/big"
26+ "reflect"
27+ "strings"
3028)
3129
3230const (
@@ -141,13 +139,17 @@ func (t DecimalType_) Compare(s context.Context, a interface{}, b interface{}) (
141139// Convert implements Type interface.
142140func (t DecimalType_ ) Convert (c context.Context , v interface {}) (interface {}, sql.ConvertInRange , error ) {
143141 dec , err := t .ConvertToNullDecimal (v )
144- if err != nil {
142+ if err != nil && ! sql . ErrTruncatedIncorrect . Is ( err ) {
145143 return nil , sql .OutOfRange , err
146144 }
147145 if ! dec .Valid {
148146 return nil , sql .InRange , nil
149147 }
150- return t .BoundsCheck (dec .Decimal )
148+ res , inRange , cErr := t .BoundsCheck (dec .Decimal )
149+ if cErr != nil {
150+ return nil , sql .OutOfRange , cErr
151+ }
152+ return res , inRange , err
151153}
152154
153155func (t DecimalType_ ) ConvertNoBoundsCheck (v interface {}) (decimal.Decimal , error ) {
@@ -201,25 +203,33 @@ func (t DecimalType_) ConvertToNullDecimal(v interface{}) (decimal.NullDecimal,
201203 case float64 :
202204 return t .ConvertToNullDecimal (decimal .NewFromFloat (value ))
203205 case string :
204- // TODO: implement truncation here
205- value = strings .Trim (value , sql .NumericCutSet )
206- if len (value ) == 0 {
207- return t .ConvertToNullDecimal (decimal .NewFromInt (0 ))
208- }
209206 var err error
210- res , err = decimal .NewFromString (value )
211- if err != nil {
212- // The decimal library cannot handle all of the different formats
213- bf , _ , err := new (big.Float ).SetPrec (217 ).Parse (value , 0 )
214- if err != nil {
215- return decimal.NullDecimal {}, err
216- }
207+ truncStr := strings .Trim (value , sql .NumericCutSet )
208+ res , err = decimal .NewFromString (truncStr )
209+ if err == nil {
210+ return t .ConvertToNullDecimal (res )
211+ }
212+ // The decimal library cannot handle all the different formats
213+ bf , _ , err := new (big.Float ).SetPrec (217 ).Parse (truncStr , 0 )
214+ if err == nil {
217215 res , err = decimal .NewFromString (bf .Text ('f' , - 1 ))
218- if err ! = nil {
219- return decimal. NullDecimal {}, err
216+ if err = = nil {
217+ return t . ConvertToNullDecimal ( res )
220218 }
221219 }
222- return t .ConvertToNullDecimal (res )
220+ truncStr , didTrunc := sql .TruncateStringToDouble (value )
221+ if truncStr == "0" {
222+ return t .ConvertToNullDecimal (decimal .NewFromInt (0 ))
223+ }
224+ res , _ = decimal .NewFromString (truncStr )
225+ nullDec , cErr := t .ConvertToNullDecimal (res )
226+ if cErr != nil {
227+ return decimal.NullDecimal {}, cErr
228+ }
229+ if didTrunc {
230+ err = sql .ErrTruncatedIncorrect .New (t , value )
231+ }
232+ return nullDec , err
223233 case * big.Float :
224234 return t .ConvertToNullDecimal (value .Text ('f' , - 1 ))
225235 case * big.Int :
0 commit comments