@@ -471,13 +471,17 @@ func (i *modifyColumnIter) Close(context *sql.Context) error {
471471
472472// rewriteTable rewrites the table given if required or requested, and returns whether it was rewritten
473473func (i * modifyColumnIter ) rewriteTable (ctx * sql.Context , rwt sql.RewritableTable ) (bool , error ) {
474- oldColIdx := i .m .TargetSchema ().IndexOfColName (i .m .Column ())
474+ targetSchema := i .m .TargetSchema ()
475+ oldColName := i .m .Column ()
476+ oldColIdx := targetSchema .IndexOfColName (oldColName )
475477 if oldColIdx == - 1 {
476478 // Should be impossible, checked in analyzer
477- return false , sql .ErrTableColumnNotFound .New (rwt .Name (), i . m . Column () )
479+ return false , sql .ErrTableColumnNotFound .New (rwt .Name (), oldColName )
478480 }
479481
480- newSch , projections , err := modifyColumnInSchema (i .m .TargetSchema (), i .m .Column (), i .m .NewColumn (), i .m .Order ())
482+ oldCol := i .m .TargetSchema ()[oldColIdx ]
483+ newCol := i .m .NewColumn ()
484+ newSch , projections , err := modifyColumnInSchema (targetSchema , oldColName , newCol , i .m .Order ())
481485 if err != nil {
482486 return false , err
483487 }
@@ -494,27 +498,33 @@ func (i *modifyColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTabl
494498 }
495499
496500 var renames []sql.ColumnRename
497- if i . m . Column () != i . m . NewColumn () .Name {
501+ if oldColName != newCol .Name {
498502 renames = []sql.ColumnRename {{
499- Before : i . m . Column () , After : i . m . NewColumn () .Name ,
503+ Before : oldColName , After : newCol .Name ,
500504 }}
501505 }
502506
503507 oldPkSchema := sql .SchemaToPrimaryKeySchema (rwt , rwt .Schema ())
504508 newPkSchema := sql .SchemaToPrimaryKeySchema (rwt , newSch , renames ... )
505509
506510 rewriteRequired := false
507- if i .m .TargetSchema ()[oldColIdx ].Nullable && ! i .m .NewColumn ().Nullable {
511+ if oldCol .Nullable && ! newCol .Nullable {
512+ rewriteRequired = true
513+ }
514+
515+ oldEnum , isOldEnum := oldCol .Type .(sql.EnumType )
516+ newEnum , isNewEnum := newCol .Type .(sql.EnumType )
517+ if isOldEnum && isNewEnum {
508518 rewriteRequired = true
509519 }
510520
511521 // TODO: codify rewrite requirements
512- rewriteRequested := rwt .ShouldRewriteTable (ctx , oldPkSchema , newPkSchema , i . m . TargetSchema ()[ oldColIdx ], i . m . NewColumn () )
522+ rewriteRequested := rwt .ShouldRewriteTable (ctx , oldPkSchema , newPkSchema , oldCol , newCol )
513523 if ! rewriteRequired && ! rewriteRequested {
514524 return false , nil
515525 }
516526
517- inserter , err := rwt .RewriteInserter (ctx , oldPkSchema , newPkSchema , i . m . TargetSchema ()[ oldColIdx ], i . m . NewColumn () , nil )
527+ inserter , err := rwt .RewriteInserter (ctx , oldPkSchema , newPkSchema , oldCol , newCol , nil )
518528 if err != nil {
519529 return false , err
520530 }
@@ -524,8 +534,8 @@ func (i *modifyColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTabl
524534 return false , err
525535 }
526536
537+ newColIdx := newSch .IndexOf (newCol .Name , newCol .Source )
527538 rowIter := sql .NewTableRowIter (ctx , rwt , partitions )
528-
529539 for {
530540 r , err := rowIter .Next (ctx )
531541 if err == io .EOF {
@@ -543,6 +553,18 @@ func (i *modifyColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTabl
543553 return false , err
544554 }
545555
556+ // remap old enum values to new enum values
557+ if isOldEnum && isNewEnum {
558+ oldIdx := int (newRow [newColIdx ].(uint16 ))
559+ oldStr , _ := oldEnum .At (oldIdx )
560+ newIdx := newEnum .IndexOf (oldStr )
561+ if newIdx == - 1 {
562+ // TODO: convert to truncated warning, and somehow still show old enum value
563+ return false , fmt .Errorf ("data truncated for column %s" , newCol .Name )
564+ }
565+ newRow [newColIdx ] = uint16 (newIdx )
566+ }
567+
546568 err = i .validateNullability (ctx , newSch , newRow )
547569 if err != nil {
548570 _ = inserter .DiscardChanges (ctx , err )
0 commit comments