@@ -241,6 +241,23 @@ func (m *Migrate) Migrate(version uint) error {
241241 return m .unlockErr (ErrDirty {curVersion })
242242 }
243243
244+ // If the current version is a clean migration task version, then we
245+ // need to rerun the task for the previous version before we can
246+ // continue with any SQL migration(s).
247+ if InTaskVersionRange (curVersion ) {
248+ sqlMigVersion := SQLMigrationVersion (curVersion )
249+
250+ err := m .execTaskAtMigVersion (sqlMigVersion )
251+ if err != nil {
252+ return m .unlockErr (err )
253+ }
254+
255+ curVersion , dirty , err = m .databaseDrv .Version ()
256+ if err != nil {
257+ return m .unlockErr (err )
258+ }
259+ }
260+
244261 ret := make (chan interface {}, m .PrefetchMigrations )
245262 go m .read (curVersion , int (version ), ret )
246263
@@ -267,6 +284,23 @@ func (m *Migrate) Steps(n int) error {
267284 return m .unlockErr (ErrDirty {curVersion })
268285 }
269286
287+ // If the current version is a clean migration task version, then we
288+ // need to rerun the task for the previous version before we can
289+ // continue with any SQL migration(s).
290+ if InTaskVersionRange (curVersion ) {
291+ sqlMigVersion := SQLMigrationVersion (curVersion )
292+
293+ err := m .execTaskAtMigVersion (sqlMigVersion )
294+ if err != nil {
295+ return m .unlockErr (err )
296+ }
297+
298+ curVersion , dirty , err = m .databaseDrv .Version ()
299+ if err != nil {
300+ return m .unlockErr (err )
301+ }
302+ }
303+
270304 ret := make (chan interface {}, m .PrefetchMigrations )
271305
272306 if n > 0 {
@@ -294,6 +328,23 @@ func (m *Migrate) Up() error {
294328 return m .unlockErr (ErrDirty {curVersion })
295329 }
296330
331+ // If the current version is a clean migration task version, then we
332+ // need to rerun the task for the previous version before we can
333+ // continue with any SQL migration(s).
334+ if InTaskVersionRange (curVersion ) {
335+ sqlMigVersion := SQLMigrationVersion (curVersion )
336+
337+ err := m .execTaskAtMigVersion (sqlMigVersion )
338+ if err != nil {
339+ return m .unlockErr (err )
340+ }
341+
342+ curVersion , dirty , err = m .databaseDrv .Version ()
343+ if err != nil {
344+ return m .unlockErr (err )
345+ }
346+ }
347+
297348 ret := make (chan interface {}, m .PrefetchMigrations )
298349
299350 go m .readUp (curVersion , - 1 , ret )
@@ -316,6 +367,23 @@ func (m *Migrate) Down() error {
316367 return m .unlockErr (ErrDirty {curVersion })
317368 }
318369
370+ // If the current version is a clean migration task version, then we
371+ // need to rerun the task for the previous version before we can
372+ // continue with any SQL migration(s).
373+ if InTaskVersionRange (curVersion ) {
374+ sqlMigVersion := SQLMigrationVersion (curVersion )
375+
376+ err := m .execTaskAtMigVersion (sqlMigVersion )
377+ if err != nil {
378+ return m .unlockErr (err )
379+ }
380+
381+ curVersion , dirty , err = m .databaseDrv .Version ()
382+ if err != nil {
383+ return m .unlockErr (err )
384+ }
385+ }
386+
319387 ret := make (chan interface {}, m .PrefetchMigrations )
320388 go m .readDown (curVersion , - 1 , ret )
321389 return m .unlockErr (m .runMigrations (ret ))
@@ -354,6 +422,23 @@ func (m *Migrate) Run(migration ...*Migration) error {
354422 return m .unlockErr (ErrDirty {curVersion })
355423 }
356424
425+ // If the current version is a clean migration task version, then
426+ // we need to rerun the task for the previous version before we can
427+ // continue with any SQL migration(s).
428+ if InTaskVersionRange (curVersion ) {
429+ sqlMigVersion := SQLMigrationVersion (curVersion )
430+
431+ err := m .execTaskAtMigVersion (sqlMigVersion )
432+ if err != nil {
433+ return m .unlockErr (err )
434+ }
435+
436+ curVersion , dirty , err = m .databaseDrv .Version ()
437+ if err != nil {
438+ return m .unlockErr (err )
439+ }
440+ }
441+
357442 ret := make (chan interface {}, m .PrefetchMigrations )
358443
359444 go func () {
@@ -732,6 +817,30 @@ func (m *Migrate) readDown(from int, limit int, ret chan<- interface{}) {
732817 }
733818}
734819
820+ // readSingle reads a single migration for the given version, and sends it
821+ // over the passed channel.
822+ func (m * Migrate ) readSingle (ver uint , ret chan <- interface {}) {
823+ defer close (ret )
824+
825+ if err := m .versionExists (ver ); err != nil {
826+ ret <- err
827+ return
828+ }
829+
830+ migr , err := m .newMigration (ver , int (ver ))
831+ if err != nil {
832+ ret <- err
833+ return
834+ }
835+
836+ ret <- migr
837+ go func () {
838+ if err := migr .Buffer (); err != nil {
839+ m .logErr (err )
840+ }
841+ }()
842+ }
843+
735844// runMigrations reads *Migration and error from a channel. Any other type
736845// sent on this channel will result in a panic. Each migration is then
737846// proxied to the database driver and run against the database.
@@ -752,6 +861,12 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error {
752861 case * Migration :
753862 migr := r
754863
864+ if migr .Version >= TaskVersionOffset {
865+ return fmt .Errorf ("migration version %v is " +
866+ "invalid, must be < %v" , migr .Version ,
867+ TaskVersionOffset )
868+ }
869+
755870 // set version with dirty state
756871 if err := m .databaseDrv .SetVersion (migr .TargetVersion , true ); err != nil {
757872 return err
@@ -763,23 +878,9 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error {
763878 return err
764879 }
765880
766- // If there is a task function for this
767- // migration, run it now.
768- cb , ok := m .opts .tasks [migr .Version ]
769- if ok {
770- m .logVerbosePrintf ("Running migration " +
771- "task for %v\n " , migr .LogString ())
772-
773- err := cb (migr , m .databaseDrv )
774- if err != nil {
775- return fmt .Errorf ("failed to " +
776- "execute migration " +
777- "task: %w" ,
778- err )
779- }
780-
781- m .logVerbosePrintf ("Migration task " +
782- "finished for %v\n " , migr .LogString ())
881+ err := m .execTask (migr )
882+ if err != nil {
883+ return err
783884 }
784885 }
785886
@@ -808,6 +909,108 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error {
808909 return nil
809910}
810911
912+ // execTask checks if a migration task exists for the passed migration and
913+ // proceeds to execute if one exists.
914+ func (m * Migrate ) execTask (migr * Migration ) error {
915+ task , ok := m .opts .tasks [migr .Version ]
916+ if ok {
917+ m .logVerbosePrintf ("Running migration task for %v\n " ,
918+ migr .LogString ())
919+
920+ taskVersion := int (migr .Version ) + TaskVersionOffset
921+
922+ // Persist that we are in the migration task phase for this
923+ // version.
924+ if err := m .databaseDrv .SetVersion (taskVersion , true ); err != nil {
925+ return err
926+ }
927+
928+ err := task (migr , m .databaseDrv )
929+ if err != nil {
930+ // Mark the database version as the taskVersion but in a
931+ // clean state, to indicate that the migration task
932+ // errored. We will therefore re-run the task on the
933+ // next migration run.
934+ if setErr := m .databaseDrv .SetVersion (taskVersion , false ); setErr != nil {
935+ // Note that if we error here, the database
936+ // version will remain in a dirty state. As we
937+ // cannot know if the migration task was
938+ // executed or not in that scenario, manual
939+ // intervention is required.
940+ return fmt .Errorf ("WARNING, failed to set " +
941+ "migration version after migration " +
942+ "task errored. Manual intervention " +
943+ "needed! Migration task error: %w, " +
944+ "version setting error : %w" ,
945+ err , setErr )
946+ }
947+
948+ return fmt .Errorf ("failed to execute migration " +
949+ "task: %w" , err )
950+ }
951+
952+ m .logVerbosePrintf ("Migration task finished for %v\n " ,
953+ migr .LogString ())
954+ }
955+
956+ return nil
957+ }
958+
959+ // execTaskAtMigVersion executes only the migration task for the passed SQL
960+ // migration version.
961+ // The function can be used to re-execute the task for a SQL migration version
962+ // where the SQL migration was successfully applied, but where the task failed.
963+ func (m * Migrate ) execTaskAtMigVersion (sqlMigVersion int ) error {
964+ var (
965+ r interface {}
966+ migRet = make (chan interface {}, m .PrefetchMigrations )
967+ err error
968+ )
969+
970+ // Fetch the migration for the specified SQL migration version.
971+ go m .readSingle (uint (sqlMigVersion ), migRet )
972+
973+ select {
974+ case r = <- migRet :
975+ case <- time .After (30 * time .Second ):
976+ return fmt .Errorf ("timeout waiting for single migration " +
977+ "version %v" , sqlMigVersion )
978+ }
979+
980+ if m .stop () {
981+ return nil
982+ }
983+
984+ switch r := r .(type ) {
985+ case * Migration :
986+ // If the migration was found, execute the migration task.
987+ migr := r
988+
989+ err = m .execTask (migr )
990+ if err != nil {
991+ return err
992+ }
993+
994+ m .logVerbosePrintf ("successfully re-executed migration task " +
995+ "for SQL migration version: %v\n " , sqlMigVersion )
996+
997+ // set clean state
998+ if err = m .databaseDrv .SetVersion (migr .TargetVersion , false ); err != nil {
999+ return err
1000+ }
1001+
1002+ return nil
1003+
1004+ case error :
1005+ return fmt .Errorf ("reading SQL migration at version " +
1006+ "%v failed: %w" , sqlMigVersion , r )
1007+
1008+ default :
1009+ return fmt .Errorf ("unknown type: %T when reading " +
1010+ "single migration" , r )
1011+ }
1012+ }
1013+
8111014// versionExists checks the source if either the up or down migration for
8121015// the specified migration version exists.
8131016func (m * Migrate ) versionExists (version uint ) (result error ) {
0 commit comments