55package migrate
66
77import (
8+ "bytes"
89 "errors"
910 "fmt"
11+ "io"
1012 "os"
13+ "regexp"
14+ "strings"
1115 "sync"
1216 "time"
1317
@@ -732,6 +736,34 @@ func (m *Migrate) readDown(from int, limit int, ret chan<- interface{}) {
732736 }
733737}
734738
739+ // hasSQLMigration checks if the passed data contains executable statements,
740+ // meaning that the data doesn't only contain comments/whitespace or semicolons.
741+ func (m * Migrate ) hasSQLMigration (data []byte ) (bool , error ) {
742+ s := string (data )
743+
744+ // Remove Byte Order Mark (BOM) if present in the migration file.
745+ s = strings .TrimPrefix (s , "\uFEFF " )
746+
747+ // Strip block comments /* ... */ (non-greedy, across lines).
748+ reBlock := regexp .MustCompile (`(?s)/\*.*?\*/` )
749+ s = reBlock .ReplaceAllString (s , "" )
750+
751+ // Strip line comments -- ... (to end of line).
752+ reLine := regexp .MustCompile (`(?m)--[^\n\r]*` )
753+ s = reLine .ReplaceAllString (s , "" )
754+
755+ // Trim whitespaces.
756+ s = strings .TrimSpace (s )
757+
758+ // Remove any semicolons, newlines, tabs, or spaces from the beginning
759+ // and end of the string.
760+ s = strings .Trim (s , ";\r \n \t " )
761+
762+ // If the string still contains any characters, the data likely
763+ // contains executable statements.
764+ return len (s ) > 0 , nil
765+ }
766+
735767// runMigrations reads *Migration and error from a channel. Any other type
736768// sent on this channel will result in a panic. Each migration is then
737769// proxied to the database driver and run against the database.
@@ -752,34 +784,58 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error {
752784 case * Migration :
753785 migr := r
754786
755- // set version with dirty state
756- if err := m .databaseDrv .SetVersion (migr .TargetVersion , true ); err != nil {
757- return err
758- }
759-
760787 if migr .Body != nil {
761- m .logVerbosePrintf ("Read and execute %v\n " , migr .LogString ())
762- if err := m .databaseDrv .Run (migr .BufferedBody ); err != nil {
788+ // Read the body so we can inspect and (re)use it.
789+ data , err := io .ReadAll (migr .BufferedBody )
790+ if err != nil {
791+ return fmt .Errorf ("read migration body: %w" , err )
792+ }
793+
794+ // Reset the reader so the driver can read it
795+ migr .BufferedBody = bytes .NewReader (data )
796+
797+ // Check if the migration contains an SQL
798+ // migration.
799+ hasSqlMig , err := m .hasSQLMigration (data )
800+ if err != nil {
763801 return err
764802 }
765803
766- // If there is a task function for this
767- // migration, run it now.
768- cb , ok := m .opts .tasks [migr .Version ]
769- if ok {
770- m .logVerbosePrintf ("Running migration " +
771- "task for %v\n " , migr .LogString ())
804+ // Check if the migration contains a migration
805+ // task.
806+ _ , hasMigTask := m .opts .tasks [migr .Version ]
807+
808+ // Execute the SQL migration or the migration
809+ // task.
810+ switch {
811+ case hasSqlMig && hasMigTask :
812+ return fmt .Errorf ("migration has both " +
813+ "a SQL migration and a " +
814+ "migration task set" )
815+
816+ case hasSqlMig :
817+ if err = m .databaseDrv .SetVersion (migr .TargetVersion , true ); err != nil {
818+ return err
819+ }
772820
773- err := cb (migr , m .databaseDrv )
821+ m .logVerbosePrintf ("Read and execute %v\n " , migr .LogString ())
822+ if err = m .databaseDrv .Run (migr .BufferedBody ); err != nil {
823+ return err
824+ }
825+
826+ case hasMigTask :
827+ err = m .execTask (migr )
774828 if err != nil {
775- return fmt .Errorf ("failed to " +
776- "execute migration " +
777- "task: %w" ,
778- err )
829+ return fmt .Errorf ("migration " +
830+ "task execution " +
831+ "failed: %w" , err )
779832 }
780833
781- m .logVerbosePrintf ("Migration task " +
782- "finished for %v\n " , migr .LogString ())
834+ default :
835+ // When the migration contains no SQL
836+ // migration or migration task, we
837+ // continue and set the version to the
838+ // migr.TargetVersion.
783839 }
784840 }
785841
@@ -808,6 +864,59 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error {
808864 return nil
809865}
810866
867+ // execTask checks if a migration task exists for the passed migration and
868+ // proceeds to execute if one exists. If the migration task fails, the function
869+ // will reset the database version to the version it was set to before
870+ // attempting to execute the migration task.
871+ func (m * Migrate ) execTask (migr * Migration ) error {
872+ m .logVerbosePrintf ("Running migration task for %v\n " , migr .LogString ())
873+
874+ task , ok := m .opts .tasks [migr .Version ]
875+ if ! ok {
876+ return fmt .Errorf ("no migration task set for %v" ,
877+ migr .LogString ())
878+ }
879+
880+ // Get the current database version before executing the migration task.
881+ curVersion , dirty , err := m .databaseDrv .Version ()
882+ if err != nil {
883+ return fmt .Errorf ("unable to get current version: %w" , err )
884+ }
885+
886+ if dirty {
887+ return ErrDirty {curVersion }
888+ }
889+
890+ // Persist that we are at the migration version of the migration task.
891+ if err = m .databaseDrv .SetVersion (int (migr .Version ), true ); err != nil {
892+ return err
893+ }
894+
895+ err = task (migr , m .databaseDrv )
896+ if err != nil {
897+ // Reset the version to the version set before executing the
898+ // migration task. Therefore, the migration task will be
899+ // re-executed on nnext startup until it succeeds.
900+ setErr := m .databaseDrv .SetVersion (curVersion , false )
901+ if setErr != nil {
902+ // Note that if we error here, the database version will
903+ // remain in a dirty state. As we cannot know if the
904+ // migration task was executed or not in that scenario,
905+ // manual intervention is required.
906+ return fmt .Errorf ("WARNING, failed to set migration " +
907+ "version after migration task errored. Manual " +
908+ "intervention needed! Migration task error: " +
909+ "%w, version setting error : %w" , err , setErr )
910+ }
911+
912+ return fmt .Errorf ("failed to execute migration task: %w" , err )
913+ }
914+
915+ m .logVerbosePrintf ("Migration task finished for %v\n " , migr .LogString ())
916+
917+ return nil
918+ }
919+
811920// versionExists checks the source if either the up or down migration for
812921// the specified migration version exists.
813922func (m * Migrate ) versionExists (version uint ) (result error ) {
0 commit comments