diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index 51f3592a84..c9965a80c6 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -7448,6 +7448,57 @@ where }, }, }, + { + Name: "preserve enums through alter statements", + SetUpScript: []string{ + "create table t (i int primary key, e enum('a', 'b', 'c'));", + "insert into t values (1, 'a');", + "insert into t values (2, 'b');", + "insert into t values (3, 'c');", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select i, e, e + 0 from t;", + Expected: []sql.Row{ + {1, "a", float64(1)}, + {2, "b", float64(2)}, + {3, "c", float64(3)}, + }, + }, + { + Query: "alter table t modify column e enum('c', 'a', 'b');", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + Query: "select i, e, e + 0 from t;", + Expected: []sql.Row{ + {1, "a", float64(2)}, + {2, "b", float64(3)}, + {3, "c", float64(1)}, + }, + }, + { + Query: "alter table t modify column e enum('asdf', 'a', 'b', 'c');", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + Query: "select i, e, e + 0 from t;", + Expected: []sql.Row{ + {1, "a", float64(2)}, + {2, "b", float64(3)}, + {3, "c", float64(4)}, + }, + }, + { + Query: "alter table t modify column e enum('abc');", + ExpectedErrStr: "value 2 is not valid for this Enum", + }, + }, + }, } var SpatialScriptTests = []ScriptTest{ diff --git a/sql/rowexec/ddl_iters.go b/sql/rowexec/ddl_iters.go index a23dd9b168..90307a9320 100644 --- a/sql/rowexec/ddl_iters.go +++ b/sql/rowexec/ddl_iters.go @@ -471,13 +471,17 @@ func (i *modifyColumnIter) Close(context *sql.Context) error { // rewriteTable rewrites the table given if required or requested, and returns whether it was rewritten func (i *modifyColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTable) (bool, error) { - oldColIdx := i.m.TargetSchema().IndexOfColName(i.m.Column()) + targetSchema := i.m.TargetSchema() + oldColName := i.m.Column() + oldColIdx := targetSchema.IndexOfColName(oldColName) if oldColIdx == -1 { // Should be impossible, checked in analyzer - return false, sql.ErrTableColumnNotFound.New(rwt.Name(), i.m.Column()) + return false, sql.ErrTableColumnNotFound.New(rwt.Name(), oldColName) } - newSch, projections, err := modifyColumnInSchema(i.m.TargetSchema(), i.m.Column(), i.m.NewColumn(), i.m.Order()) + oldCol := i.m.TargetSchema()[oldColIdx] + newCol := i.m.NewColumn() + newSch, projections, err := modifyColumnInSchema(targetSchema, oldColName, newCol, i.m.Order()) if err != nil { return false, err } @@ -494,9 +498,9 @@ func (i *modifyColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTabl } var renames []sql.ColumnRename - if i.m.Column() != i.m.NewColumn().Name { + if oldColName != newCol.Name { renames = []sql.ColumnRename{{ - Before: i.m.Column(), After: i.m.NewColumn().Name, + Before: oldColName, After: newCol.Name, }} } @@ -504,17 +508,23 @@ func (i *modifyColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTabl newPkSchema := sql.SchemaToPrimaryKeySchema(rwt, newSch, renames...) rewriteRequired := false - if i.m.TargetSchema()[oldColIdx].Nullable && !i.m.NewColumn().Nullable { + if oldCol.Nullable && !newCol.Nullable { + rewriteRequired = true + } + + oldEnum, isOldEnum := oldCol.Type.(sql.EnumType) + newEnum, isNewEnum := newCol.Type.(sql.EnumType) + if isOldEnum && isNewEnum && !oldEnum.Equals(newEnum) { rewriteRequired = true } // TODO: codify rewrite requirements - rewriteRequested := rwt.ShouldRewriteTable(ctx, oldPkSchema, newPkSchema, i.m.TargetSchema()[oldColIdx], i.m.NewColumn()) + rewriteRequested := rwt.ShouldRewriteTable(ctx, oldPkSchema, newPkSchema, oldCol, newCol) if !rewriteRequired && !rewriteRequested { return false, nil } - inserter, err := rwt.RewriteInserter(ctx, oldPkSchema, newPkSchema, i.m.TargetSchema()[oldColIdx], i.m.NewColumn(), nil) + inserter, err := rwt.RewriteInserter(ctx, oldPkSchema, newPkSchema, oldCol, newCol, nil) if err != nil { return false, err } @@ -524,8 +534,8 @@ func (i *modifyColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTabl return false, err } + newColIdx := newSch.IndexOf(newCol.Name, newCol.Source) rowIter := sql.NewTableRowIter(ctx, rwt, partitions) - for { r, err := rowIter.Next(ctx) if err == io.EOF { @@ -543,6 +553,17 @@ func (i *modifyColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTabl return false, err } + // remap old enum values to new enum values + if isOldEnum && isNewEnum { + oldIdx := int(newRow[newColIdx].(uint16)) + oldStr, _ := oldEnum.At(oldIdx) + newIdx := newEnum.IndexOf(oldStr) + if newIdx == -1 { + return false, fmt.Errorf("data truncated for column %s", newCol.Name) + } + newRow[newColIdx] = uint16(newIdx) + } + err = i.validateNullability(ctx, newSch, newRow) if err != nil { _ = inserter.DiscardChanges(ctx, err)