Skip to content

Commit c5e4529

Browse files
elianddbclaude
andcommitted
Fix ALTER TABLE MODIFY COLUMN to report actual row count
- Updated modifyColumnIter.Next() to count and return actual rows affected - Modified rewriteTable() signature to return row count (bool, int, error) - Added countTableRows() helper function for accurate row counting - Handles both rewrite and non-rewrite code paths correctly Fixes dolthub/dolt#9606 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent f45e28b commit c5e4529

File tree

1 file changed

+54
-17
lines changed

1 file changed

+54
-17
lines changed

sql/rowexec/ddl_iters.go

Lines changed: 54 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -352,12 +352,12 @@ func (i *modifyColumnIter) Next(ctx *sql.Context) (sql.Row, error) {
352352

353353
// TODO: replace with different node in analyzer
354354
if rwt, ok := i.alterable.(sql.RewritableTable); ok {
355-
rewritten, err := i.rewriteTable(ctx, rwt)
355+
rewritten, rowCount, err := i.rewriteTable(ctx, rwt)
356356
if err != nil {
357357
return nil, err
358358
}
359359
if rewritten {
360-
return sql.NewRow(types.NewOkResult(0)), nil
360+
return sql.NewRow(types.NewOkResult(rowCount)), nil
361361
}
362362
}
363363

@@ -376,7 +376,42 @@ func (i *modifyColumnIter) Next(ctx *sql.Context) (sql.Row, error) {
376376
return nil, err
377377
}
378378
}
379-
return sql.NewRow(types.NewOkResult(0)), nil
379+
380+
rowCount, err := countTableRows(ctx, i.alterable)
381+
if err != nil {
382+
return nil, err
383+
}
384+
385+
return sql.NewRow(types.NewOkResult(rowCount)), nil
386+
}
387+
388+
// countTableRows counts the number of rows in a table for DDL result reporting
389+
func countTableRows(ctx *sql.Context, table sql.Table) (int, error) {
390+
partitions, err := table.Partitions(ctx)
391+
if err != nil {
392+
return 0, err
393+
}
394+
395+
rowIter := sql.NewTableRowIter(ctx, table, partitions)
396+
defer func() {
397+
if rowIter != nil {
398+
_ = rowIter.Close(ctx)
399+
}
400+
}()
401+
402+
var count int
403+
for {
404+
_, err := rowIter.Next(ctx)
405+
if err == io.EOF {
406+
break
407+
}
408+
if err != nil {
409+
return 0, err
410+
}
411+
count++
412+
}
413+
414+
return count, nil
380415
}
381416

382417
func handleFkColumnRename(ctx *sql.Context, fkTable sql.ForeignKeyTable, db sql.Database, oldName string, newName string) error {
@@ -481,20 +516,20 @@ func (i *modifyColumnIter) Close(context *sql.Context) error {
481516
}
482517

483518
// rewriteTable rewrites the table given if required or requested, and returns whether it was rewritten
484-
func (i *modifyColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTable) (bool, error) {
519+
func (i *modifyColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTable) (bool, int, error) {
485520
targetSchema := i.m.TargetSchema()
486521
oldColName := i.m.Column()
487522
oldColIdx := targetSchema.IndexOfColName(oldColName)
488523
if oldColIdx == -1 {
489524
// Should be impossible, checked in analyzer
490-
return false, sql.ErrTableColumnNotFound.New(rwt.Name(), oldColName)
525+
return false, 0, sql.ErrTableColumnNotFound.New(rwt.Name(), oldColName)
491526
}
492527

493528
oldCol := i.m.TargetSchema()[oldColIdx]
494529
newCol := i.m.NewColumn()
495530
newSch, projections, err := modifyColumnInSchema(targetSchema, oldColName, newCol, i.m.Order())
496531
if err != nil {
497-
return false, err
532+
return false, 0, err
498533
}
499534

500535
// Wrap any auto increment columns in auto increment expressions. This mirrors what happens to row sources for normal
@@ -503,7 +538,7 @@ func (i *modifyColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTabl
503538
if col.AutoIncrement {
504539
projections[i], err = expression.NewAutoIncrementForColumn(ctx, rwt, col, projections[i])
505540
if err != nil {
506-
return false, err
541+
return false, 0, err
507542
}
508543
}
509544
}
@@ -532,28 +567,29 @@ func (i *modifyColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTabl
532567
// TODO: codify rewrite requirements
533568
rewriteRequested := rwt.ShouldRewriteTable(ctx, oldPkSchema, newPkSchema, oldCol, newCol)
534569
if !rewriteRequired && !rewriteRequested {
535-
return false, nil
570+
return false, 0, nil
536571
}
537572

538573
inserter, err := rwt.RewriteInserter(ctx, oldPkSchema, newPkSchema, oldCol, newCol, nil)
539574
if err != nil {
540-
return false, err
575+
return false, 0, err
541576
}
542577

543578
partitions, err := rwt.Partitions(ctx)
544579
if err != nil {
545-
return false, err
580+
return false, 0, err
546581
}
547582

548583
rowIter := sql.NewTableRowIter(ctx, rwt, partitions)
584+
var rowCount int
549585
for {
550586
r, err := rowIter.Next(ctx)
551587
if err == io.EOF {
552588
break
553589
} else if err != nil {
554590
_ = inserter.DiscardChanges(ctx, err)
555591
_ = inserter.Close(ctx)
556-
return false, err
592+
return false, 0, err
557593
}
558594

559595
// remap old enum values to new enum values
@@ -564,7 +600,7 @@ func (i *modifyColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTabl
564600
oldStr, _ := oldEnum.At(oldIdx)
565601
newIdx := newEnum.IndexOf(oldStr)
566602
if newIdx == -1 {
567-
return false, types.ErrDataTruncatedForColumn.New(newCol.Name)
603+
return false, 0, types.ErrDataTruncatedForColumn.New(newCol.Name)
568604
}
569605
r[oldColIdx] = uint16(newIdx)
570606
}
@@ -574,31 +610,32 @@ func (i *modifyColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTabl
574610
if err != nil {
575611
_ = inserter.DiscardChanges(ctx, err)
576612
_ = inserter.Close(ctx)
577-
return false, err
613+
return false, 0, err
578614
}
579615

580616
err = i.validateNullability(ctx, newSch, newRow)
581617
if err != nil {
582618
_ = inserter.DiscardChanges(ctx, err)
583619
_ = inserter.Close(ctx)
584-
return false, err
620+
return false, 0, err
585621
}
586622

587623
err = inserter.Insert(ctx, newRow)
588624
if err != nil {
589625
_ = inserter.DiscardChanges(ctx, err)
590626
_ = inserter.Close(ctx)
591-
return false, err
627+
return false, 0, err
592628
}
629+
rowCount++
593630
}
594631

595632
// TODO: move this into iter.close, probably
596633
err = inserter.Close(ctx)
597634
if err != nil {
598-
return false, err
635+
return false, 0, err
599636
}
600637

601-
return true, nil
638+
return true, rowCount, nil
602639
}
603640

604641
// modifyColumnInSchema modifies the given column in given schema and returns the new schema, along with a set of

0 commit comments

Comments
 (0)