From 2f96704812ccb4d2bcbec26d1f7bcb36e1bce40e Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 9 Oct 2024 17:15:13 -0700 Subject: [PATCH 1/4] enum --- enginetest/queries/script_queries.go | 47 ++++++++++++++++++++++++++++ sql/rowexec/ddl_iters.go | 40 +++++++++++++++++------ 2 files changed, 78 insertions(+), 9 deletions(-) diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index 51f3592a84..c403ba0efd 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -7448,6 +7448,53 @@ where }, }, }, + { + Name: "preserve enums through alter statements", + SetUpScript: []string{ + "create table t (i int primary key, e enum('a', 'b', 'c'));", + "insert into t values (1, 'a');", + "insert into t values (2, 'b');", + "insert into t values (3, 'c');", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select i, e, e + 0 from t;", + Expected: []sql.Row{ + {1, "a", float64(1)}, + {2, "b", float64(2)}, + {3, "c", float64(3)}, + }, + }, + { + Query: "alter table t modify column e enum('c', 'a', 'b');", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + Query: "select i, e, e + 0 from t;", + Expected: []sql.Row{ + {1, "a", float64(2)}, + {2, "b", float64(3)}, + {3, "c", float64(1)}, + }, + }, + { + Query: "alter table t modify column e enum('asdf', 'a', 'b', 'c');", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + Query: "select i, e, e + 0 from t;", + Expected: []sql.Row{ + {1, "a", float64(2)}, + {2, "b", float64(3)}, + {3, "c", float64(4)}, + }, + }, + }, + }, } var SpatialScriptTests = []ScriptTest{ diff --git a/sql/rowexec/ddl_iters.go b/sql/rowexec/ddl_iters.go index a23dd9b168..aefdcee20a 100644 --- a/sql/rowexec/ddl_iters.go +++ b/sql/rowexec/ddl_iters.go @@ -471,13 +471,17 @@ func (i *modifyColumnIter) Close(context *sql.Context) error { // rewriteTable rewrites the table given if required or requested, and returns whether it was rewritten func (i *modifyColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTable) (bool, error) { - oldColIdx := i.m.TargetSchema().IndexOfColName(i.m.Column()) + targetSchema := i.m.TargetSchema() + oldColName := i.m.Column() + oldColIdx := targetSchema.IndexOfColName(oldColName) if oldColIdx == -1 { // Should be impossible, checked in analyzer - return false, sql.ErrTableColumnNotFound.New(rwt.Name(), i.m.Column()) + return false, sql.ErrTableColumnNotFound.New(rwt.Name(), oldColName) } - newSch, projections, err := modifyColumnInSchema(i.m.TargetSchema(), i.m.Column(), i.m.NewColumn(), i.m.Order()) + oldCol := i.m.TargetSchema()[oldColIdx] + newCol := i.m.NewColumn() + newSch, projections, err := modifyColumnInSchema(targetSchema, oldColName, newCol, i.m.Order()) if err != nil { return false, err } @@ -494,9 +498,9 @@ func (i *modifyColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTabl } var renames []sql.ColumnRename - if i.m.Column() != i.m.NewColumn().Name { + if oldColName != newCol.Name { renames = []sql.ColumnRename{{ - Before: i.m.Column(), After: i.m.NewColumn().Name, + Before: oldColName, After: newCol.Name, }} } @@ -504,17 +508,23 @@ func (i *modifyColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTabl newPkSchema := sql.SchemaToPrimaryKeySchema(rwt, newSch, renames...) rewriteRequired := false - if i.m.TargetSchema()[oldColIdx].Nullable && !i.m.NewColumn().Nullable { + if oldCol.Nullable && !newCol.Nullable { + rewriteRequired = true + } + + oldEnum, isOldEnum := oldCol.Type.(sql.EnumType) + newEnum, isNewEnum := newCol.Type.(sql.EnumType) + if isOldEnum && isNewEnum { rewriteRequired = true } // TODO: codify rewrite requirements - rewriteRequested := rwt.ShouldRewriteTable(ctx, oldPkSchema, newPkSchema, i.m.TargetSchema()[oldColIdx], i.m.NewColumn()) + rewriteRequested := rwt.ShouldRewriteTable(ctx, oldPkSchema, newPkSchema, oldCol, newCol) if !rewriteRequired && !rewriteRequested { return false, nil } - inserter, err := rwt.RewriteInserter(ctx, oldPkSchema, newPkSchema, i.m.TargetSchema()[oldColIdx], i.m.NewColumn(), nil) + inserter, err := rwt.RewriteInserter(ctx, oldPkSchema, newPkSchema, oldCol, newCol, nil) if err != nil { return false, err } @@ -524,8 +534,8 @@ func (i *modifyColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTabl return false, err } + newColIdx := newSch.IndexOf(newCol.Name, newCol.Source) rowIter := sql.NewTableRowIter(ctx, rwt, partitions) - for { r, err := rowIter.Next(ctx) if err == io.EOF { @@ -543,6 +553,18 @@ func (i *modifyColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTabl return false, err } + // remap old enum values to new enum values + if isOldEnum && isNewEnum { + oldIdx := int(newRow[newColIdx].(uint16)) + oldStr, _ := oldEnum.At(oldIdx) + newIdx := newEnum.IndexOf(oldStr) + if newIdx == -1 { + // TODO: convert to truncated warning, and somehow still show old enum value + return false, fmt.Errorf("data truncated for column %s", newCol.Name) + } + newRow[newColIdx] = uint16(newIdx) + } + err = i.validateNullability(ctx, newSch, newRow) if err != nil { _ = inserter.DiscardChanges(ctx, err) From 64cad7e0db08c293227e62bc8c0ad5cb2d2721bb Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 10 Oct 2024 10:52:11 -0700 Subject: [PATCH 2/4] error test --- enginetest/queries/script_queries.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index c403ba0efd..18c91abb60 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -7493,6 +7493,11 @@ where {3, "c", float64(4)}, }, }, + { + // TODO: MySQL preserves the original enum val and enum string, and throws a warning. + Query: "alter table t modify column e enum('abc');", + ExpectedErrStr: "value 2 is not valid for this Enum", + }, }, }, } From 0c9dc4871c4d0998e9c1d307aa374bdadec282ce Mon Sep 17 00:00:00 2001 From: jycor Date: Thu, 10 Oct 2024 17:53:33 +0000 Subject: [PATCH 3/4] [ga-format-pr] Run ./format_repo.sh to fix formatting --- enginetest/queries/script_queries.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index 18c91abb60..85458cd072 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -7495,7 +7495,7 @@ where }, { // TODO: MySQL preserves the original enum val and enum string, and throws a warning. - Query: "alter table t modify column e enum('abc');", + Query: "alter table t modify column e enum('abc');", ExpectedErrStr: "value 2 is not valid for this Enum", }, }, From 8e1c71c8a846237d02a9a7ab3a07e1982a6b297a Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 10 Oct 2024 14:37:23 -0700 Subject: [PATCH 4/4] feedback --- enginetest/queries/script_queries.go | 1 - sql/rowexec/ddl_iters.go | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index 18c91abb60..c05352712e 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -7494,7 +7494,6 @@ where }, }, { - // TODO: MySQL preserves the original enum val and enum string, and throws a warning. Query: "alter table t modify column e enum('abc');", ExpectedErrStr: "value 2 is not valid for this Enum", }, diff --git a/sql/rowexec/ddl_iters.go b/sql/rowexec/ddl_iters.go index aefdcee20a..90307a9320 100644 --- a/sql/rowexec/ddl_iters.go +++ b/sql/rowexec/ddl_iters.go @@ -514,7 +514,7 @@ func (i *modifyColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTabl oldEnum, isOldEnum := oldCol.Type.(sql.EnumType) newEnum, isNewEnum := newCol.Type.(sql.EnumType) - if isOldEnum && isNewEnum { + if isOldEnum && isNewEnum && !oldEnum.Equals(newEnum) { rewriteRequired = true } @@ -559,7 +559,6 @@ func (i *modifyColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTabl oldStr, _ := oldEnum.At(oldIdx) newIdx := newEnum.IndexOf(oldStr) if newIdx == -1 { - // TODO: convert to truncated warning, and somehow still show old enum value return false, fmt.Errorf("data truncated for column %s", newCol.Name) } newRow[newColIdx] = uint16(newIdx)