@@ -7,7 +7,11 @@ package migrate
77import (
88 "errors"
99 "fmt"
10+ "net/url"
1011 "os"
12+ "path/filepath"
13+ "strconv"
14+ "strings"
1115 "sync"
1216 "time"
1317
3640 ErrLockTimeout = errors .New ("timeout: can't acquire database lock" )
3741)
3842
43+ // Define a constant for the migration file name
44+ const lastSuccessfulMigrationFile = "lastSuccessfulMigration"
45+
3946// ErrShortLimit is an error returned when not enough migrations
4047// can be returned by a source for a given limit.
4148type ErrShortLimit struct {
@@ -80,6 +87,21 @@ type Migrate struct {
8087 // LockTimeout defaults to DefaultLockTimeout,
8188 // but can be set per Migrate instance.
8289 LockTimeout time.Duration
90+
91+ // DirtyStateHandler is used to handle dirty state of the database
92+ dirtyStateConf * dirtyStateHandler
93+ }
94+
95+ type dirtyStateHandler struct {
96+ srcScheme string
97+ srcPath string
98+ destScheme string
99+ destPath string
100+ enable bool
101+ }
102+
103+ func (m * Migrate ) IsDirtyHandlingEnabled () bool {
104+ return m .dirtyStateConf != nil && m .dirtyStateConf .enable && m .dirtyStateConf .destPath != ""
83105}
84106
85107// New returns a new Migrate instance from a source URL and a database URL.
@@ -114,6 +136,20 @@ func New(sourceURL, databaseURL string) (*Migrate, error) {
114136 return m , nil
115137}
116138
139+ func (m * Migrate ) updateSourceDrv (sourceURL string ) error {
140+ sourceName , err := iurl .SchemeFromURL (sourceURL )
141+ if err != nil {
142+ return fmt .Errorf ("failed to parse scheme from source URL: %w" , err )
143+ }
144+ m .sourceName = sourceName
145+ sourceDrv , err := source .Open (sourceURL )
146+ if err != nil {
147+ return fmt .Errorf ("failed to open source, %q: %w" , sourceURL , err )
148+ }
149+ m .sourceDrv = sourceDrv
150+ return nil
151+ }
152+
117153// NewWithDatabaseInstance returns a new Migrate instance from a source URL
118154// and an existing database instance. The source URL scheme is defined by each driver.
119155// Use any string that can serve as an identifier during logging as databaseName.
@@ -182,6 +218,42 @@ func NewWithInstance(sourceName string, sourceInstance source.Driver, databaseNa
182218 return m , nil
183219}
184220
221+ func (m * Migrate ) WithDirtyStateHandler (srcPath , destPath string , isDirty bool ) error {
222+ parser := func (path string ) (string , string , error ) {
223+ var scheme , p string
224+ uri , err := url .Parse (path )
225+ if err != nil {
226+ return "" , "" , err
227+ }
228+ scheme = uri .Scheme
229+ p = uri .Path
230+ // if no scheme is provided, assume it's a file path
231+ if scheme == "" {
232+ scheme = "file://"
233+ }
234+ return scheme , p , nil
235+ }
236+
237+ sScheme , sPath , err := parser (srcPath )
238+ if err != nil {
239+ return err
240+ }
241+
242+ dScheme , dPath , err := parser (destPath )
243+ if err != nil {
244+ return err
245+ }
246+
247+ m .dirtyStateConf = & dirtyStateHandler {
248+ srcScheme : sScheme ,
249+ destScheme : dScheme ,
250+ srcPath : sPath ,
251+ destPath : dPath ,
252+ enable : isDirty ,
253+ }
254+ return nil
255+ }
256+
185257func newCommon () * Migrate {
186258 return & Migrate {
187259 GracefulStop : make (chan bool , 1 ),
@@ -215,20 +287,42 @@ func (m *Migrate) Migrate(version uint) error {
215287 if err := m .lock (); err != nil {
216288 return err
217289 }
218-
219290 curVersion , dirty , err := m .databaseDrv .Version ()
220291 if err != nil {
221292 return m .unlockErr (err )
222293 }
223294
295+ // if the dirty flag is passed to the 'goto' command, handle the dirty state
224296 if dirty {
225- return m .unlockErr (ErrDirty {curVersion })
297+ if m .IsDirtyHandlingEnabled () {
298+ if err = m .handleDirtyState (); err != nil {
299+ return m .unlockErr (err )
300+ }
301+ } else {
302+ // default behavior
303+ return m .unlockErr (ErrDirty {curVersion })
304+ }
305+ }
306+
307+ // Copy migrations to the destination directory,
308+ // if state was dirty when Migrate was called, we should handle the dirty state first before copying the migrations
309+ if err = m .copyFiles (); err != nil {
310+ return m .unlockErr (err )
226311 }
227312
228313 ret := make (chan interface {}, m .PrefetchMigrations )
229314 go m .read (curVersion , int (version ), ret )
230315
231- return m .unlockErr (m .runMigrations (ret ))
316+ if err = m .runMigrations (ret ); err != nil {
317+ return m .unlockErr (err )
318+ }
319+ // Success: Clean up and confirm
320+ // Files are cleaned up after the migration is successful
321+ if err = m .cleanupFiles (version ); err != nil {
322+ return m .unlockErr (err )
323+ }
324+ // unlock the database
325+ return m .unlock ()
232326}
233327
234328// Steps looks at the currently active migration version.
@@ -723,6 +817,7 @@ func (m *Migrate) readDown(from int, limit int, ret chan<- interface{}) {
723817// to stop execution because it might have received a stop signal on the
724818// GracefulStop channel.
725819func (m * Migrate ) runMigrations (ret <- chan interface {}) error {
820+ var lastCleanMigrationApplied int
726821 for r := range ret {
727822
728823 if m .stop () {
@@ -744,6 +839,15 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error {
744839 if migr .Body != nil {
745840 m .logVerbosePrintf ("Read and execute %v\n " , migr .LogString ())
746841 if err := m .databaseDrv .Run (migr .BufferedBody ); err != nil {
842+ if m .dirtyStateConf != nil && m .dirtyStateConf .enable {
843+ // this condition is required if the first migration fails
844+ if lastCleanMigrationApplied == 0 {
845+ lastCleanMigrationApplied = migr .TargetVersion
846+ }
847+ if e := m .handleMigrationFailure (lastCleanMigrationApplied ); e != nil {
848+ return multierror .Append (err , e )
849+ }
850+ }
747851 return err
748852 }
749853 }
@@ -752,7 +856,7 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error {
752856 if err := m .databaseDrv .SetVersion (migr .TargetVersion , false ); err != nil {
753857 return err
754858 }
755-
859+ lastCleanMigrationApplied = migr . TargetVersion
756860 endTime := time .Now ()
757861 readTime := migr .FinishedReading .Sub (migr .StartedBuffering )
758862 runTime := endTime .Sub (migr .FinishedReading )
@@ -979,3 +1083,114 @@ func (m *Migrate) logErr(err error) {
9791083 m .Log .Printf ("error: %v" , err )
9801084 }
9811085}
1086+
1087+ func (m * Migrate ) handleDirtyState () error {
1088+ // Perform the following actions when the database state is dirty
1089+ /*
1090+ 1. Update the source driver to read the migrations from the mounted volume
1091+ 2. Read the last successful migration version from the file
1092+ 3. Set the last successful migration version in the schema_migrations table
1093+ 4. Delete the last successful migration file
1094+ */
1095+ // the source driver should read the migrations from the mounted volume
1096+ // as the DB is dirty and last applied migrations to the database are not present in the source path
1097+ if err := m .updateSourceDrv (m .dirtyStateConf .destScheme + m .dirtyStateConf .destPath ); err != nil {
1098+ return err
1099+ }
1100+ lastSuccessfulMigrationPath := filepath .Join (m .dirtyStateConf .destPath , lastSuccessfulMigrationFile )
1101+ lastVersionBytes , err := os .ReadFile (lastSuccessfulMigrationPath )
1102+ if err != nil {
1103+ return err
1104+ }
1105+ lastVersionStr := strings .TrimSpace (string (lastVersionBytes ))
1106+ lastVersion , err := strconv .ParseInt (lastVersionStr , 10 , 64 )
1107+ if err != nil {
1108+ return fmt .Errorf ("failed to parse last successful migration version: %w" , err )
1109+ }
1110+
1111+ // Set the last successful migration version in the schema_migrations table
1112+ if err = m .databaseDrv .SetVersion (int (lastVersion ), false ); err != nil {
1113+ return fmt .Errorf ("failed to apply last successful migration: %w" , err )
1114+ }
1115+
1116+ m .logPrintf ("Successfully set last successful migration version: %s on the DB" , lastVersionStr )
1117+
1118+ if err = os .Remove (lastSuccessfulMigrationPath ); err != nil {
1119+ return err
1120+ }
1121+
1122+ m .logPrintf ("Successfully deleted file: %s" , lastSuccessfulMigrationPath )
1123+ return nil
1124+ }
1125+
1126+ func (m * Migrate ) handleMigrationFailure (lastSuccessfulMigration int ) error {
1127+ if ! m .IsDirtyHandlingEnabled () {
1128+ return nil
1129+ }
1130+ lastSuccessfulMigrationPath := filepath .Join (m .dirtyStateConf .destPath , lastSuccessfulMigrationFile )
1131+ return os .WriteFile (lastSuccessfulMigrationPath , []byte (strconv .Itoa (lastSuccessfulMigration )), 0644 )
1132+ }
1133+
1134+ func (m * Migrate ) cleanupFiles (targetVersion uint ) error {
1135+ if ! m .IsDirtyHandlingEnabled () {
1136+ return nil
1137+ }
1138+
1139+ files , err := os .ReadDir (m .dirtyStateConf .destPath )
1140+ if err != nil {
1141+ // If the directory does not exist
1142+ return fmt .Errorf ("failed to read directory %s: %w" , m .dirtyStateConf .destPath , err )
1143+ }
1144+
1145+ for _ , file := range files {
1146+ fileName := file .Name ()
1147+ migration , err := source .Parse (fileName )
1148+ if err != nil {
1149+ return err
1150+ }
1151+ // Delete file if version is greater than targetVersion
1152+ if migration .Version > targetVersion {
1153+ if err = os .Remove (filepath .Join (m .dirtyStateConf .destPath , fileName )); err != nil {
1154+ m .logErr (fmt .Errorf ("failed to delete file %s: %v" , fileName , err ))
1155+ continue
1156+ }
1157+ m .logPrintf ("Migration file: %s removed during cleanup" , fileName )
1158+ }
1159+ }
1160+
1161+ return nil
1162+ }
1163+
1164+ // copyFiles copies all files from source to destination volume.
1165+ func (m * Migrate ) copyFiles () error {
1166+ // this is the case when the dirty handling is disabled
1167+ if ! m .IsDirtyHandlingEnabled () {
1168+ return nil
1169+ }
1170+
1171+ files , err := os .ReadDir (m .dirtyStateConf .srcPath )
1172+ if err != nil {
1173+ // If the directory does not exist
1174+ return fmt .Errorf ("failed to read directory %s: %w" , m .dirtyStateConf .srcPath , err )
1175+ }
1176+ m .logPrintf ("Copying files from %s to %s" , m .dirtyStateConf .srcPath , m .dirtyStateConf .destPath )
1177+ for _ , file := range files {
1178+ fileName := file .Name ()
1179+ if source .Regex .MatchString (fileName ) {
1180+ fileContentBytes , err := os .ReadFile (filepath .Join (m .dirtyStateConf .srcPath , fileName ))
1181+ if err != nil {
1182+ return err
1183+ }
1184+ info , err := file .Info ()
1185+ if err != nil {
1186+ return err
1187+ }
1188+ if err = os .WriteFile (filepath .Join (m .dirtyStateConf .destPath , fileName ), fileContentBytes , info .Mode ().Perm ()); err != nil {
1189+ return err
1190+ }
1191+ }
1192+ }
1193+
1194+ m .logPrintf ("Successfully Copied files from %s to %s" , m .dirtyStateConf .srcPath , m .dirtyStateConf .destPath )
1195+ return nil
1196+ }
0 commit comments