Skip to content

Commit 4c093bb

Browse files
committed
Do not allow inserting NaN and Inf values into numeric type columns
1 parent bbd0659 commit 4c093bb

File tree

5 files changed

+149
-21
lines changed

5 files changed

+149
-21
lines changed

memory/table_test.go

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
package memory_test
1616

1717
import (
18+
"context"
1819
"fmt"
1920
"io"
21+
"math"
2022
"testing"
2123

2224
"github.com/stretchr/testify/require"
@@ -62,6 +64,73 @@ func TestTableString(t *testing.T) {
6264
require.Equal("foo", table.String())
6365
}
6466

67+
func TestTableInsert(t *testing.T) {
68+
testCases := []struct {
69+
name string
70+
colType sql.Type
71+
value interface{}
72+
err bool
73+
}{
74+
{
75+
name: "inserting NaN into float results in error",
76+
colType: types.Float64,
77+
value: math.NaN(),
78+
err: true,
79+
},
80+
{
81+
name: "inserting NaN into int results in error",
82+
colType: types.Int64,
83+
value: math.NaN(),
84+
err: true,
85+
},
86+
{
87+
name: "inserting NaN into Decimal results in error",
88+
colType: types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, types.DecimalTypeMaxScale),
89+
value: math.NaN(),
90+
err: true,
91+
},
92+
{
93+
name: "inserting Infinity into float results in error",
94+
colType: types.Float64,
95+
value: math.Inf(1),
96+
err: true,
97+
},
98+
{
99+
name: "inserting Infinity into int results in error",
100+
colType: types.Int64,
101+
value: math.Inf(1),
102+
err: true,
103+
},
104+
{
105+
name: "inserting Infinity into Decimal results in error",
106+
colType: types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, types.DecimalTypeMaxScale),
107+
value: math.Inf(1),
108+
err: true,
109+
},
110+
}
111+
112+
for _, tc := range testCases {
113+
t.Run(tc.name, func(t *testing.T) {
114+
db := memory.NewDatabase("db")
115+
116+
provider := memory.NewDBProvider(db)
117+
session := memory.NewSession(sql.NewBaseSession(), provider)
118+
ctx := sql.NewContext(context.Background(), sql.WithSession(session))
119+
120+
table := memory.NewTable(db, "test", sql.NewPrimaryKeySchema(sql.Schema{
121+
{Name: "col1", Type: tc.colType, Nullable: false},
122+
}), nil)
123+
124+
err := table.Insert(ctx, sql.NewRow(tc.value))
125+
if tc.err {
126+
require.Error(t, err)
127+
} else {
128+
require.NoError(t, err)
129+
}
130+
})
131+
}
132+
}
133+
65134
type indexKeyValue struct {
66135
key sql.Row
67136
value *memory.IndexValue

sql/rowexec/insert.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,13 +161,12 @@ func (i *insertIter) Next(ctx *sql.Context) (returnRow sql.Row, returnErr error)
161161
continue
162162
} else {
163163
// Fill in error with information
164-
if types.ErrLengthBeyondLimit.Is(cErr) {
164+
switch {
165+
case types.ErrLengthBeyondLimit.Is(cErr):
165166
cErr = types.ErrLengthBeyondLimit.New(row[idx], col.Name)
166-
} else if sql.ErrNotMatchingSRID.Is(cErr) {
167+
case sql.ErrNotMatchingSRID.Is(cErr):
167168
cErr = sql.ErrNotMatchingSRIDWithColName.New(col.Name, cErr)
168-
} else if types.ErrConvertingToEnum.Is(cErr) {
169-
cErr = types.ErrDataTruncatedForColumnAtRow.New(col.Name, i.rowNumber)
170-
} else if sql.ErrInvalidSetValue.Is(cErr) || sql.ErrConvertingToSet.Is(cErr) {
169+
case types.ErrConvertingToEnum.Is(cErr), sql.ErrInvalidSetValue.Is(cErr), sql.ErrConvertingToSet.Is(cErr):
171170
cErr = types.ErrDataTruncatedForColumnAtRow.New(col.Name, i.rowNumber)
172171
}
173172
return nil, sql.NewWrappedInsertError(origRow, cErr)

sql/rowexec/insert_test.go

Lines changed: 63 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package rowexec
1616

1717
import (
18+
"math"
1819
"testing"
1920
"time"
2021

@@ -28,46 +29,88 @@ import (
2829
"github.com/dolthub/go-mysql-server/sql/types"
2930
)
3031

31-
func TestInsertIgnoreConversions(t *testing.T) {
32+
func TestInsert(t *testing.T) {
3233
testCases := []struct {
3334
name string
3435
colType sql.Type
3536
value interface{}
3637
valueType sql.Type
3738
expected interface{}
39+
warning bool
40+
ignore bool
3841
err bool
3942
}{
4043
{
41-
name: "inserting a string into a integer defaults to a 0",
44+
name: "inserting a string into a integer defaults to a 0 (with ignore)",
4245
colType: types.Int64,
4346
value: "dadasd",
4447
valueType: types.Text,
4548
expected: int64(0),
46-
err: true,
49+
warning: true,
50+
ignore: true,
4751
},
4852
{
49-
name: "string too long gets truncated",
53+
name: "string too long gets truncated (with ignore)",
5054
colType: types.MustCreateStringWithDefaults(sqltypes.VarChar, 2),
5155
value: "dadsa",
5256
valueType: types.Text,
5357
expected: "da",
54-
err: true,
58+
warning: true,
59+
ignore: true,
5560
},
5661
{
57-
name: "inserting a string into a datetime results in 0 time",
62+
name: "inserting a string into a datetime results in 0 time (with ignore)",
5863
colType: types.Datetime,
5964
value: "dadasd",
6065
valueType: types.Text,
6166
expected: time.Unix(-62167219200, 0).UTC(),
62-
err: true,
67+
warning: true,
68+
ignore: true,
6369
},
6470
{
65-
name: "inserting a negative into an unsigned int results in 0",
71+
name: "inserting a negative into an unsigned int results in 0 (with ignore)",
6672
colType: types.Uint64,
6773
value: int64(-1),
6874
expected: uint64(1<<64 - 1),
6975
valueType: types.Uint64,
70-
err: true,
76+
warning: true,
77+
ignore: true,
78+
},
79+
{
80+
name: "inserting NaN into float results in error",
81+
colType: types.Float64,
82+
value: math.NaN(),
83+
err: true,
84+
},
85+
{
86+
name: "inserting NaN into int results in error",
87+
colType: types.Int64,
88+
value: math.NaN(),
89+
err: true,
90+
},
91+
{
92+
name: "inserting NaN into Decimal results in error",
93+
colType: types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, types.DecimalTypeMaxScale),
94+
value: math.NaN(),
95+
err: true,
96+
},
97+
{
98+
name: "inserting Infinity into float results in error",
99+
colType: types.Float64,
100+
value: math.Inf(1),
101+
err: true,
102+
},
103+
{
104+
name: "inserting Infinity into int results in error",
105+
colType: types.Int64,
106+
value: math.Inf(1),
107+
err: true,
108+
},
109+
{
110+
name: "inserting Infinity into Decimal results in error",
111+
colType: types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, types.DecimalTypeMaxScale),
112+
value: math.Inf(1),
113+
err: true,
71114
},
72115
}
73116

@@ -83,21 +126,25 @@ func TestInsertIgnoreConversions(t *testing.T) {
83126

84127
insertPlan := plan.NewInsertInto(sql.UnresolvedDatabase(""), plan.NewResolvedTable(table, nil, nil), plan.NewValues([][]sql.Expression{{
85128
expression.NewLiteral(tc.value, tc.valueType),
86-
}}), false, []string{"c1"}, []sql.Expression{}, true)
129+
}}), false, []string{"c1"}, []sql.Expression{}, tc.ignore)
87130

88131
ri, err := DefaultBuilder.Build(ctx, insertPlan, nil)
89132
require.NoError(t, err)
90133

91134
row, err := ri.Next(ctx)
92-
require.NoError(t, err)
135+
if tc.err {
136+
require.Error(t, err)
137+
} else {
138+
require.NoError(t, err)
93139

94-
require.Equal(t, sql.Row{tc.expected}, row)
140+
require.Equal(t, sql.Row{tc.expected}, row)
95141

96-
var warningCnt int
97-
if tc.err {
98-
warningCnt = 1
142+
var warningCnt int
143+
if tc.warning {
144+
warningCnt = 1
145+
}
146+
require.Equal(t, ctx.WarningCount(), uint16(warningCnt))
99147
}
100-
require.Equal(t, ctx.WarningCount(), uint16(warningCnt))
101148
})
102149
}
103150
}

sql/types/decimal.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package types
1717
import (
1818
"context"
1919
"fmt"
20+
"math"
2021
"math/big"
2122
"reflect"
2223
"strings"
@@ -203,6 +204,9 @@ func (t DecimalType_) ConvertToNullDecimal(v interface{}) (decimal.NullDecimal,
203204
case float32:
204205
return t.ConvertToNullDecimal(decimal.NewFromFloat32(value))
205206
case float64:
207+
if math.IsInf(value, 0) || math.IsNaN(value) {
208+
return decimal.NullDecimal{}, ErrConvertingToDecimal.New(v)
209+
}
206210
return t.ConvertToNullDecimal(decimal.NewFromFloat(value))
207211
case string:
208212
var err error

sql/types/number.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,6 +1093,9 @@ func convertToInt64(t NumberTypeImpl_, v any, round Round) (int64, sql.ConvertIn
10931093
if v < float64(math.MinInt64) {
10941094
return math.MinInt64, sql.OutOfRange, nil
10951095
}
1096+
if math.IsInf(v, 0) || math.IsNaN(v) {
1097+
return 0, sql.OutOfRange, sql.ErrInvalidValue.New(v, t.String())
1098+
}
10961099
return int64(math.Round(v)), sql.InRange, nil
10971100
case decimal.Decimal:
10981101
if v.GreaterThan(dec_int64_max) {
@@ -1291,6 +1294,9 @@ func convertToUint64(t NumberTypeImpl_, v any, round Round) (uint64, sql.Convert
12911294
if v < 0 {
12921295
return uint64(math.MaxUint64 - uint(-v-1)), sql.OutOfRange, nil
12931296
}
1297+
if math.IsInf(v, 0) || math.IsNaN(v) {
1298+
return 0, sql.OutOfRange, sql.ErrInvalidValue.New(v, t.String())
1299+
}
12941300
return uint64(math.Round(v)), sql.InRange, nil
12951301
case decimal.Decimal:
12961302
if v.GreaterThan(dec_uint64_max) {
@@ -1389,6 +1395,9 @@ func convertToFloat64(t NumberTypeImpl_, v interface{}) (float64, error) {
13891395
case float32:
13901396
return float64(v), nil
13911397
case float64:
1398+
if math.IsInf(v, 0) || math.IsNaN(v) {
1399+
return 0, sql.ErrInvalidValue.New(v, t.String())
1400+
}
13921401
return v, nil
13931402
case decimal.Decimal:
13941403
f, _ := v.Float64()

0 commit comments

Comments
 (0)