Skip to content

Commit b2047a3

Browse files
committed
migrate: re-run migration tasks callbacks on error
This commit modifies the migration framework to re-attempt migration tasks if they error during a migration, on the next run. Previously, if a migration task failed, but their associated SQL migration succeeded, the database version would be set to a dirty state, and require manual intervention in order to reset the SQL migration and re-attempt it + the migration task. The new re-attempt mechanism is achieved by ensuring that a migration can only be either an SQL migration or a migration task, but not both. This way, if a migration task errors, the database version will be reset to the previous version prior to executing the migration task, and the migration task will be re-attempted on the next run.
1 parent 35f3e3e commit b2047a3

File tree

2 files changed

+447
-45
lines changed

2 files changed

+447
-45
lines changed

migrate.go

Lines changed: 129 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,13 @@
55
package migrate
66

77
import (
8+
"bytes"
89
"errors"
910
"fmt"
11+
"io"
1012
"os"
13+
"regexp"
14+
"strings"
1115
"sync"
1216
"time"
1317

@@ -732,6 +736,34 @@ func (m *Migrate) readDown(from int, limit int, ret chan<- interface{}) {
732736
}
733737
}
734738

739+
// hasSQLMigration checks if the passed data contains executable statements,
740+
// meaning that the data doesn't only contain comments/whitespace or semicolons.
741+
func (m *Migrate) hasSQLMigration(data []byte) (bool, error) {
742+
s := string(data)
743+
744+
// Remove Byte Order Mark (BOM) if present in the migration file.
745+
s = strings.TrimPrefix(s, "\uFEFF")
746+
747+
// Strip block comments /* ... */ (non-greedy, across lines).
748+
reBlock := regexp.MustCompile(`(?s)/\*.*?\*/`)
749+
s = reBlock.ReplaceAllString(s, "")
750+
751+
// Strip line comments -- ... (to end of line).
752+
reLine := regexp.MustCompile(`(?m)--[^\n\r]*`)
753+
s = reLine.ReplaceAllString(s, "")
754+
755+
// Trim whitespaces.
756+
s = strings.TrimSpace(s)
757+
758+
// Remove any semicolons, newlines, tabs, or spaces from the beginning
759+
// and end of the string.
760+
s = strings.Trim(s, ";\r\n\t ")
761+
762+
// If the string still contains any characters, the data likely
763+
// contains executable statements.
764+
return len(s) > 0, nil
765+
}
766+
735767
// runMigrations reads *Migration and error from a channel. Any other type
736768
// sent on this channel will result in a panic. Each migration is then
737769
// proxied to the database driver and run against the database.
@@ -752,34 +784,58 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error {
752784
case *Migration:
753785
migr := r
754786

755-
// set version with dirty state
756-
if err := m.databaseDrv.SetVersion(migr.TargetVersion, true); err != nil {
757-
return err
758-
}
759-
760787
if migr.Body != nil {
761-
m.logVerbosePrintf("Read and execute %v\n", migr.LogString())
762-
if err := m.databaseDrv.Run(migr.BufferedBody); err != nil {
788+
// Read the body so we can inspect and (re)use it.
789+
data, err := io.ReadAll(migr.BufferedBody)
790+
if err != nil {
791+
return fmt.Errorf("read migration body: %w", err)
792+
}
793+
794+
// Reset the reader so the driver can read it
795+
migr.BufferedBody = bytes.NewReader(data)
796+
797+
// Check if the migration contains an SQL
798+
// migration.
799+
hasSqlMig, err := m.hasSQLMigration(data)
800+
if err != nil {
763801
return err
764802
}
765803

766-
// If there is a task function for this
767-
// migration, run it now.
768-
cb, ok := m.opts.tasks[migr.Version]
769-
if ok {
770-
m.logVerbosePrintf("Running migration "+
771-
"task for %v\n", migr.LogString())
804+
// Check if the migration contains a migration
805+
// task.
806+
_, hasMigTask := m.opts.tasks[migr.Version]
807+
808+
// Execute the SQL migration or the migration
809+
// task.
810+
switch {
811+
case hasSqlMig && hasMigTask:
812+
return fmt.Errorf("migration has both " +
813+
"a SQL migration and a " +
814+
"migration task set")
815+
816+
case hasSqlMig:
817+
if err = m.databaseDrv.SetVersion(migr.TargetVersion, true); err != nil {
818+
return err
819+
}
772820

773-
err := cb(migr, m.databaseDrv)
821+
m.logVerbosePrintf("Read and execute %v\n", migr.LogString())
822+
if err = m.databaseDrv.Run(migr.BufferedBody); err != nil {
823+
return err
824+
}
825+
826+
case hasMigTask:
827+
err = m.execTask(migr)
774828
if err != nil {
775-
return fmt.Errorf("failed to "+
776-
"execute migration "+
777-
"task: %w",
778-
err)
829+
return fmt.Errorf("migration "+
830+
"task execution "+
831+
"failed: %w", err)
779832
}
780833

781-
m.logVerbosePrintf("Migration task "+
782-
"finished for %v\n", migr.LogString())
834+
default:
835+
// When the migration contains no SQL
836+
// migration or migration task, we
837+
// continue and set the version to the
838+
// migr.TargetVersion.
783839
}
784840
}
785841

@@ -808,6 +864,59 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error {
808864
return nil
809865
}
810866

867+
// execTask checks if a migration task exists for the passed migration and
868+
// proceeds to execute if one exists. If the migration task fails, the function
869+
// will reset the database version to the version it was set to before
870+
// attempting to execute the migration task.
871+
func (m *Migrate) execTask(migr *Migration) error {
872+
m.logVerbosePrintf("Running migration task for %v\n", migr.LogString())
873+
874+
task, ok := m.opts.tasks[migr.Version]
875+
if !ok {
876+
return fmt.Errorf("no migration task set for %v",
877+
migr.LogString())
878+
}
879+
880+
// Get the current database version before executing the migration task.
881+
curVersion, dirty, err := m.databaseDrv.Version()
882+
if err != nil {
883+
return fmt.Errorf("unable to get current version: %w", err)
884+
}
885+
886+
if dirty {
887+
return ErrDirty{curVersion}
888+
}
889+
890+
// Persist that we are at the migration version of the migration task.
891+
if err = m.databaseDrv.SetVersion(int(migr.Version), true); err != nil {
892+
return err
893+
}
894+
895+
err = task(migr, m.databaseDrv)
896+
if err != nil {
897+
// Reset the version to the version set before executing the
898+
// migration task. Therefore, the migration task will be
899+
// re-executed on nnext startup until it succeeds.
900+
setErr := m.databaseDrv.SetVersion(curVersion, false)
901+
if setErr != nil {
902+
// Note that if we error here, the database version will
903+
// remain in a dirty state. As we cannot know if the
904+
// migration task was executed or not in that scenario,
905+
// manual intervention is required.
906+
return fmt.Errorf("WARNING, failed to set migration "+
907+
"version after migration task errored. Manual "+
908+
"intervention needed! Migration task error: "+
909+
"%w, version setting error : %w", err, setErr)
910+
}
911+
912+
return fmt.Errorf("failed to execute migration task: %w", err)
913+
}
914+
915+
m.logVerbosePrintf("Migration task finished for %v\n", migr.LogString())
916+
917+
return nil
918+
}
919+
811920
// versionExists checks the source if either the up or down migration for
812921
// the specified migration version exists.
813922
func (m *Migrate) versionExists(version uint) (result error) {

0 commit comments

Comments
 (0)