Skip to content

Commit 13318b2

Browse files
authored
Goto command support to handle dirty Database state (#36)
* First pass at downmigrate * First pass at downmigrate * downmigrate changes + UTs + cleanups * address comments
1 parent 6caa1d9 commit 13318b2

File tree

4 files changed

+520
-6
lines changed

4 files changed

+520
-6
lines changed

cmd/migrate/config.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,8 @@ var (
3535

3636
flagConfigDirectory = pflag.String("config.source", defaultConfigDirectory, "directory of the configuration file")
3737
flagConfigFile = pflag.String("config.file", "", "configuration file name without extension")
38+
39+
// goto command flags
40+
flagDirty = pflag.Bool("force-dirty-handling", false, "force the handling of dirty database state")
41+
flagMountPath = pflag.String("cache-dir", "", "path to the mounted volume which is used to copy the migration files")
3842
)

internal/cli/main.go

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ const (
2929
Use -format option to specify a Go time format string. Note: migrations with the same time cause "duplicate migration version" error.
3030
Use -tz option to specify the timezone that will be used when generating non-sequential migrations (defaults: UTC).
3131
`
32-
gotoUsage = `goto V Migrate to version V`
32+
gotoUsage = `goto V [-force-dirty-handling] [-cache-dir P] Migrate to version V
33+
Use -force-dirty-handling to handle dirty database state
34+
Use -cache-dir to specify the intermediate path P for storing migrations`
3335
upUsage = `up [N] Apply all or N up migrations`
3436
downUsage = `down [N] [-all] Apply all or N down migrations
3537
Use -all to apply all down migrations`
@@ -262,8 +264,19 @@ Database drivers: `+strings.Join(database.List(), ", ")+"\n", createUsage, gotoU
262264
if err != nil {
263265
log.fatal("error: can't read version argument V")
264266
}
267+
handleDirty := viper.GetBool("force-dirty-handling")
268+
if handleDirty {
269+
destPath := viper.GetString("cache-dir")
270+
if destPath == "" {
271+
log.fatal("error: cache-dir must be specified when force-dirty-handling is set")
272+
}
273+
274+
if err = migrater.WithDirtyStateHandler(sourcePtr, destPath, handleDirty); err != nil {
275+
log.fatalErr(err)
276+
}
277+
}
265278

266-
if err := gotoCmd(migrater, uint(v)); err != nil {
279+
if err = gotoCmd(migrater, uint(v)); err != nil {
267280
log.fatalErr(err)
268281
}
269282

migrate.go

Lines changed: 219 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@ package migrate
77
import (
88
"errors"
99
"fmt"
10+
"net/url"
1011
"os"
12+
"path/filepath"
13+
"strconv"
14+
"strings"
1115
"sync"
1216
"time"
1317

@@ -36,6 +40,9 @@ var (
3640
ErrLockTimeout = errors.New("timeout: can't acquire database lock")
3741
)
3842

43+
// Define a constant for the migration file name
44+
const lastSuccessfulMigrationFile = "lastSuccessfulMigration"
45+
3946
// ErrShortLimit is an error returned when not enough migrations
4047
// can be returned by a source for a given limit.
4148
type ErrShortLimit struct {
@@ -80,6 +87,21 @@ type Migrate struct {
8087
// LockTimeout defaults to DefaultLockTimeout,
8188
// but can be set per Migrate instance.
8289
LockTimeout time.Duration
90+
91+
// DirtyStateHandler is used to handle dirty state of the database
92+
dirtyStateConf *dirtyStateHandler
93+
}
94+
95+
type dirtyStateHandler struct {
96+
srcScheme string
97+
srcPath string
98+
destScheme string
99+
destPath string
100+
enable bool
101+
}
102+
103+
func (m *Migrate) IsDirtyHandlingEnabled() bool {
104+
return m.dirtyStateConf != nil && m.dirtyStateConf.enable && m.dirtyStateConf.destPath != ""
83105
}
84106

85107
// New returns a new Migrate instance from a source URL and a database URL.
@@ -114,6 +136,20 @@ func New(sourceURL, databaseURL string) (*Migrate, error) {
114136
return m, nil
115137
}
116138

139+
func (m *Migrate) updateSourceDrv(sourceURL string) error {
140+
sourceName, err := iurl.SchemeFromURL(sourceURL)
141+
if err != nil {
142+
return fmt.Errorf("failed to parse scheme from source URL: %w", err)
143+
}
144+
m.sourceName = sourceName
145+
sourceDrv, err := source.Open(sourceURL)
146+
if err != nil {
147+
return fmt.Errorf("failed to open source, %q: %w", sourceURL, err)
148+
}
149+
m.sourceDrv = sourceDrv
150+
return nil
151+
}
152+
117153
// NewWithDatabaseInstance returns a new Migrate instance from a source URL
118154
// and an existing database instance. The source URL scheme is defined by each driver.
119155
// Use any string that can serve as an identifier during logging as databaseName.
@@ -182,6 +218,42 @@ func NewWithInstance(sourceName string, sourceInstance source.Driver, databaseNa
182218
return m, nil
183219
}
184220

221+
func (m *Migrate) WithDirtyStateHandler(srcPath, destPath string, isDirty bool) error {
222+
parser := func(path string) (string, string, error) {
223+
var scheme, p string
224+
uri, err := url.Parse(path)
225+
if err != nil {
226+
return "", "", err
227+
}
228+
scheme = uri.Scheme
229+
p = uri.Path
230+
// if no scheme is provided, assume it's a file path
231+
if scheme == "" {
232+
scheme = "file://"
233+
}
234+
return scheme, p, nil
235+
}
236+
237+
sScheme, sPath, err := parser(srcPath)
238+
if err != nil {
239+
return err
240+
}
241+
242+
dScheme, dPath, err := parser(destPath)
243+
if err != nil {
244+
return err
245+
}
246+
247+
m.dirtyStateConf = &dirtyStateHandler{
248+
srcScheme: sScheme,
249+
destScheme: dScheme,
250+
srcPath: sPath,
251+
destPath: dPath,
252+
enable: isDirty,
253+
}
254+
return nil
255+
}
256+
185257
func newCommon() *Migrate {
186258
return &Migrate{
187259
GracefulStop: make(chan bool, 1),
@@ -215,20 +287,42 @@ func (m *Migrate) Migrate(version uint) error {
215287
if err := m.lock(); err != nil {
216288
return err
217289
}
218-
219290
curVersion, dirty, err := m.databaseDrv.Version()
220291
if err != nil {
221292
return m.unlockErr(err)
222293
}
223294

295+
// if the dirty flag is passed to the 'goto' command, handle the dirty state
224296
if dirty {
225-
return m.unlockErr(ErrDirty{curVersion})
297+
if m.IsDirtyHandlingEnabled() {
298+
if err = m.handleDirtyState(); err != nil {
299+
return m.unlockErr(err)
300+
}
301+
} else {
302+
// default behavior
303+
return m.unlockErr(ErrDirty{curVersion})
304+
}
305+
}
306+
307+
// Copy migrations to the destination directory,
308+
// if state was dirty when Migrate was called, we should handle the dirty state first before copying the migrations
309+
if err = m.copyFiles(); err != nil {
310+
return m.unlockErr(err)
226311
}
227312

228313
ret := make(chan interface{}, m.PrefetchMigrations)
229314
go m.read(curVersion, int(version), ret)
230315

231-
return m.unlockErr(m.runMigrations(ret))
316+
if err = m.runMigrations(ret); err != nil {
317+
return m.unlockErr(err)
318+
}
319+
// Success: Clean up and confirm
320+
// Files are cleaned up after the migration is successful
321+
if err = m.cleanupFiles(version); err != nil {
322+
return m.unlockErr(err)
323+
}
324+
// unlock the database
325+
return m.unlock()
232326
}
233327

234328
// Steps looks at the currently active migration version.
@@ -723,6 +817,7 @@ func (m *Migrate) readDown(from int, limit int, ret chan<- interface{}) {
723817
// to stop execution because it might have received a stop signal on the
724818
// GracefulStop channel.
725819
func (m *Migrate) runMigrations(ret <-chan interface{}) error {
820+
var lastCleanMigrationApplied int
726821
for r := range ret {
727822

728823
if m.stop() {
@@ -744,6 +839,15 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error {
744839
if migr.Body != nil {
745840
m.logVerbosePrintf("Read and execute %v\n", migr.LogString())
746841
if err := m.databaseDrv.Run(migr.BufferedBody); err != nil {
842+
if m.dirtyStateConf != nil && m.dirtyStateConf.enable {
843+
// this condition is required if the first migration fails
844+
if lastCleanMigrationApplied == 0 {
845+
lastCleanMigrationApplied = migr.TargetVersion
846+
}
847+
if e := m.handleMigrationFailure(lastCleanMigrationApplied); e != nil {
848+
return multierror.Append(err, e)
849+
}
850+
}
747851
return err
748852
}
749853
}
@@ -752,7 +856,7 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error {
752856
if err := m.databaseDrv.SetVersion(migr.TargetVersion, false); err != nil {
753857
return err
754858
}
755-
859+
lastCleanMigrationApplied = migr.TargetVersion
756860
endTime := time.Now()
757861
readTime := migr.FinishedReading.Sub(migr.StartedBuffering)
758862
runTime := endTime.Sub(migr.FinishedReading)
@@ -979,3 +1083,114 @@ func (m *Migrate) logErr(err error) {
9791083
m.Log.Printf("error: %v", err)
9801084
}
9811085
}
1086+
1087+
func (m *Migrate) handleDirtyState() error {
1088+
// Perform the following actions when the database state is dirty
1089+
/*
1090+
1. Update the source driver to read the migrations from the mounted volume
1091+
2. Read the last successful migration version from the file
1092+
3. Set the last successful migration version in the schema_migrations table
1093+
4. Delete the last successful migration file
1094+
*/
1095+
// the source driver should read the migrations from the mounted volume
1096+
// as the DB is dirty and last applied migrations to the database are not present in the source path
1097+
if err := m.updateSourceDrv(m.dirtyStateConf.destScheme + m.dirtyStateConf.destPath); err != nil {
1098+
return err
1099+
}
1100+
lastSuccessfulMigrationPath := filepath.Join(m.dirtyStateConf.destPath, lastSuccessfulMigrationFile)
1101+
lastVersionBytes, err := os.ReadFile(lastSuccessfulMigrationPath)
1102+
if err != nil {
1103+
return err
1104+
}
1105+
lastVersionStr := strings.TrimSpace(string(lastVersionBytes))
1106+
lastVersion, err := strconv.ParseInt(lastVersionStr, 10, 64)
1107+
if err != nil {
1108+
return fmt.Errorf("failed to parse last successful migration version: %w", err)
1109+
}
1110+
1111+
// Set the last successful migration version in the schema_migrations table
1112+
if err = m.databaseDrv.SetVersion(int(lastVersion), false); err != nil {
1113+
return fmt.Errorf("failed to apply last successful migration: %w", err)
1114+
}
1115+
1116+
m.logPrintf("Successfully set last successful migration version: %s on the DB", lastVersionStr)
1117+
1118+
if err = os.Remove(lastSuccessfulMigrationPath); err != nil {
1119+
return err
1120+
}
1121+
1122+
m.logPrintf("Successfully deleted file: %s", lastSuccessfulMigrationPath)
1123+
return nil
1124+
}
1125+
1126+
func (m *Migrate) handleMigrationFailure(lastSuccessfulMigration int) error {
1127+
if !m.IsDirtyHandlingEnabled() {
1128+
return nil
1129+
}
1130+
lastSuccessfulMigrationPath := filepath.Join(m.dirtyStateConf.destPath, lastSuccessfulMigrationFile)
1131+
return os.WriteFile(lastSuccessfulMigrationPath, []byte(strconv.Itoa(lastSuccessfulMigration)), 0644)
1132+
}
1133+
1134+
func (m *Migrate) cleanupFiles(targetVersion uint) error {
1135+
if !m.IsDirtyHandlingEnabled() {
1136+
return nil
1137+
}
1138+
1139+
files, err := os.ReadDir(m.dirtyStateConf.destPath)
1140+
if err != nil {
1141+
// If the directory does not exist
1142+
return fmt.Errorf("failed to read directory %s: %w", m.dirtyStateConf.destPath, err)
1143+
}
1144+
1145+
for _, file := range files {
1146+
fileName := file.Name()
1147+
migration, err := source.Parse(fileName)
1148+
if err != nil {
1149+
return err
1150+
}
1151+
// Delete file if version is greater than targetVersion
1152+
if migration.Version > targetVersion {
1153+
if err = os.Remove(filepath.Join(m.dirtyStateConf.destPath, fileName)); err != nil {
1154+
m.logErr(fmt.Errorf("failed to delete file %s: %v", fileName, err))
1155+
continue
1156+
}
1157+
m.logPrintf("Migration file: %s removed during cleanup", fileName)
1158+
}
1159+
}
1160+
1161+
return nil
1162+
}
1163+
1164+
// copyFiles copies all files from source to destination volume.
1165+
func (m *Migrate) copyFiles() error {
1166+
// this is the case when the dirty handling is disabled
1167+
if !m.IsDirtyHandlingEnabled() {
1168+
return nil
1169+
}
1170+
1171+
files, err := os.ReadDir(m.dirtyStateConf.srcPath)
1172+
if err != nil {
1173+
// If the directory does not exist
1174+
return fmt.Errorf("failed to read directory %s: %w", m.dirtyStateConf.srcPath, err)
1175+
}
1176+
m.logPrintf("Copying files from %s to %s", m.dirtyStateConf.srcPath, m.dirtyStateConf.destPath)
1177+
for _, file := range files {
1178+
fileName := file.Name()
1179+
if source.Regex.MatchString(fileName) {
1180+
fileContentBytes, err := os.ReadFile(filepath.Join(m.dirtyStateConf.srcPath, fileName))
1181+
if err != nil {
1182+
return err
1183+
}
1184+
info, err := file.Info()
1185+
if err != nil {
1186+
return err
1187+
}
1188+
if err = os.WriteFile(filepath.Join(m.dirtyStateConf.destPath, fileName), fileContentBytes, info.Mode().Perm()); err != nil {
1189+
return err
1190+
}
1191+
}
1192+
}
1193+
1194+
m.logPrintf("Successfully Copied files from %s to %s", m.dirtyStateConf.srcPath, m.dirtyStateConf.destPath)
1195+
return nil
1196+
}

0 commit comments

Comments
 (0)