@@ -352,12 +352,12 @@ func (i *modifyColumnIter) Next(ctx *sql.Context) (sql.Row, error) {
352352
353353 // TODO: replace with different node in analyzer
354354 if rwt , ok := i .alterable .(sql.RewritableTable ); ok {
355- rewritten , err := i .rewriteTable (ctx , rwt )
355+ rewritten , rowCount , err := i .rewriteTable (ctx , rwt )
356356 if err != nil {
357357 return nil , err
358358 }
359359 if rewritten {
360- return sql .NewRow (types .NewOkResult (0 )), nil
360+ return sql .NewRow (types .NewOkResult (rowCount )), nil
361361 }
362362 }
363363
@@ -376,7 +376,42 @@ func (i *modifyColumnIter) Next(ctx *sql.Context) (sql.Row, error) {
376376 return nil , err
377377 }
378378 }
379- return sql .NewRow (types .NewOkResult (0 )), nil
379+
380+ rowCount , err := countTableRows (ctx , i .alterable )
381+ if err != nil {
382+ return nil , err
383+ }
384+
385+ return sql .NewRow (types .NewOkResult (rowCount )), nil
386+ }
387+
388+ // countTableRows counts the number of rows in a table for DDL result reporting
389+ func countTableRows (ctx * sql.Context , table sql.Table ) (int , error ) {
390+ partitions , err := table .Partitions (ctx )
391+ if err != nil {
392+ return 0 , err
393+ }
394+
395+ rowIter := sql .NewTableRowIter (ctx , table , partitions )
396+ defer func () {
397+ if rowIter != nil {
398+ _ = rowIter .Close (ctx )
399+ }
400+ }()
401+
402+ var count int
403+ for {
404+ _ , err := rowIter .Next (ctx )
405+ if err == io .EOF {
406+ break
407+ }
408+ if err != nil {
409+ return 0 , err
410+ }
411+ count ++
412+ }
413+
414+ return count , nil
380415}
381416
382417func handleFkColumnRename (ctx * sql.Context , fkTable sql.ForeignKeyTable , db sql.Database , oldName string , newName string ) error {
@@ -481,20 +516,20 @@ func (i *modifyColumnIter) Close(context *sql.Context) error {
481516}
482517
483518// rewriteTable rewrites the table given if required or requested, and returns whether it was rewritten
484- func (i * modifyColumnIter ) rewriteTable (ctx * sql.Context , rwt sql.RewritableTable ) (bool , error ) {
519+ func (i * modifyColumnIter ) rewriteTable (ctx * sql.Context , rwt sql.RewritableTable ) (bool , int , error ) {
485520 targetSchema := i .m .TargetSchema ()
486521 oldColName := i .m .Column ()
487522 oldColIdx := targetSchema .IndexOfColName (oldColName )
488523 if oldColIdx == - 1 {
489524 // Should be impossible, checked in analyzer
490- return false , sql .ErrTableColumnNotFound .New (rwt .Name (), oldColName )
525+ return false , 0 , sql .ErrTableColumnNotFound .New (rwt .Name (), oldColName )
491526 }
492527
493528 oldCol := i .m .TargetSchema ()[oldColIdx ]
494529 newCol := i .m .NewColumn ()
495530 newSch , projections , err := modifyColumnInSchema (targetSchema , oldColName , newCol , i .m .Order ())
496531 if err != nil {
497- return false , err
532+ return false , 0 , err
498533 }
499534
500535 // Wrap any auto increment columns in auto increment expressions. This mirrors what happens to row sources for normal
@@ -503,7 +538,7 @@ func (i *modifyColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTabl
503538 if col .AutoIncrement {
504539 projections [i ], err = expression .NewAutoIncrementForColumn (ctx , rwt , col , projections [i ])
505540 if err != nil {
506- return false , err
541+ return false , 0 , err
507542 }
508543 }
509544 }
@@ -532,28 +567,29 @@ func (i *modifyColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTabl
532567 // TODO: codify rewrite requirements
533568 rewriteRequested := rwt .ShouldRewriteTable (ctx , oldPkSchema , newPkSchema , oldCol , newCol )
534569 if ! rewriteRequired && ! rewriteRequested {
535- return false , nil
570+ return false , 0 , nil
536571 }
537572
538573 inserter , err := rwt .RewriteInserter (ctx , oldPkSchema , newPkSchema , oldCol , newCol , nil )
539574 if err != nil {
540- return false , err
575+ return false , 0 , err
541576 }
542577
543578 partitions , err := rwt .Partitions (ctx )
544579 if err != nil {
545- return false , err
580+ return false , 0 , err
546581 }
547582
548583 rowIter := sql .NewTableRowIter (ctx , rwt , partitions )
584+ var rowCount int
549585 for {
550586 r , err := rowIter .Next (ctx )
551587 if err == io .EOF {
552588 break
553589 } else if err != nil {
554590 _ = inserter .DiscardChanges (ctx , err )
555591 _ = inserter .Close (ctx )
556- return false , err
592+ return false , 0 , err
557593 }
558594
559595 // remap old enum values to new enum values
@@ -564,7 +600,7 @@ func (i *modifyColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTabl
564600 oldStr , _ := oldEnum .At (oldIdx )
565601 newIdx := newEnum .IndexOf (oldStr )
566602 if newIdx == - 1 {
567- return false , types .ErrDataTruncatedForColumn .New (newCol .Name )
603+ return false , 0 , types .ErrDataTruncatedForColumn .New (newCol .Name )
568604 }
569605 r [oldColIdx ] = uint16 (newIdx )
570606 }
@@ -574,31 +610,32 @@ func (i *modifyColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTabl
574610 if err != nil {
575611 _ = inserter .DiscardChanges (ctx , err )
576612 _ = inserter .Close (ctx )
577- return false , err
613+ return false , 0 , err
578614 }
579615
580616 err = i .validateNullability (ctx , newSch , newRow )
581617 if err != nil {
582618 _ = inserter .DiscardChanges (ctx , err )
583619 _ = inserter .Close (ctx )
584- return false , err
620+ return false , 0 , err
585621 }
586622
587623 err = inserter .Insert (ctx , newRow )
588624 if err != nil {
589625 _ = inserter .DiscardChanges (ctx , err )
590626 _ = inserter .Close (ctx )
591- return false , err
627+ return false , 0 , err
592628 }
629+ rowCount ++
593630 }
594631
595632 // TODO: move this into iter.close, probably
596633 err = inserter .Close (ctx )
597634 if err != nil {
598- return false , err
635+ return false , 0 , err
599636 }
600637
601- return true , nil
638+ return true , rowCount , nil
602639}
603640
604641// modifyColumnInSchema modifies the given column in given schema and returns the new schema, along with a set of
0 commit comments