Skip to content

Commit 10711e0

Browse files
committed
add auto-increment overflow protection for all integer types
1 parent 873af2a commit 10711e0

File tree

7 files changed

+103
-30
lines changed

7 files changed

+103
-30
lines changed

enginetest/queries/script_queries.go

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10644,7 +10644,6 @@ where
1064410644
},
1064510645
Assertions: []ScriptTestAssertion{
1064610646
{
10647-
Skip: true,
1064810647
Query: "insert into tinyint_tbl values (999)",
1064910648
ExpectedErr: sql.ErrValueOutOfRange,
1065010649
},
@@ -10668,7 +10667,6 @@ where
1066810667
},
1066910668

1067010669
{
10671-
Skip: true,
1067210670
Query: "insert into smallint_tbl values (99999);",
1067310671
ExpectedErr: sql.ErrValueOutOfRange,
1067410672
},
@@ -10692,12 +10690,10 @@ where
1069210690
},
1069310691

1069410692
{
10695-
Skip: true,
1069610693
Query: "insert into mediumint_tbl values (99999999);",
1069710694
ExpectedErr: sql.ErrValueOutOfRange,
1069810695
},
1069910696
{
10700-
Skip: true,
1070110697
Query: "insert into mediumint_tbl values (8388607);",
1070210698
Expected: []sql.Row{
1070310699
{types.OkResult{
@@ -10707,7 +10703,6 @@ where
1070710703
},
1070810704
},
1070910705
{
10710-
Skip: true,
1071110706
Query: "show create table mediumint_tbl;",
1071210707
Expected: []sql.Row{
1071310708
{"mediumint_tbl", "CREATE TABLE `mediumint_tbl` (\n" +
@@ -10718,7 +10713,6 @@ where
1071810713
},
1071910714

1072010715
{
10721-
Skip: true,
1072210716
Query: "insert into int_tbl values (99999999999)",
1072310717
ExpectedErr: sql.ErrValueOutOfRange,
1072410718
},
@@ -10742,7 +10736,6 @@ where
1074210736
},
1074310737

1074410738
{
10745-
Skip: true,
1074610739
Query: "insert into bigint_tbl values (99999999999999999999);",
1074710740
ExpectedErr: sql.ErrValueOutOfRange,
1074810741
},
@@ -10779,7 +10772,6 @@ where
1077910772
},
1078010773
Assertions: []ScriptTestAssertion{
1078110774
{
10782-
Skip: true,
1078310775
Query: "insert into tinyint_tbl values (999)",
1078410776
ExpectedErr: sql.ErrValueOutOfRange,
1078510777
},
@@ -10803,7 +10795,6 @@ where
1080310795
},
1080410796

1080510797
{
10806-
Skip: true,
1080710798
Query: "insert into smallint_tbl values (99999);",
1080810799
ExpectedErr: sql.ErrValueOutOfRange,
1080910800
},
@@ -10827,7 +10818,6 @@ where
1082710818
},
1082810819

1082910820
{
10830-
Skip: true,
1083110821
Query: "insert into mediumint_tbl values (999999999);",
1083210822
ExpectedErr: sql.ErrValueOutOfRange,
1083310823
},
@@ -10851,7 +10841,6 @@ where
1085110841
},
1085210842

1085310843
{
10854-
Skip: true,
1085510844
Query: "insert into int_tbl values (99999999999)",
1085610845
ExpectedErr: sql.ErrValueOutOfRange,
1085710846
},
@@ -10875,7 +10864,6 @@ where
1087510864
},
1087610865

1087710866
{
10878-
Skip: true,
1087910867
Query: "insert into bigint_tbl values (999999999999999999999);",
1088010868
ExpectedErr: sql.ErrValueOutOfRange,
1088110869
},

memory/table.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,6 +1150,7 @@ func (t *Table) PeekNextAutoIncrementValue(ctx *sql.Context) (uint64, error) {
11501150
// GetNextAutoIncrementValue gets the next auto increment value for the memory table the increment.
11511151
func (t *Table) GetNextAutoIncrementValue(ctx *sql.Context, insertVal interface{}) (uint64, error) {
11521152
data := t.sessionTableData(ctx)
1153+
11531154

11541155
cmp, err := types.Uint64.Compare(ctx, insertVal, data.autoIncVal)
11551156
if err != nil {

memory/table_editor.go

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package memory
1616

1717
import (
1818
"fmt"
19+
"math"
1920
"reflect"
2021
"strings"
2122

@@ -186,26 +187,37 @@ func (t *tableEditor) Insert(ctx *sql.Context, row sql.Row) error {
186187
return err
187188
}
188189
if cmp > 0 {
190+
// Only update auto-increment if the value is valid for the column type
191+
if _, inRange, err := autoCol.Type.Convert(ctx, row[idx]); err != nil || inRange != sql.InRange {
192+
return nil // Don't update auto-increment for invalid values
193+
}
189194
v, _, err := types.Uint64.Convert(ctx, row[idx])
190195
if err != nil {
191196
return err
192197
}
193-
t.ea.TableData().autoIncVal = v.(uint64)
194-
nextVal := v.(uint64) + 1
195-
if _, inRange, err := autoCol.Type.Convert(ctx, nextVal); err == nil && inRange == sql.InRange {
196-
t.ea.TableData().autoIncVal = nextVal
197-
}
198+
currentVal := v.(uint64)
199+
t.ea.TableData().autoIncVal = currentVal
200+
t.incrementAutoIncrementValue(ctx, autoCol, currentVal)
198201
} else if cmp == 0 {
199-
nextVal := t.ea.TableData().autoIncVal + 1
200-
if _, inRange, err := autoCol.Type.Convert(ctx, nextVal); err == nil && inRange == sql.InRange {
201-
t.ea.TableData().autoIncVal = nextVal
202-
}
202+
currentVal := t.ea.TableData().autoIncVal
203+
t.incrementAutoIncrementValue(ctx, autoCol, currentVal)
203204
}
204205
}
205206

206207
return nil
207208
}
208209

210+
// incrementAutoIncrementValue increments the auto-increment value if it won't cause overflow or exceed type bounds
211+
func (t *tableEditor) incrementAutoIncrementValue(ctx *sql.Context, autoCol *sql.Column, currentVal uint64) {
212+
if currentVal == math.MaxUint64 {
213+
return // Can't increment further
214+
}
215+
nextVal := currentVal + 1
216+
if _, inRange, err := autoCol.Type.Convert(ctx, nextVal); err == nil && inRange == sql.InRange {
217+
t.ea.TableData().autoIncVal = nextVal
218+
}
219+
}
220+
209221
// Delete the given row from the table.
210222
func (t *tableEditor) Delete(ctx *sql.Context, row sql.Row) error {
211223
if err := checkRow(ctx, t.editedTable.Schema(), row); err != nil {

sql/rowexec/insert.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ func (i *insertIter) Next(ctx *sql.Context) (returnRow sql.Row, returnErr error)
120120
ctxWithValues = context.WithValue(ctxWithValues, types.RowNumberKey, i.rowNumber)
121121
ctxWithColumnInfo := ctx.WithContext(ctxWithValues)
122122
converted, inRange, cErr := col.Type.Convert(ctxWithColumnInfo, row[idx])
123-
if cErr == nil && !inRange {
123+
if cErr == nil && inRange == sql.OutOfRange {
124124
cErr = sql.ErrValueOutOfRange.New(row[idx], col.Type)
125125
}
126126
if cErr != nil {

sql/rowexec/insert_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ func TestInsertIgnoreConversions(t *testing.T) {
6565
name: "inserting a negative into an unsigned int results in 0",
6666
colType: types.Uint64,
6767
value: int64(-1),
68-
expected: uint64(1<<64 - 1),
68+
expected: uint64(0),
6969
valueType: types.Uint64,
7070
err: true,
7171
},

sql/rowexec/update_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func TestUpdateIgnoreConversions(t *testing.T) {
6161
name: "inserting a negative into an unsigned int results in 0",
6262
colType: types.Uint64,
6363
value: int64(-1),
64-
expected: uint64(1<<64 - 1),
64+
expected: uint64(0),
6565
valueType: types.Uint64,
6666
},
6767
}

sql/types/number.go

Lines changed: 78 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -260,41 +260,80 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{
260260
}
261261
}
262262

263+
// Check if we're in strict mode for error handling
264+
strictMode := false
265+
if sqlCtx, ok := ctx.(*sql.Context); ok {
266+
strictMode = sql.ValidateStrictMode(sqlCtx)
267+
}
268+
269+
263270
switch t.baseType {
264271
case sqltypes.Int8:
265272
num, _, err := convertToInt64(t, v)
266273
if err != nil {
267274
return nil, sql.OutOfRange, err
268275
}
269276
if num > math.MaxInt8 {
277+
if strictMode {
278+
return nil, sql.OutOfRange, sql.ErrValueOutOfRange.New(v, t.String())
279+
}
270280
return int8(math.MaxInt8), sql.OutOfRange, nil
271281
} else if num < math.MinInt8 {
282+
if strictMode {
283+
return nil, sql.OutOfRange, sql.ErrValueOutOfRange.New(v, t.String())
284+
}
272285
return int8(math.MinInt8), sql.OutOfRange, nil
273286
}
274287
return int8(num), sql.InRange, nil
275288
case sqltypes.Uint8:
276-
return convertToUint8(t, v)
289+
val, inRange, err := convertToUint8(t, v)
290+
if err != nil {
291+
return nil, sql.OutOfRange, err
292+
}
293+
if inRange == sql.OutOfRange && strictMode {
294+
return nil, sql.OutOfRange, sql.ErrValueOutOfRange.New(v, t.String())
295+
}
296+
return val, inRange, nil
277297
case sqltypes.Int16:
278298
num, _, err := convertToInt64(t, v)
279299
if err != nil {
280300
return nil, sql.OutOfRange, err
281301
}
282302
if num > math.MaxInt16 {
303+
if strictMode {
304+
return nil, sql.OutOfRange, sql.ErrValueOutOfRange.New(v, t.String())
305+
}
283306
return int16(math.MaxInt16), sql.OutOfRange, nil
284307
} else if num < math.MinInt16 {
308+
if strictMode {
309+
return nil, sql.OutOfRange, sql.ErrValueOutOfRange.New(v, t.String())
310+
}
285311
return int16(math.MinInt16), sql.OutOfRange, nil
286312
}
287313
return int16(num), sql.InRange, nil
288314
case sqltypes.Uint16:
289-
return convertToUint16(t, v)
315+
val, inRange, err := convertToUint16(t, v)
316+
if err != nil {
317+
return nil, sql.OutOfRange, err
318+
}
319+
if inRange == sql.OutOfRange && strictMode {
320+
return nil, sql.OutOfRange, sql.ErrValueOutOfRange.New(v, t.String())
321+
}
322+
return val, inRange, nil
290323
case sqltypes.Int24:
291324
num, _, err := convertToInt64(t, v)
292325
if err != nil {
293326
return nil, sql.OutOfRange, err
294327
}
295328
if num > (1<<23 - 1) {
329+
if strictMode {
330+
return nil, sql.OutOfRange, sql.ErrValueOutOfRange.New(v, t.String())
331+
}
296332
return int32(1<<23 - 1), sql.OutOfRange, nil
297333
} else if num < (-1 << 23) {
334+
if strictMode {
335+
return nil, sql.OutOfRange, sql.ErrValueOutOfRange.New(v, t.String())
336+
}
298337
return int32(-1 << 23), sql.OutOfRange, nil
299338
}
300339
return int32(num), sql.InRange, nil
@@ -304,8 +343,14 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{
304343
return nil, sql.OutOfRange, err
305344
}
306345
if num >= (1 << 24) {
346+
if strictMode {
347+
return nil, sql.OutOfRange, sql.ErrValueOutOfRange.New(v, t.String())
348+
}
307349
return uint32(1<<24 - 1), sql.OutOfRange, nil
308350
} else if num < 0 {
351+
if strictMode {
352+
return nil, sql.OutOfRange, sql.ErrValueOutOfRange.New(v, t.String())
353+
}
309354
return uint32(1<<24 - int32(-num)), sql.OutOfRange, nil
310355
}
311356
return uint32(num), sql.InRange, nil
@@ -315,17 +360,44 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{
315360
return nil, sql.OutOfRange, err
316361
}
317362
if num > math.MaxInt32 {
363+
if strictMode {
364+
return nil, sql.OutOfRange, sql.ErrValueOutOfRange.New(v, t.String())
365+
}
318366
return int32(math.MaxInt32), sql.OutOfRange, nil
319367
} else if num < math.MinInt32 {
368+
if strictMode {
369+
return nil, sql.OutOfRange, sql.ErrValueOutOfRange.New(v, t.String())
370+
}
320371
return int32(math.MinInt32), sql.OutOfRange, nil
321372
}
322373
return int32(num), sql.InRange, nil
323374
case sqltypes.Uint32:
324-
return convertToUint32(t, v)
375+
val, inRange, err := convertToUint32(t, v)
376+
if err != nil {
377+
return nil, sql.OutOfRange, err
378+
}
379+
if inRange == sql.OutOfRange && strictMode {
380+
return nil, sql.OutOfRange, sql.ErrValueOutOfRange.New(v, t.String())
381+
}
382+
return val, inRange, nil
325383
case sqltypes.Int64:
326-
return convertToInt64(t, v)
384+
val, inRange, err := convertToInt64(t, v)
385+
if err != nil {
386+
return nil, sql.OutOfRange, err
387+
}
388+
if inRange == sql.OutOfRange && strictMode {
389+
return nil, sql.OutOfRange, sql.ErrValueOutOfRange.New(v, t.String())
390+
}
391+
return val, inRange, nil
327392
case sqltypes.Uint64:
328-
return convertToUint64(t, v)
393+
val, inRange, err := convertToUint64(t, v)
394+
if err != nil {
395+
return nil, sql.OutOfRange, err
396+
}
397+
if inRange == sql.OutOfRange && strictMode {
398+
return nil, sql.OutOfRange, sql.ErrValueOutOfRange.New(v, t.String())
399+
}
400+
return val, inRange, nil
329401
case sqltypes.Float32:
330402
num, err := convertToFloat64(t, v)
331403
if err != nil {
@@ -1163,7 +1235,7 @@ func convertToUint64(t NumberTypeImpl_, v interface{}) (uint64, sql.ConvertInRan
11631235
return uint64(math.Round(v)), sql.InRange, nil
11641236
case decimal.Decimal:
11651237
if v.GreaterThan(dec_uint64_max) {
1166-
return math.MaxUint64, sql.InRange, nil
1238+
return math.MaxUint64, sql.OutOfRange, nil
11671239
} else if v.LessThan(dec_zero) {
11681240
ret, _ := dec_uint64_max.Sub(v).Float64()
11691241
return uint64(math.Round(ret)), sql.OutOfRange, nil

0 commit comments

Comments
 (0)