Skip to content

Commit ca76fae

Browse files
author
James Cor
committed
add decimal truncation
1 parent d0d7b1a commit ca76fae

File tree

4 files changed

+48
-35
lines changed

4 files changed

+48
-35
lines changed

enginetest/queries/script_queries.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -385,12 +385,11 @@ var ScriptTests = []ScriptTest{
385385
},
386386
{
387387
// TODO: 123.456 is converted to a DECIMAL by Builder.ConvertVal, when it should be a DOUBLE
388-
Skip: true,
389388
Query: "SELECT -'+123.456ABC' = -123.456",
390389
Expected: []sql.Row{{true}},
391390
ExpectedWarningsCount: 1,
392391
ExpectedWarning: mysql.ERTruncatedWrongValue,
393-
ExpectedWarningMessageSubstring: "Truncated incorrect double value: +123.456ABC",
392+
ExpectedWarningMessageSubstring: "Truncated incorrect decimal(65,30) value: +123.456ABC",
394393
},
395394
{
396395
Query: "SELECT '0xBEEF' = 0;",
@@ -457,12 +456,11 @@ var ScriptTests = []ScriptTest{
457456
},
458457
{
459458
// TODO: 123.456 is converted to a DECIMAL by Builder.ConvertVal, when it should be a DOUBLE
460-
Skip: true,
461459
Query: "SELECT '123.456ABC' in (123.456);",
462460
Expected: []sql.Row{{true}},
463461
ExpectedWarningsCount: 1,
464462
ExpectedWarning: mysql.ERTruncatedWrongValue,
465-
ExpectedWarningMessageSubstring: "Truncated incorrect double value: 123A",
463+
ExpectedWarningMessageSubstring: "Truncated incorrect decimal(65,30) value: 123.456ABC",
466464
},
467465
{
468466
Query: "SELECT '123.456e2' in (12345.6);",

sql/expression/arithmetic.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -691,9 +691,12 @@ func (e *UnaryMinus) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
691691
}
692692

693693
if !types.IsNumber(e.Child.Type()) {
694-
child, err = decimal.NewFromString(fmt.Sprintf("%v", child))
694+
child, _, err = types.InternalDecimalType.Convert(ctx, child)
695695
if err != nil {
696-
child = 0.0
696+
if !sql.ErrTruncatedIncorrect.Is(err) {
697+
child = 0.0
698+
}
699+
ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error())
697700
}
698701
}
699702

@@ -735,7 +738,7 @@ func (e *UnaryMinus) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
735738
case uint64:
736739
return -int64(n), nil
737740
case decimal.Decimal:
738-
return n.Neg(), err
741+
return n.Neg(), nil
739742
case string:
740743
// try getting int out of string value
741744
i, iErr := strconv.ParseInt(n, 10, 64)

sql/expression/convert.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,10 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s
356356
dt := createConvertedDecimalType(typeLength, typeScale, false)
357357
d, _, err := dt.Convert(ctx, value)
358358
if err != nil {
359-
return dt.Zero(), nil
359+
if !sql.ErrTruncatedIncorrect.Is(err) {
360+
return dt.Zero(), nil
361+
}
362+
ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error())
360363
}
361364
return d, nil
362365
case ConvertToFloat:
@@ -376,11 +379,10 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s
376379
}
377380
d, _, err := types.Float64.Convert(ctx, value)
378381
if err != nil {
379-
if sql.ErrTruncatedIncorrect.Is(err) {
380-
ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error())
381-
return d, nil
382+
if !sql.ErrTruncatedIncorrect.Is(err) {
383+
return types.Float64.Zero(), nil
382384
}
383-
return types.Float64.Zero(), nil
385+
ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error())
384386
}
385387
return d, nil
386388
case ConvertToJSON:

sql/types/decimal.go

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,14 @@ package types
1717
import (
1818
"context"
1919
"fmt"
20-
"math/big"
21-
"reflect"
22-
"strings"
23-
20+
"github.com/dolthub/go-mysql-server/sql"
2421
"github.com/dolthub/vitess/go/sqltypes"
2522
"github.com/dolthub/vitess/go/vt/proto/query"
2623
"github.com/shopspring/decimal"
2724
"gopkg.in/src-d/go-errors.v1"
28-
29-
"github.com/dolthub/go-mysql-server/sql"
25+
"math/big"
26+
"reflect"
27+
"strings"
3028
)
3129

3230
const (
@@ -141,13 +139,17 @@ func (t DecimalType_) Compare(s context.Context, a interface{}, b interface{}) (
141139
// Convert implements Type interface.
142140
func (t DecimalType_) Convert(c context.Context, v interface{}) (interface{}, sql.ConvertInRange, error) {
143141
dec, err := t.ConvertToNullDecimal(v)
144-
if err != nil {
142+
if err != nil && !sql.ErrTruncatedIncorrect.Is(err) {
145143
return nil, sql.OutOfRange, err
146144
}
147145
if !dec.Valid {
148146
return nil, sql.InRange, nil
149147
}
150-
return t.BoundsCheck(dec.Decimal)
148+
res, inRange, cErr := t.BoundsCheck(dec.Decimal)
149+
if cErr != nil {
150+
return nil, sql.OutOfRange, cErr
151+
}
152+
return res, inRange, err
151153
}
152154

153155
func (t DecimalType_) ConvertNoBoundsCheck(v interface{}) (decimal.Decimal, error) {
@@ -201,25 +203,33 @@ func (t DecimalType_) ConvertToNullDecimal(v interface{}) (decimal.NullDecimal,
201203
case float64:
202204
return t.ConvertToNullDecimal(decimal.NewFromFloat(value))
203205
case string:
204-
// TODO: implement truncation here
205-
value = strings.Trim(value, sql.NumericCutSet)
206-
if len(value) == 0 {
207-
return t.ConvertToNullDecimal(decimal.NewFromInt(0))
208-
}
209206
var err error
210-
res, err = decimal.NewFromString(value)
211-
if err != nil {
212-
// The decimal library cannot handle all of the different formats
213-
bf, _, err := new(big.Float).SetPrec(217).Parse(value, 0)
214-
if err != nil {
215-
return decimal.NullDecimal{}, err
216-
}
207+
truncStr := strings.Trim(value, sql.NumericCutSet)
208+
res, err = decimal.NewFromString(truncStr)
209+
if err == nil {
210+
return t.ConvertToNullDecimal(res)
211+
}
212+
// The decimal library cannot handle all the different formats
213+
bf, _, err := new(big.Float).SetPrec(217).Parse(truncStr, 0)
214+
if err == nil {
217215
res, err = decimal.NewFromString(bf.Text('f', -1))
218-
if err != nil {
219-
return decimal.NullDecimal{}, err
216+
if err == nil {
217+
return t.ConvertToNullDecimal(res)
220218
}
221219
}
222-
return t.ConvertToNullDecimal(res)
220+
truncStr, didTrunc := sql.TruncateStringToDouble(value)
221+
if truncStr == "0" {
222+
return t.ConvertToNullDecimal(decimal.NewFromInt(0))
223+
}
224+
res, _ = decimal.NewFromString(truncStr)
225+
nullDec, cErr := t.ConvertToNullDecimal(res)
226+
if cErr != nil {
227+
return decimal.NullDecimal{}, cErr
228+
}
229+
if didTrunc {
230+
err = sql.ErrTruncatedIncorrect.New(t, value)
231+
}
232+
return nullDec, err
223233
case *big.Float:
224234
return t.ConvertToNullDecimal(value.Text('f', -1))
225235
case *big.Int:

0 commit comments

Comments
 (0)