Skip to content

Commit 2f96704

Browse files
author
James Cor
committed
enum
1 parent 487cf93 commit 2f96704

File tree

2 files changed

+78
-9
lines changed

2 files changed

+78
-9
lines changed

enginetest/queries/script_queries.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7448,6 +7448,53 @@ where
74487448
},
74497449
},
74507450
},
7451+
{
7452+
Name: "preserve enums through alter statements",
7453+
SetUpScript: []string{
7454+
"create table t (i int primary key, e enum('a', 'b', 'c'));",
7455+
"insert into t values (1, 'a');",
7456+
"insert into t values (2, 'b');",
7457+
"insert into t values (3, 'c');",
7458+
},
7459+
Assertions: []ScriptTestAssertion{
7460+
{
7461+
Query: "select i, e, e + 0 from t;",
7462+
Expected: []sql.Row{
7463+
{1, "a", float64(1)},
7464+
{2, "b", float64(2)},
7465+
{3, "c", float64(3)},
7466+
},
7467+
},
7468+
{
7469+
Query: "alter table t modify column e enum('c', 'a', 'b');",
7470+
Expected: []sql.Row{
7471+
{types.NewOkResult(0)},
7472+
},
7473+
},
7474+
{
7475+
Query: "select i, e, e + 0 from t;",
7476+
Expected: []sql.Row{
7477+
{1, "a", float64(2)},
7478+
{2, "b", float64(3)},
7479+
{3, "c", float64(1)},
7480+
},
7481+
},
7482+
{
7483+
Query: "alter table t modify column e enum('asdf', 'a', 'b', 'c');",
7484+
Expected: []sql.Row{
7485+
{types.NewOkResult(0)},
7486+
},
7487+
},
7488+
{
7489+
Query: "select i, e, e + 0 from t;",
7490+
Expected: []sql.Row{
7491+
{1, "a", float64(2)},
7492+
{2, "b", float64(3)},
7493+
{3, "c", float64(4)},
7494+
},
7495+
},
7496+
},
7497+
},
74517498
}
74527499

74537500
var SpatialScriptTests = []ScriptTest{

sql/rowexec/ddl_iters.go

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -471,13 +471,17 @@ func (i *modifyColumnIter) Close(context *sql.Context) error {
471471

472472
// rewriteTable rewrites the table given if required or requested, and returns whether it was rewritten
473473
func (i *modifyColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTable) (bool, error) {
474-
oldColIdx := i.m.TargetSchema().IndexOfColName(i.m.Column())
474+
targetSchema := i.m.TargetSchema()
475+
oldColName := i.m.Column()
476+
oldColIdx := targetSchema.IndexOfColName(oldColName)
475477
if oldColIdx == -1 {
476478
// Should be impossible, checked in analyzer
477-
return false, sql.ErrTableColumnNotFound.New(rwt.Name(), i.m.Column())
479+
return false, sql.ErrTableColumnNotFound.New(rwt.Name(), oldColName)
478480
}
479481

480-
newSch, projections, err := modifyColumnInSchema(i.m.TargetSchema(), i.m.Column(), i.m.NewColumn(), i.m.Order())
482+
oldCol := i.m.TargetSchema()[oldColIdx]
483+
newCol := i.m.NewColumn()
484+
newSch, projections, err := modifyColumnInSchema(targetSchema, oldColName, newCol, i.m.Order())
481485
if err != nil {
482486
return false, err
483487
}
@@ -494,27 +498,33 @@ func (i *modifyColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTabl
494498
}
495499

496500
var renames []sql.ColumnRename
497-
if i.m.Column() != i.m.NewColumn().Name {
501+
if oldColName != newCol.Name {
498502
renames = []sql.ColumnRename{{
499-
Before: i.m.Column(), After: i.m.NewColumn().Name,
503+
Before: oldColName, After: newCol.Name,
500504
}}
501505
}
502506

503507
oldPkSchema := sql.SchemaToPrimaryKeySchema(rwt, rwt.Schema())
504508
newPkSchema := sql.SchemaToPrimaryKeySchema(rwt, newSch, renames...)
505509

506510
rewriteRequired := false
507-
if i.m.TargetSchema()[oldColIdx].Nullable && !i.m.NewColumn().Nullable {
511+
if oldCol.Nullable && !newCol.Nullable {
512+
rewriteRequired = true
513+
}
514+
515+
oldEnum, isOldEnum := oldCol.Type.(sql.EnumType)
516+
newEnum, isNewEnum := newCol.Type.(sql.EnumType)
517+
if isOldEnum && isNewEnum {
508518
rewriteRequired = true
509519
}
510520

511521
// TODO: codify rewrite requirements
512-
rewriteRequested := rwt.ShouldRewriteTable(ctx, oldPkSchema, newPkSchema, i.m.TargetSchema()[oldColIdx], i.m.NewColumn())
522+
rewriteRequested := rwt.ShouldRewriteTable(ctx, oldPkSchema, newPkSchema, oldCol, newCol)
513523
if !rewriteRequired && !rewriteRequested {
514524
return false, nil
515525
}
516526

517-
inserter, err := rwt.RewriteInserter(ctx, oldPkSchema, newPkSchema, i.m.TargetSchema()[oldColIdx], i.m.NewColumn(), nil)
527+
inserter, err := rwt.RewriteInserter(ctx, oldPkSchema, newPkSchema, oldCol, newCol, nil)
518528
if err != nil {
519529
return false, err
520530
}
@@ -524,8 +534,8 @@ func (i *modifyColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTabl
524534
return false, err
525535
}
526536

537+
newColIdx := newSch.IndexOf(newCol.Name, newCol.Source)
527538
rowIter := sql.NewTableRowIter(ctx, rwt, partitions)
528-
529539
for {
530540
r, err := rowIter.Next(ctx)
531541
if err == io.EOF {
@@ -543,6 +553,18 @@ func (i *modifyColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTabl
543553
return false, err
544554
}
545555

556+
// remap old enum values to new enum values
557+
if isOldEnum && isNewEnum {
558+
oldIdx := int(newRow[newColIdx].(uint16))
559+
oldStr, _ := oldEnum.At(oldIdx)
560+
newIdx := newEnum.IndexOf(oldStr)
561+
if newIdx == -1 {
562+
// TODO: convert to truncated warning, and somehow still show old enum value
563+
return false, fmt.Errorf("data truncated for column %s", newCol.Name)
564+
}
565+
newRow[newColIdx] = uint16(newIdx)
566+
}
567+
546568
err = i.validateNullability(ctx, newSch, newRow)
547569
if err != nil {
548570
_ = inserter.DiscardChanges(ctx, err)

0 commit comments

Comments
 (0)