@@ -232,15 +232,11 @@ func (m *Migrate) Migrate(version uint) error {
232232 return err
233233 }
234234
235- curVersion , dirty , err := m .databaseDrv . Version ()
235+ curVersion , err := m .ensureCleanCurrentSQLVersion ()
236236 if err != nil {
237237 return m .unlockErr (err )
238238 }
239239
240- if dirty {
241- return m .unlockErr (ErrDirty {curVersion })
242- }
243-
244240 ret := make (chan interface {}, m .PrefetchMigrations )
245241 go m .read (curVersion , int (version ), ret )
246242
@@ -258,15 +254,11 @@ func (m *Migrate) Steps(n int) error {
258254 return err
259255 }
260256
261- curVersion , dirty , err := m .databaseDrv . Version ()
257+ curVersion , err := m .ensureCleanCurrentSQLVersion ()
262258 if err != nil {
263259 return m .unlockErr (err )
264260 }
265261
266- if dirty {
267- return m .unlockErr (ErrDirty {curVersion })
268- }
269-
270262 ret := make (chan interface {}, m .PrefetchMigrations )
271263
272264 if n > 0 {
@@ -285,15 +277,11 @@ func (m *Migrate) Up() error {
285277 return err
286278 }
287279
288- curVersion , dirty , err := m .databaseDrv . Version ()
280+ curVersion , err := m .ensureCleanCurrentSQLVersion ()
289281 if err != nil {
290282 return m .unlockErr (err )
291283 }
292284
293- if dirty {
294- return m .unlockErr (ErrDirty {curVersion })
295- }
296-
297285 ret := make (chan interface {}, m .PrefetchMigrations )
298286
299287 go m .readUp (curVersion , - 1 , ret )
@@ -307,15 +295,11 @@ func (m *Migrate) Down() error {
307295 return err
308296 }
309297
310- curVersion , dirty , err := m .databaseDrv . Version ()
298+ curVersion , err := m .ensureCleanCurrentSQLVersion ()
311299 if err != nil {
312300 return m .unlockErr (err )
313301 }
314302
315- if dirty {
316- return m .unlockErr (ErrDirty {curVersion })
317- }
318-
319303 ret := make (chan interface {}, m .PrefetchMigrations )
320304 go m .readDown (curVersion , - 1 , ret )
321305 return m .unlockErr (m .runMigrations (ret ))
@@ -345,15 +329,11 @@ func (m *Migrate) Run(migration ...*Migration) error {
345329 return err
346330 }
347331
348- curVersion , dirty , err := m .databaseDrv . Version ()
332+ _ , err := m .ensureCleanCurrentSQLVersion ()
349333 if err != nil {
350334 return m .unlockErr (err )
351335 }
352336
353- if dirty {
354- return m .unlockErr (ErrDirty {curVersion })
355- }
356-
357337 ret := make (chan interface {}, m .PrefetchMigrations )
358338
359339 go func () {
@@ -542,6 +522,54 @@ func (m *Migrate) read(from int, to int, ret chan<- interface{}) {
542522 }
543523}
544524
525+ // ensureCleanCurrentSQLVersion returns the database's current SQL migration
526+ // version in a clean (non-dirty) state. If the database is dirty, it returns
527+ // ErrDirty.
528+ //
529+ // If the current version when executing this function is a clean migrate task
530+ // version (meaning a migration task previously failed after the SQL migration
531+ // applied), this method re-executes the task for the associated SQL migration
532+ // version. If successful, the function normalizes the recorded version to the
533+ // SQL target version in a clean state so subsequent migrations can proceed.
534+ //
535+ // NOTE: The caller must hold the lock when calling this method.
536+ func (m * Migrate ) ensureCleanCurrentSQLVersion () (int , error ) {
537+ curVersion , dirty , err := m .databaseDrv .Version ()
538+ if err != nil {
539+ return curVersion , err
540+ }
541+
542+ if dirty {
543+ return curVersion , ErrDirty {curVersion }
544+ }
545+
546+ // If the current version is a clean migration task version, then we
547+ // need to rerun the task for the previous version before we can
548+ // continue with any SQL migration(s). We can be certain here that the
549+ // task was attempted to be run before, but errored. This is since
550+ // the migration function only sets the version to a **clean** (i.e. not
551+ // dirty) **task** version if the task errored on the last attempt.
552+ if InTaskVersionRange (curVersion ) {
553+ sqlMigVersion := SQLMigrationVersion (curVersion )
554+
555+ err = m .execTaskAtMigVersion (sqlMigVersion )
556+ if err != nil {
557+ return curVersion , err
558+ }
559+
560+ curVersion , dirty , err = m .databaseDrv .Version ()
561+ if err != nil {
562+ return curVersion , err
563+ }
564+
565+ if dirty {
566+ return curVersion , ErrDirty {curVersion }
567+ }
568+ }
569+
570+ return curVersion , nil
571+ }
572+
545573// readUp reads up migrations from `from` limited by `limit`.
546574// limit can be -1, implying no limit and reading until there are no more migrations.
547575// Each migration is then written to the ret channel.
@@ -732,6 +760,30 @@ func (m *Migrate) readDown(from int, limit int, ret chan<- interface{}) {
732760 }
733761}
734762
763+ // readSingle reads a single migration for the given version, and sends it
764+ // over the passed channel.
765+ func (m * Migrate ) readSingle (ver uint , ret chan <- interface {}) {
766+ defer close (ret )
767+
768+ if err := m .versionExists (ver ); err != nil {
769+ ret <- err
770+ return
771+ }
772+
773+ migr , err := m .newMigration (ver , int (ver ))
774+ if err != nil {
775+ ret <- err
776+ return
777+ }
778+
779+ ret <- migr
780+ go func () {
781+ if err := migr .Buffer (); err != nil {
782+ m .logErr (err )
783+ }
784+ }()
785+ }
786+
735787// runMigrations reads *Migration and error from a channel. Any other type
736788// sent on this channel will result in a panic. Each migration is then
737789// proxied to the database driver and run against the database.
@@ -752,6 +804,12 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error {
752804 case * Migration :
753805 migr := r
754806
807+ if migr .Version >= TaskVersionOffset {
808+ return fmt .Errorf ("migration version %v is " +
809+ "invalid, must be < %v" , migr .Version ,
810+ TaskVersionOffset )
811+ }
812+
755813 // set version with dirty state
756814 if err := m .databaseDrv .SetVersion (migr .TargetVersion , true ); err != nil {
757815 return err
@@ -763,23 +821,10 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error {
763821 return err
764822 }
765823
766- // If there is a task function for this
767- // migration, run it now.
768- cb , ok := m .opts .tasks [migr .Version ]
769- if ok {
770- m .logVerbosePrintf ("Running migration " +
771- "task for %v\n " , migr .LogString ())
772-
773- err := cb (migr , m .databaseDrv )
774- if err != nil {
775- return fmt .Errorf ("failed to " +
776- "execute migration " +
777- "task: %w" ,
778- err )
779- }
780-
781- m .logVerbosePrintf ("Migration task " +
782- "finished for %v\n " , migr .LogString ())
824+ err := m .execTask (migr )
825+ if err != nil {
826+ return fmt .Errorf ("migration task " +
827+ "error: %w" , err )
783828 }
784829 }
785830
@@ -808,6 +853,117 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error {
808853 return nil
809854}
810855
856+ // execTask checks if a migration task exists for the passed migration and
857+ // proceeds to execute if one exists.
858+ func (m * Migrate ) execTask (migr * Migration ) error {
859+ task , ok := m .opts .tasks [migr .Version ]
860+ if ! ok {
861+ m .logVerbosePrintf ("No migration task set for %v\n " ,
862+ migr .LogString ())
863+
864+ return nil
865+ }
866+
867+ m .logVerbosePrintf ("Running migration task for %v\n " , migr .LogString ())
868+
869+ taskVersion := int (migr .Version ) + TaskVersionOffset
870+
871+ // Persist that we are in the migration task phase for this version.
872+ if err := m .databaseDrv .SetVersion (taskVersion , true ); err != nil {
873+ return err
874+ }
875+
876+ err := task (migr , m .databaseDrv )
877+ if err != nil {
878+ // Mark the database version as the taskVersion but in a clean
879+ // state, to indicate that the migration task errored. We will
880+ // therefore re-run the task on the next migration run.
881+ // The definition for the migration task version is that
882+ // the database version is only ever set to a migration task
883+ // version in a clean state if the task errored during its
884+ // execution. We therefore mark the state as clean here, so that
885+ // the migration task will be re-executed until it succeeds.
886+ setErr := m .databaseDrv .SetVersion (taskVersion , false )
887+ if setErr != nil {
888+ // Note that if we error here, the database version will
889+ // remain in a dirty state. As we cannot know if the
890+ // migration task was executed or not in that scenario,
891+ // manual intervention is required.
892+ return fmt .Errorf ("WARNING, failed to set migration " +
893+ "version after migration task errored. Manual " +
894+ "intervention needed! Migration task error: " +
895+ "%w, version setting error : %w" , err , setErr )
896+ }
897+
898+ return fmt .Errorf ("failed to execute migration task: %w" , err )
899+ }
900+
901+ m .logVerbosePrintf ("Migration task finished for %v\n " , migr .LogString ())
902+
903+ return nil
904+ }
905+
906+ // execTaskAtMigVersion executes only the migration task for the passed SQL
907+ // migration version.
908+ // The function can be used to re-execute the task for a SQL migration version
909+ // where the SQL migration was successfully applied, but where the task failed.
910+ func (m * Migrate ) execTaskAtMigVersion (sqlMigVersion int ) error {
911+ var (
912+ r interface {}
913+ migRet = make (chan interface {}, m .PrefetchMigrations )
914+ err error
915+ )
916+
917+ // Fetch the migration for the specified SQL migration version.
918+ go m .readSingle (uint (sqlMigVersion ), migRet )
919+
920+ select {
921+ case r = <- migRet :
922+ case <- time .After (DefaultSingleMigReadTimeout ):
923+ return fmt .Errorf ("timeout waiting for single migration " +
924+ "version %v" , sqlMigVersion )
925+ }
926+
927+ if m .stop () {
928+ return nil
929+ }
930+
931+ switch r := r .(type ) {
932+ case * Migration :
933+ // If the migration was found, execute the migration task.
934+ migr := r
935+
936+ err = m .execTask (migr )
937+ if err != nil {
938+ return fmt .Errorf ("execution of migration task for " +
939+ "SQL migration version %d failed: %w" ,
940+ sqlMigVersion , err )
941+ }
942+
943+ m .logVerbosePrintf ("successfully re-executed migration task " +
944+ "for SQL migration version: %v\n " , sqlMigVersion )
945+
946+ // After the migration task has been executed successfully, we
947+ // set the db version to the SQL migration target version with a
948+ // clean state, as we can now proceed with the next migrations,
949+ // if any.
950+ err = m .databaseDrv .SetVersion (migr .TargetVersion , false )
951+ if err != nil {
952+ return err
953+ }
954+
955+ return nil
956+
957+ case error :
958+ return fmt .Errorf ("reading SQL migration at version " +
959+ "%v failed: %w" , sqlMigVersion , r )
960+
961+ default :
962+ return fmt .Errorf ("unknown type: %T when reading " +
963+ "single migration" , r )
964+ }
965+ }
966+
811967// versionExists checks the source if either the up or down migration for
812968// the specified migration version exists.
813969func (m * Migrate ) versionExists (version uint ) (result error ) {
0 commit comments