Skip to content

Commit 226c8d7

Browse files
committed
migrate: re-run post-step callbacks on error
This commit modifies the migration framework to re-attempt post-step callbacks if they error during a migration. Previously, if a post-step callbacks failed, but their associated SQL migration succeeded, the post-step callback would not be re-attempted on the next migration run, and instead proceed with the next SQL migration. This is achieved by introducing the concept of a "post-step callback" migration version. Post-step callbacks are their corresponding SQL migration version offset by +1000000000. During the execution of a post-step callback, the post-step callback migration version will be persisted as the database version. That way, if the post-step callback errors, the version for the database will be the post-step callback version on the next startup. The post-step callback will then be re-attempted before proceeding with the next SQL migration.
1 parent 35f3e3e commit 226c8d7

File tree

3 files changed

+469
-20
lines changed

3 files changed

+469
-20
lines changed

migrate.go

Lines changed: 220 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,23 @@ func (m *Migrate) Migrate(version uint) error {
241241
return m.unlockErr(ErrDirty{curVersion})
242242
}
243243

244+
// If the current version is a clean migration task version, then we
245+
// need to rerun the task for the previous version before we can
246+
// continue with any SQL migration(s).
247+
if InTaskVersionRange(curVersion) {
248+
sqlMigVersion := SQLMigrationVersion(curVersion)
249+
250+
err := m.execTaskAtMigVersion(sqlMigVersion)
251+
if err != nil {
252+
return m.unlockErr(err)
253+
}
254+
255+
curVersion, dirty, err = m.databaseDrv.Version()
256+
if err != nil {
257+
return m.unlockErr(err)
258+
}
259+
}
260+
244261
ret := make(chan interface{}, m.PrefetchMigrations)
245262
go m.read(curVersion, int(version), ret)
246263

@@ -267,6 +284,23 @@ func (m *Migrate) Steps(n int) error {
267284
return m.unlockErr(ErrDirty{curVersion})
268285
}
269286

287+
// If the current version is a clean migration task version, then we
288+
// need to rerun the task for the previous version before we can
289+
// continue with any SQL migration(s).
290+
if InTaskVersionRange(curVersion) {
291+
sqlMigVersion := SQLMigrationVersion(curVersion)
292+
293+
err := m.execTaskAtMigVersion(sqlMigVersion)
294+
if err != nil {
295+
return m.unlockErr(err)
296+
}
297+
298+
curVersion, dirty, err = m.databaseDrv.Version()
299+
if err != nil {
300+
return m.unlockErr(err)
301+
}
302+
}
303+
270304
ret := make(chan interface{}, m.PrefetchMigrations)
271305

272306
if n > 0 {
@@ -294,6 +328,23 @@ func (m *Migrate) Up() error {
294328
return m.unlockErr(ErrDirty{curVersion})
295329
}
296330

331+
// If the current version is a clean migration task version, then we
332+
// need to rerun the task for the previous version before we can
333+
// continue with any SQL migration(s).
334+
if InTaskVersionRange(curVersion) {
335+
sqlMigVersion := SQLMigrationVersion(curVersion)
336+
337+
err := m.execTaskAtMigVersion(sqlMigVersion)
338+
if err != nil {
339+
return m.unlockErr(err)
340+
}
341+
342+
curVersion, dirty, err = m.databaseDrv.Version()
343+
if err != nil {
344+
return m.unlockErr(err)
345+
}
346+
}
347+
297348
ret := make(chan interface{}, m.PrefetchMigrations)
298349

299350
go m.readUp(curVersion, -1, ret)
@@ -316,6 +367,23 @@ func (m *Migrate) Down() error {
316367
return m.unlockErr(ErrDirty{curVersion})
317368
}
318369

370+
// If the current version is a clean migration task version, then we
371+
// need to rerun the task for the previous version before we can
372+
// continue with any SQL migration(s).
373+
if InTaskVersionRange(curVersion) {
374+
sqlMigVersion := SQLMigrationVersion(curVersion)
375+
376+
err := m.execTaskAtMigVersion(sqlMigVersion)
377+
if err != nil {
378+
return m.unlockErr(err)
379+
}
380+
381+
curVersion, dirty, err = m.databaseDrv.Version()
382+
if err != nil {
383+
return m.unlockErr(err)
384+
}
385+
}
386+
319387
ret := make(chan interface{}, m.PrefetchMigrations)
320388
go m.readDown(curVersion, -1, ret)
321389
return m.unlockErr(m.runMigrations(ret))
@@ -354,6 +422,23 @@ func (m *Migrate) Run(migration ...*Migration) error {
354422
return m.unlockErr(ErrDirty{curVersion})
355423
}
356424

425+
// If the current version is a clean migration task version, then
426+
// we need to rerun the task for the previous version before we can
427+
// continue with any SQL migration(s).
428+
if InTaskVersionRange(curVersion) {
429+
sqlMigVersion := SQLMigrationVersion(curVersion)
430+
431+
err := m.execTaskAtMigVersion(sqlMigVersion)
432+
if err != nil {
433+
return m.unlockErr(err)
434+
}
435+
436+
curVersion, dirty, err = m.databaseDrv.Version()
437+
if err != nil {
438+
return m.unlockErr(err)
439+
}
440+
}
441+
357442
ret := make(chan interface{}, m.PrefetchMigrations)
358443

359444
go func() {
@@ -732,6 +817,30 @@ func (m *Migrate) readDown(from int, limit int, ret chan<- interface{}) {
732817
}
733818
}
734819

820+
// readSingle reads a single migration for the given version, and sends it
821+
// over the passed channel.
822+
func (m *Migrate) readSingle(ver uint, ret chan<- interface{}) {
823+
defer close(ret)
824+
825+
if err := m.versionExists(ver); err != nil {
826+
ret <- err
827+
return
828+
}
829+
830+
migr, err := m.newMigration(ver, int(ver))
831+
if err != nil {
832+
ret <- err
833+
return
834+
}
835+
836+
ret <- migr
837+
go func() {
838+
if err := migr.Buffer(); err != nil {
839+
m.logErr(err)
840+
}
841+
}()
842+
}
843+
735844
// runMigrations reads *Migration and error from a channel. Any other type
736845
// sent on this channel will result in a panic. Each migration is then
737846
// proxied to the database driver and run against the database.
@@ -752,6 +861,12 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error {
752861
case *Migration:
753862
migr := r
754863

864+
if migr.Version >= TaskVersionOffset {
865+
return fmt.Errorf("migration version %v is "+
866+
"invalid, must be < %v", migr.Version,
867+
TaskVersionOffset)
868+
}
869+
755870
// set version with dirty state
756871
if err := m.databaseDrv.SetVersion(migr.TargetVersion, true); err != nil {
757872
return err
@@ -763,23 +878,9 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error {
763878
return err
764879
}
765880

766-
// If there is a task function for this
767-
// migration, run it now.
768-
cb, ok := m.opts.tasks[migr.Version]
769-
if ok {
770-
m.logVerbosePrintf("Running migration "+
771-
"task for %v\n", migr.LogString())
772-
773-
err := cb(migr, m.databaseDrv)
774-
if err != nil {
775-
return fmt.Errorf("failed to "+
776-
"execute migration "+
777-
"task: %w",
778-
err)
779-
}
780-
781-
m.logVerbosePrintf("Migration task "+
782-
"finished for %v\n", migr.LogString())
881+
err := m.execTask(migr)
882+
if err != nil {
883+
return err
783884
}
784885
}
785886

@@ -808,6 +909,108 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error {
808909
return nil
809910
}
810911

912+
// execTask checks if a migration task exists for the passed migration and
913+
// proceeds to execute if one exists.
914+
func (m *Migrate) execTask(migr *Migration) error {
915+
task, ok := m.opts.tasks[migr.Version]
916+
if ok {
917+
m.logVerbosePrintf("Running migration task for %v\n",
918+
migr.LogString())
919+
920+
taskVersion := int(migr.Version) + TaskVersionOffset
921+
922+
// Persist that we are in the migration task phase for this
923+
// version.
924+
if err := m.databaseDrv.SetVersion(taskVersion, true); err != nil {
925+
return err
926+
}
927+
928+
err := task(migr, m.databaseDrv)
929+
if err != nil {
930+
// Mark the database version as the taskVersion but in a
931+
// clean state, to indicate that the migration task
932+
// errored. We will therefore re-run the task on the
933+
// next migration run.
934+
if setErr := m.databaseDrv.SetVersion(taskVersion, false); setErr != nil {
935+
// Note that if we error here, the database
936+
// version will remain in a dirty state. As we
937+
// cannot know if the migration task was
938+
// executed or not in that scenario, manual
939+
// intervention is required.
940+
return fmt.Errorf("WARNING, failed to set "+
941+
"migration version after migration "+
942+
"task errored. Manual intervention "+
943+
"needed! Migration task error: %w, "+
944+
"version setting error : %w",
945+
err, setErr)
946+
}
947+
948+
return fmt.Errorf("failed to execute migration "+
949+
"task: %w", err)
950+
}
951+
952+
m.logVerbosePrintf("Migration task finished for %v\n",
953+
migr.LogString())
954+
}
955+
956+
return nil
957+
}
958+
959+
// execTaskAtMigVersion executes only the migration task for the passed SQL
960+
// migration version.
961+
// The function can be used to re-execute the task for a SQL migration version
962+
// where the SQL migration was successfully applied, but where the task failed.
963+
func (m *Migrate) execTaskAtMigVersion(sqlMigVersion int) error {
964+
var (
965+
r interface{}
966+
migRet = make(chan interface{}, m.PrefetchMigrations)
967+
err error
968+
)
969+
970+
// Fetch the migration for the specified SQL migration version.
971+
go m.readSingle(uint(sqlMigVersion), migRet)
972+
973+
select {
974+
case r = <-migRet:
975+
case <-time.After(30 * time.Second):
976+
return fmt.Errorf("timeout waiting for single migration "+
977+
"version %v", sqlMigVersion)
978+
}
979+
980+
if m.stop() {
981+
return nil
982+
}
983+
984+
switch r := r.(type) {
985+
case *Migration:
986+
// If the migration was found, execute the migration task.
987+
migr := r
988+
989+
err = m.execTask(migr)
990+
if err != nil {
991+
return err
992+
}
993+
994+
m.logVerbosePrintf("successfully re-executed migration task "+
995+
"for SQL migration version: %v\n", sqlMigVersion)
996+
997+
// set clean state
998+
if err = m.databaseDrv.SetVersion(migr.TargetVersion, false); err != nil {
999+
return err
1000+
}
1001+
1002+
return nil
1003+
1004+
case error:
1005+
return fmt.Errorf("reading SQL migration at version "+
1006+
"%v failed: %w", sqlMigVersion, r)
1007+
1008+
default:
1009+
return fmt.Errorf("unknown type: %T when reading "+
1010+
"single migration", r)
1011+
}
1012+
}
1013+
8111014
// versionExists checks the source if either the up or down migration for
8121015
// the specified migration version exists.
8131016
func (m *Migrate) versionExists(version uint) (result error) {

0 commit comments

Comments
 (0)