Skip to content

Commit 8b0896d

Browse files
authored
Merge pull request #3076 from dolthub/elianddb/9425-enum-zero-strict-mode
dolthub/dolt#9425 - Fix enum zero validation in strict mode
2 parents 62fac20 + 24ee22c commit 8b0896d

File tree

4 files changed

+127
-11
lines changed

4 files changed

+127
-11
lines changed

enginetest/queries/script_queries.go

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8465,23 +8465,20 @@ where
84658465
ExpectedErr: sql.ErrIncompatibleDefaultType,
84668466
},
84678467
{
8468-
Skip: true,
84698468
Query: "create table bad (e enum('a') default 0);",
8470-
ExpectedErr: sql.ErrIncompatibleDefaultType,
8469+
ExpectedErr: sql.ErrInvalidColumnDefaultValue,
84718470
},
84728471
{
84738472
Query: "create table bad (e enum('a') default '');",
84748473
ExpectedErr: sql.ErrIncompatibleDefaultType,
84758474
},
84768475
{
8477-
Skip: true,
84788476
Query: "create table bad (e enum('a') default '1');",
8479-
ExpectedErr: sql.ErrIncompatibleDefaultType,
8477+
ExpectedErr: sql.ErrInvalidColumnDefaultValue,
84808478
},
84818479
{
8482-
Skip: true,
84838480
Query: "create table bad (e enum('a') default 1);",
8484-
ExpectedErr: sql.ErrIncompatibleDefaultType,
8481+
ExpectedErr: sql.ErrInvalidColumnDefaultValue,
84858482
},
84868483

84878484
{
@@ -8688,22 +8685,79 @@ where
86888685
},
86898686
{
86908687
// This is with STRICT_TRANS_TABLES or STRICT_ALL_TABLES in sql_mode
8691-
Skip: true,
86928688
Name: "enums with zero",
86938689
Dialect: "mysql",
86948690
SetUpScript: []string{
8691+
"SET sql_mode = 'STRICT_TRANS_TABLES';",
86958692
"create table t (e enum('a', 'b', 'c'));",
86968693
},
86978694
Assertions: []ScriptTestAssertion{
86988695
{
8699-
Query: "insert into t values (0);",
8700-
// TODO should be truncated error, but this is the error we throw for empty string
8701-
ExpectedErrStr: "is not valid for this Enum",
8696+
Query: "insert into t values (0);",
8697+
ExpectedErrStr: "Data truncated for column 'e' at row 1",
8698+
},
8699+
{
8700+
Query: "insert into t values ('a'), (0), ('b');",
8701+
ExpectedErrStr: "Data truncated for column 'e' at row 2",
87028702
},
87038703
{
87048704
Query: "create table tt (e enum('a', 'b', 'c') default 0)",
87058705
ExpectedErr: sql.ErrInvalidColumnDefaultValue,
87068706
},
8707+
{
8708+
Query: "create table et (e enum('a', 'b', '', 'c'));",
8709+
Expected: []sql.Row{
8710+
{types.NewOkResult(0)},
8711+
},
8712+
},
8713+
{
8714+
Query: "insert into et values (0);",
8715+
ExpectedErrStr: "Data truncated for column 'e' at row 1",
8716+
},
8717+
},
8718+
},
8719+
{
8720+
Name: "enums with zero strict all tables",
8721+
Dialect: "mysql",
8722+
SetUpScript: []string{
8723+
"SET sql_mode = 'STRICT_ALL_TABLES';",
8724+
"create table t (e enum('a', 'b', 'c'));",
8725+
},
8726+
Assertions: []ScriptTestAssertion{
8727+
{
8728+
Query: "insert into t values (0);",
8729+
ExpectedErrStr: "Data truncated for column 'e' at row 1",
8730+
},
8731+
{
8732+
Query: "insert into t values ('a'), (0), ('b');",
8733+
ExpectedErrStr: "Data truncated for column 'e' at row 2",
8734+
},
8735+
{
8736+
Query: "create table tt (e enum('a', 'b', 'c') default 0)",
8737+
ExpectedErr: sql.ErrInvalidColumnDefaultValue,
8738+
},
8739+
},
8740+
},
8741+
{
8742+
Name: "enums with zero non-strict mode",
8743+
Dialect: "mysql",
8744+
SetUpScript: []string{
8745+
"SET sql_mode = '';",
8746+
"create table t (e enum('a', 'b', 'c'));",
8747+
},
8748+
Assertions: []ScriptTestAssertion{
8749+
{
8750+
Query: "insert into t values (0);",
8751+
Expected: []sql.Row{
8752+
{types.NewOkResult(1)},
8753+
},
8754+
},
8755+
{
8756+
Query: "select * from t;",
8757+
Expected: []sql.Row{
8758+
{""},
8759+
},
8760+
},
87078761
},
87088762
},
87098763
{

sql/analyzer/resolve_column_defaults.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,9 +278,46 @@ func validateColumnDefault(ctx *sql.Context, col *sql.Column, colDefault *sql.Co
278278
return err
279279
}
280280

281+
if enumType, isEnum := col.Type.(sql.EnumType); isEnum && colDefault.IsLiteral() {
282+
if err = validateEnumLiteralDefault(enumType, colDefault, col.Name, ctx); err != nil {
283+
return err
284+
}
285+
}
286+
281287
return nil
282288
}
283289

290+
// validateEnumLiteralDefault validates enum literal defaults more strictly than runtime conversions
291+
// MySQL doesn't allow numeric index references for literal enum defaults
292+
func validateEnumLiteralDefault(enumType sql.EnumType, colDefault *sql.ColumnDefaultValue, columnName string, ctx *sql.Context) error {
293+
val, err := colDefault.Expr.Eval(ctx, nil)
294+
if err != nil {
295+
return err
296+
}
297+
298+
switch v := val.(type) {
299+
case string:
300+
// For string values, check if it's a direct enum value match
301+
enumValues := enumType.Values()
302+
for _, enumVal := range enumValues {
303+
if enumVal == v {
304+
return nil // Valid enum value
305+
}
306+
}
307+
// String doesn't match any enum value, return appropriate error
308+
if v == "" {
309+
return sql.ErrIncompatibleDefaultType.New()
310+
}
311+
return sql.ErrInvalidColumnDefaultValue.New(columnName)
312+
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
313+
// MySQL doesn't allow numeric enum indices as literal defaults
314+
return sql.ErrInvalidColumnDefaultValue.New(columnName)
315+
default:
316+
// Other types not supported for enum defaults
317+
return sql.ErrIncompatibleDefaultType.New()
318+
}
319+
}
320+
284321
func stripTableNamesFromDefault(e *expression.Wrapper) (sql.Expression, transform.TreeIdentity, error) {
285322
newDefault, ok := e.Unwrap().(*sql.ColumnDefaultValue)
286323
if !ok {

sql/planbuilder/ddl.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1418,6 +1418,15 @@ func (b *Builder) tableSpecToSchema(inScope, outScope *scope, db sql.Database, t
14181418
}
14191419

14201420
for i, def := range defaults {
1421+
// Early validation for enum default 0 to catch it before conversion
1422+
if def != nil && types.IsEnum(schema[i].Type) {
1423+
if lit, ok := def.(*ast.SQLVal); ok {
1424+
if lit.Type == ast.IntVal && string(lit.Val) == "0" {
1425+
b.handleErr(sql.ErrInvalidColumnDefaultValue.New(schema[i].Name))
1426+
}
1427+
}
1428+
}
1429+
14211430
schema[i].Default = b.convertDefaultExpression(outScope, def, schema[i].Type, schema[i].Nullable)
14221431
err := validateDefaultExprs(schema[i])
14231432
if err != nil {

sql/types/enum.go

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,10 @@ func (t EnumType) Convert(ctx context.Context, v interface{}) (interface{}, sql.
165165

166166
switch value := v.(type) {
167167
case int:
168+
// MySQL rejects 0 values in strict mode regardless of enum definition
169+
if value == 0 && t.validateScrictMode(ctx) {
170+
return nil, sql.OutOfRange, ErrConvertingToEnum.New(value)
171+
}
168172
if _, ok := t.At(value); ok {
169173
return uint16(value), sql.InRange, nil
170174
}
@@ -177,7 +181,10 @@ func (t EnumType) Convert(ctx context.Context, v interface{}) (interface{}, sql.
177181
case int16:
178182
return t.Convert(ctx, int(value))
179183
case uint16:
180-
return t.Convert(ctx, int(value))
184+
// uint16 values are stored enum indices - allow them without strict mode validation
185+
if _, ok := t.At(int(value)); ok {
186+
return value, sql.InRange, nil
187+
}
181188
case int32:
182189
return t.Convert(ctx, int(value))
183190
case uint32:
@@ -208,6 +215,15 @@ func (t EnumType) Convert(ctx context.Context, v interface{}) (interface{}, sql.
208215
return nil, sql.InRange, ErrConvertingToEnum.New(v)
209216
}
210217

218+
// validateScrictMode checks if STRICT_TRANS_TABLES or STRICT_ALL_TABLES is enabled
219+
func (t EnumType) validateScrictMode(ctx context.Context) bool {
220+
if sqlCtx, ok := ctx.(*sql.Context); ok {
221+
sqlMode := sql.LoadSqlMode(sqlCtx)
222+
return sqlMode.ModeEnabled("STRICT_TRANS_TABLES") || sqlMode.ModeEnabled("STRICT_ALL_TABLES")
223+
}
224+
return false
225+
}
226+
211227
// Equals implements the Type interface.
212228
func (t EnumType) Equals(otherType sql.Type) bool {
213229
if ot, ok := otherType.(EnumType); ok && t.collation.Equals(ot.collation) && len(t.idxToVal) == len(ot.idxToVal) {

0 commit comments

Comments
 (0)