Skip to content

Commit 6919952

Browse files
committed
amend overflow detection logic
1 parent 35d5291 commit 6919952

File tree

5 files changed

+102
-43
lines changed

5 files changed

+102
-43
lines changed

enginetest/queries/script_queries.go

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10481,7 +10481,6 @@ where
1048110481
},
1048210482
Assertions: []ScriptTestAssertion{
1048310483
{
10484-
Skip: true,
1048510484
Query: "insert into tinyint_tbl values (999)",
1048610485
ExpectedErr: sql.ErrValueOutOfRange,
1048710486
},
@@ -10505,7 +10504,6 @@ where
1050510504
},
1050610505

1050710506
{
10508-
Skip: true,
1050910507
Query: "insert into smallint_tbl values (99999);",
1051010508
ExpectedErr: sql.ErrValueOutOfRange,
1051110509
},
@@ -10529,12 +10527,10 @@ where
1052910527
},
1053010528

1053110529
{
10532-
Skip: true,
1053310530
Query: "insert into mediumint_tbl values (99999999);",
1053410531
ExpectedErr: sql.ErrValueOutOfRange,
1053510532
},
1053610533
{
10537-
Skip: true,
1053810534
Query: "insert into mediumint_tbl values (8388607);",
1053910535
Expected: []sql.Row{
1054010536
{types.OkResult{
@@ -10544,7 +10540,6 @@ where
1054410540
},
1054510541
},
1054610542
{
10547-
Skip: true,
1054810543
Query: "show create table mediumint_tbl;",
1054910544
Expected: []sql.Row{
1055010545
{"mediumint_tbl", "CREATE TABLE `mediumint_tbl` (\n" +
@@ -10555,7 +10550,6 @@ where
1055510550
},
1055610551

1055710552
{
10558-
Skip: true,
1055910553
Query: "insert into int_tbl values (99999999999)",
1056010554
ExpectedErr: sql.ErrValueOutOfRange,
1056110555
},
@@ -10579,7 +10573,6 @@ where
1057910573
},
1058010574

1058110575
{
10582-
Skip: true,
1058310576
Query: "insert into bigint_tbl values (99999999999999999999);",
1058410577
ExpectedErr: sql.ErrValueOutOfRange,
1058510578
},
@@ -10616,12 +10609,10 @@ where
1061610609
},
1061710610
Assertions: []ScriptTestAssertion{
1061810611
{
10619-
Skip: true,
1062010612
Query: "insert into tinyint_tbl values (999)",
1062110613
ExpectedErr: sql.ErrValueOutOfRange,
1062210614
},
1062310615
{
10624-
Skip: true,
1062510616
Query: "insert into tinyint_tbl values (255)",
1062610617
Expected: []sql.Row{
1062710618
{types.OkResult{
@@ -10634,19 +10625,17 @@ where
1063410625
Query: "show create table tinyint_tbl;",
1063510626
Expected: []sql.Row{
1063610627
{"tinyint_tbl", "CREATE TABLE `tinyint_tbl` (\n" +
10637-
" `i` tinyint NOT NULL AUTO_INCREMENT,\n" +
10628+
" `i` tinyint unsigned NOT NULL AUTO_INCREMENT,\n" +
1063810629
" PRIMARY KEY (`i`)\n" +
1063910630
") ENGINE=InnoDB AUTO_INCREMENT=255 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"},
1064010631
},
1064110632
},
1064210633

1064310634
{
10644-
Skip: true,
1064510635
Query: "insert into smallint_tbl values (99999);",
1064610636
ExpectedErr: sql.ErrValueOutOfRange,
1064710637
},
1064810638
{
10649-
Skip: true,
1065010639
Query: "insert into smallint_tbl values (65535);",
1065110640
Expected: []sql.Row{
1065210641
{types.OkResult{
@@ -10659,19 +10648,17 @@ where
1065910648
Query: "show create table smallint_tbl;",
1066010649
Expected: []sql.Row{
1066110650
{"smallint_tbl", "CREATE TABLE `smallint_tbl` (\n" +
10662-
" `i` smallint NOT NULL AUTO_INCREMENT,\n" +
10651+
" `i` smallint unsigned NOT NULL AUTO_INCREMENT,\n" +
1066310652
" PRIMARY KEY (`i`)\n" +
1066410653
") ENGINE=InnoDB AUTO_INCREMENT=65535 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"},
1066510654
},
1066610655
},
1066710656

1066810657
{
10669-
Skip: true,
1067010658
Query: "insert into mediumint_tbl values (999999999);",
1067110659
ExpectedErr: sql.ErrValueOutOfRange,
1067210660
},
1067310661
{
10674-
Skip: true,
1067510662
Query: "insert into mediumint_tbl values (16777215);",
1067610663
Expected: []sql.Row{
1067710664
{types.OkResult{
@@ -10684,19 +10671,17 @@ where
1068410671
Query: "show create table mediumint_tbl;",
1068510672
Expected: []sql.Row{
1068610673
{"mediumint_tbl", "CREATE TABLE `mediumint_tbl` (\n" +
10687-
" `i` mediumint NOT NULL AUTO_INCREMENT,\n" +
10674+
" `i` mediumint unsigned NOT NULL AUTO_INCREMENT,\n" +
1068810675
" PRIMARY KEY (`i`)\n" +
1068910676
") ENGINE=InnoDB AUTO_INCREMENT=16777215 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"},
1069010677
},
1069110678
},
1069210679

1069310680
{
10694-
Skip: true,
1069510681
Query: "insert into int_tbl values (99999999999)",
1069610682
ExpectedErr: sql.ErrValueOutOfRange,
1069710683
},
1069810684
{
10699-
Skip: true,
1070010685
Query: "insert into int_tbl values (4294967295)",
1070110686
Expected: []sql.Row{
1070210687
{types.OkResult{
@@ -10709,19 +10694,17 @@ where
1070910694
Query: "show create table int_tbl;",
1071010695
Expected: []sql.Row{
1071110696
{"int_tbl", "CREATE TABLE `int_tbl` (\n" +
10712-
" `i` int NOT NULL AUTO_INCREMENT,\n" +
10697+
" `i` int unsigned NOT NULL AUTO_INCREMENT,\n" +
1071310698
" PRIMARY KEY (`i`)\n" +
1071410699
") ENGINE=InnoDB AUTO_INCREMENT=4294967295 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"},
1071510700
},
1071610701
},
1071710702

1071810703
{
10719-
Skip: true,
1072010704
Query: "insert into bigint_tbl values (999999999999999999999);",
1072110705
ExpectedErr: sql.ErrValueOutOfRange,
1072210706
},
1072310707
{
10724-
Skip: true,
1072510708
Query: "insert into bigint_tbl values (18446744073709551615);",
1072610709
Expected: []sql.Row{
1072710710
{types.OkResult{
@@ -10734,7 +10717,7 @@ where
1073410717
Query: "show create table bigint_tbl;",
1073510718
Expected: []sql.Row{
1073610719
{"bigint_tbl", "CREATE TABLE `bigint_tbl` (\n" +
10737-
" `i` bigint NOT NULL AUTO_INCREMENT,\n" +
10720+
" `i` bigint unsigned NOT NULL AUTO_INCREMENT,\n" +
1073810721
" PRIMARY KEY (`i`)\n" +
1073910722
") ENGINE=InnoDB AUTO_INCREMENT=18446744073709551615 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"},
1074010723
},

memory/table.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"encoding/gob"
2020
"fmt"
2121
"io"
22+
"math"
2223
"sort"
2324
"strconv"
2425
"strings"
@@ -1144,9 +1145,36 @@ func (t *Table) Insert(ctx *sql.Context, row sql.Row) error {
11441145
func (t *Table) PeekNextAutoIncrementValue(ctx *sql.Context) (uint64, error) {
11451146
data := t.sessionTableData(ctx)
11461147

1148+
// Find the auto increment column to validate the current value
1149+
autoCol := t.getAutoIncrementColumn()
1150+
if autoCol == nil {
1151+
return data.autoIncVal, nil
1152+
}
1153+
1154+
// If the current auto increment value is out of range for the column type,
1155+
// return the maximum valid value instead
1156+
if _, inRange, err := autoCol.Type.Convert(ctx, data.autoIncVal); err == nil && inRange == sql.OutOfRange {
1157+
// When auto increment overflowed to 0, show the previous valid value
1158+
if data.autoIncVal == 0 {
1159+
return math.MaxUint64, nil
1160+
}
1161+
return data.autoIncVal - 1, nil
1162+
}
1163+
11471164
return data.autoIncVal, nil
11481165
}
11491166

1167+
// getAutoIncrementColumn returns the auto increment column for this table, or nil if none exists.
1168+
// Only one auto increment column is allowed per table.
1169+
func (t *Table) getAutoIncrementColumn() *sql.Column {
1170+
for _, col := range t.Schema() {
1171+
if col.AutoIncrement {
1172+
return col
1173+
}
1174+
}
1175+
return nil
1176+
}
1177+
11501178
// GetNextAutoIncrementValue gets the next auto increment value for the memory table the increment.
11511179
func (t *Table) GetNextAutoIncrementValue(ctx *sql.Context, insertVal interface{}) (uint64, error) {
11521180
data := t.sessionTableData(ctx)

memory/table_editor.go

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package memory
1616

1717
import (
1818
"fmt"
19+
"math"
1920
"reflect"
2021
"strings"
2122

@@ -188,20 +189,15 @@ func (t *tableEditor) Insert(ctx *sql.Context, row sql.Row) error {
188189
return err
189190
}
190191
if cmp > 0 {
191-
v, _, err := types.Uint64.Convert(ctx, row[idx])
192+
insertedVal, _, err := types.Uint64.Convert(ctx, row[idx])
192193
if err != nil {
193194
return err
194195
}
195-
t.ea.TableData().autoIncVal = v.(uint64)
196-
nextVal := v.(uint64) + 1
197-
if _, inRange, err := autoCol.Type.Convert(ctx, nextVal); err == nil && inRange == sql.InRange {
198-
t.ea.TableData().autoIncVal = nextVal
199-
}
196+
currentAutoIncVal := insertedVal.(uint64)
197+
t.ea.TableData().autoIncVal = t.updateAutoIncrementValue(ctx, autoCol, currentAutoIncVal)
200198
} else if cmp == 0 {
201-
nextVal := t.ea.TableData().autoIncVal + 1
202-
if _, inRange, err := autoCol.Type.Convert(ctx, nextVal); err == nil && inRange == sql.InRange {
203-
t.ea.TableData().autoIncVal = nextVal
204-
}
199+
currentAutoIncVal := t.ea.TableData().autoIncVal
200+
t.ea.TableData().autoIncVal = t.updateAutoIncrementValue(ctx, autoCol, currentAutoIncVal)
205201
}
206202
}
207203

@@ -894,3 +890,22 @@ func verifyRowTypes(row sql.Row, schema sql.Schema) error {
894890
}
895891
return nil
896892
}
893+
894+
// updateAutoIncrementValue safely increments the auto_increment value, handling overflow
895+
// by ensuring it doesn't exceed the column type's maximum value or wrap around.
896+
// It returns the next valid auto_increment value for the given column type.
897+
func (t *tableEditor) updateAutoIncrementValue(ctx *sql.Context, autoCol *sql.Column, currentVal uint64) uint64 {
898+
// Check for arithmetic overflow before adding 1
899+
if currentVal == math.MaxUint64 {
900+
// At maximum uint64 value, can't increment further
901+
return currentVal
902+
}
903+
904+
nextVal := currentVal + 1
905+
if _, inRange, err := autoCol.Type.Convert(ctx, nextVal); err == nil && inRange == sql.InRange {
906+
return nextVal
907+
}
908+
909+
// If next value would be out of range for the column type, stay at current value
910+
return currentVal
911+
}

sql/expression/auto_increment.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,10 @@ func (i *AutoIncrement) Eval(ctx *sql.Context, row sql.Row) (interface{}, error)
139139
given = seq
140140
}
141141

142-
ret, _, err := i.Type().Convert(ctx, given)
142+
ret, inRange, err := i.Type().Convert(ctx, given)
143+
if err == nil && !inRange {
144+
err = sql.ErrValueOutOfRange.New(given, i.Type())
145+
}
143146
if err != nil {
144147
return nil, err
145148
}

sql/types/number.go

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,7 +1163,7 @@ func convertToUint64(t NumberTypeImpl_, v interface{}) (uint64, sql.ConvertInRan
11631163
return uint64(math.Round(v)), sql.InRange, nil
11641164
case decimal.Decimal:
11651165
if v.GreaterThan(dec_uint64_max) {
1166-
return math.MaxUint64, sql.InRange, nil
1166+
return math.MaxUint64, sql.OutOfRange, nil
11671167
} else if v.LessThan(dec_zero) {
11681168
ret, _ := dec_uint64_max.Sub(v).Float64()
11691169
return uint64(math.Round(ret)), sql.OutOfRange, nil
@@ -1181,6 +1181,9 @@ func convertToUint64(t NumberTypeImpl_, v interface{}) (uint64, sql.ConvertInRan
11811181
v = strings.Trim(v, intCutSet)
11821182
if i, err := strconv.ParseUint(v, 10, 64); err == nil {
11831183
return i, sql.InRange, nil
1184+
} else if err == strconv.ErrRange {
1185+
// Number is too large for uint64, return max value and OutOfRange
1186+
return math.MaxUint64, sql.OutOfRange, nil
11841187
}
11851188
if f, err := strconv.ParseFloat(v, 64); err == nil {
11861189
if val, inRange, err := convertToUint64(t, f); err == nil && inRange {
@@ -1238,15 +1241,15 @@ func convertToUint32(t NumberTypeImpl_, v interface{}) (uint32, sql.ConvertInRan
12381241
}
12391242
return uint32(v), sql.InRange, nil
12401243
case uint:
1241-
return uint32(v), sql.InRange, nil
1244+
return convertUintToUint32(uint64(v))
12421245
case uint8:
12431246
return uint32(v), sql.InRange, nil
12441247
case uint16:
12451248
return uint32(v), sql.InRange, nil
12461249
case uint32:
12471250
return v, sql.InRange, nil
12481251
case uint64:
1249-
return uint32(v), sql.InRange, nil
1252+
return convertUintToUint32(v)
12501253
case float64:
12511254
if float32(v) > float32(math.MaxInt32) {
12521255
return math.MaxUint32, sql.OutOfRange, nil
@@ -1334,13 +1337,13 @@ func convertToUint16(t NumberTypeImpl_, v interface{}) (uint16, sql.ConvertInRan
13341337
}
13351338
return uint16(v), sql.InRange, nil
13361339
case uint:
1337-
return uint16(v), sql.InRange, nil
1340+
return convertUintToUint16(uint64(v))
13381341
case uint8:
13391342
return uint16(v), sql.InRange, nil
13401343
case uint64:
1341-
return uint16(v), sql.InRange, nil
1344+
return convertUintToUint16(v)
13421345
case uint32:
1343-
return uint16(v), sql.InRange, nil
1346+
return convertUintToUint16(uint64(v))
13441347
case uint16:
13451348
return v, sql.InRange, nil
13461349
case float32:
@@ -1434,13 +1437,13 @@ func convertToUint8(t NumberTypeImpl_, v interface{}) (uint8, sql.ConvertInRange
14341437
}
14351438
return uint8(v), sql.InRange, nil
14361439
case uint:
1437-
return uint8(v), sql.InRange, nil
1440+
return convertUintToUint8(uint64(v))
14381441
case uint16:
1439-
return uint8(v), sql.InRange, nil
1442+
return convertUintToUint8(uint64(v))
14401443
case uint64:
1441-
return uint8(v), sql.InRange, nil
1444+
return convertUintToUint8(v)
14421445
case uint32:
1443-
return uint8(v), sql.InRange, nil
1446+
return convertUintToUint8(uint64(v))
14441447
case uint8:
14451448
return v, sql.InRange, nil
14461449
case float32:
@@ -1719,3 +1722,30 @@ func CoalesceInt(val interface{}) (int, bool) {
17191722
return 0, false
17201723
}
17211724
}
1725+
1726+
// convertUintToUint8 converts a uint64 value to uint8 with overflow checking.
1727+
// Returns the converted value, range status, and any error.
1728+
func convertUintToUint8(v uint64) (uint8, sql.ConvertInRange, error) {
1729+
if v > math.MaxUint8 {
1730+
return uint8(math.MaxUint8), sql.OutOfRange, nil
1731+
}
1732+
return uint8(v), sql.InRange, nil
1733+
}
1734+
1735+
// convertUintToUint16 converts a uint64 value to uint16 with overflow checking.
1736+
// Returns the converted value, range status, and any error.
1737+
func convertUintToUint16(v uint64) (uint16, sql.ConvertInRange, error) {
1738+
if v > math.MaxUint16 {
1739+
return uint16(math.MaxUint16), sql.OutOfRange, nil
1740+
}
1741+
return uint16(v), sql.InRange, nil
1742+
}
1743+
1744+
// convertUintToUint32 converts a uint64 value to uint32 with overflow checking.
1745+
// Returns the converted value, range status, and any error.
1746+
func convertUintToUint32(v uint64) (uint32, sql.ConvertInRange, error) {
1747+
if v > math.MaxUint32 {
1748+
return uint32(math.MaxUint32), sql.OutOfRange, nil
1749+
}
1750+
return uint32(v), sql.InRange, nil
1751+
}

0 commit comments

Comments
 (0)