diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 25062ce6c..5555007e0 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -6,6 +6,7 @@ on: jobs: lint: + if: false # disable linting for infobloxopen/migrate name: lint runs-on: ubuntu-latest steps: diff --git a/.golangci.yml b/.golangci.yml index 68a8e953b..55401a171 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -13,14 +13,10 @@ linters: - unparam - nakedret - prealloc - - revive #- gosec linters-settings: misspell: locale: US - revive: - rules: - - name: redundant-build-tag issues: max-same-issues: 0 max-issues-per-linter: 0 diff --git a/Dockerfile b/Dockerfile index 46b4e18fd..a47fbea45 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,9 +1,9 @@ -FROM golang:1.24-alpine3.21 AS builder +FROM golang:1.24-alpine AS builder ARG VERSION RUN apk add --no-cache git gcc musl-dev make -WORKDIR /go/src/github.com/golang-migrate/migrate +WORKDIR /go/src/github.com/infobloxopen/migrate ENV GO111MODULE=on @@ -15,12 +15,11 @@ COPY . ./ RUN make build-docker -FROM alpine:3.21 +FROM gcr.io/distroless/static:nonroot -RUN apk add --no-cache ca-certificates - -COPY --from=builder /go/src/github.com/golang-migrate/migrate/build/migrate.linux-386 /usr/local/bin/migrate -RUN ln -s /usr/local/bin/migrate /migrate +COPY --from=builder /go/src/github.com/infobloxopen/migrate/cmd/migrate/config /cli/config/ +COPY --from=builder /go/src/github.com/infobloxopen/migrate/build/migrate.linux-386 /migrate +COPY --from=builder /etc/ssl/certs/ /etc/ssl/certs/ ENTRYPOINT ["migrate"] CMD ["--help"] diff --git a/Jenkinsfile b/Jenkinsfile new file mode 100644 index 000000000..14dea757c --- /dev/null +++ b/Jenkinsfile @@ -0,0 +1,63 @@ + +// This library defines the isPrBuild, prepareBuild and finalizeBuild methods +@Library('jenkins.shared.library') _ + +pipeline { + agent { + label 'ubuntu_docker_label' + } + tools { + go "Go 1.24.2" + } + options { + checkoutToSubdirectory('src/github.com/infobloxopen/migrate') + } + environment { + GOPATH = "$WORKSPACE" + DIRECTORY = "src/github.com/infobloxopen/migrate" + } + + stages { + stage("Setup") { + steps { + // prepareBuild is one of the Secure CICD helper methods + prepareBuild() + } + } + stage("Unit Tests") { + steps { + dir("$DIRECTORY") { + // sh "make test" + } + } + } + stage("Build Image") { + steps { + withDockerRegistry([credentialsId: "${env.JENKINS_DOCKER_CRED_ID}", url: ""]) { + dir("$DIRECTORY") { + sh "make build" + } + } + } + } + } + post { + success { + // finalizeBuild is one of the Secure CICD helper methods + dir("$DIRECTORY") { + finalizeBuild( + sh( + script: 'make list-of-images', + returnStdout: true + ) + ) + } + } + cleanup { + dir("$DIRECTORY") { + sh "make clean || true" + } + cleanWs() + } + } +} diff --git a/Makefile b/Makefile index 8e23a43c7..7a44d98ff 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,8 @@ SOURCE ?= file go_bindata github github_ee bitbucket aws_s3 google_cloud_storage godoc_vfs gitlab DATABASE ?= postgres mysql redshift cassandra spanner cockroachdb yugabytedb clickhouse mongodb sqlserver firebird neo4j pgx pgx5 rqlite DATABASE_TEST ?= $(DATABASE) sqlite sqlite3 sqlcipher -VERSION ?= $(shell git describe --tags 2>/dev/null | cut -c 2-) +BUILD_NUMBER ?= 0 +VERSION ?= $(shell git describe --tags --long --dirty=-unsupported 2>/dev/null | cut -c 2-)-j$(BUILD_NUMBER) TEST_FLAGS ?= REPO_OWNER ?= $(shell cd .. && basename "$$(pwd)") COVERAGE_DIR ?= .coverage @@ -24,6 +25,14 @@ build-cli: clean cd ./cli/build && shasum -a 256 * > sha256sum.txt cat ./cli/build/sha256sum.txt +docker-push: + docker push infoblox/migrate:$(VERSION) + +show-image-version: + echo $(VERSION) + +list-of-images: + @echo "infoblox/migrate:$(VERSION)" clean: -rm -r ./cli/build @@ -117,4 +126,3 @@ endef SHELL = /bin/sh RAND = $(shell echo $$RANDOM) - diff --git a/cmd/migrate/.gitignore b/cmd/migrate/.gitignore new file mode 100644 index 000000000..219f6a587 --- /dev/null +++ b/cmd/migrate/.gitignore @@ -0,0 +1 @@ +migrate diff --git a/cmd/migrate/config.go b/cmd/migrate/config.go new file mode 100644 index 000000000..a03097618 --- /dev/null +++ b/cmd/migrate/config.go @@ -0,0 +1,42 @@ +package main + +import "github.com/spf13/pflag" + +const ( + // configuration defaults support local development (i.e. "go run ...") + defaultDatabaseDSN = "" + defaultDatabaseDriver = "postgres" + defaultDatabaseAddress = "0.0.0.0:5432" + defaultDatabaseName = "" + defaultDatabaseUser = "postgres" + defaultDatabasePassword = "postgres" + defaultDatabaseSSL = "disable" + defaultConfigDirectory = "/cli/config" +) + +var ( + // define flag overrides + flagHelp = pflag.Bool("help", false, "Print usage") + flagVersion = pflag.String("version", Version, "Print version") + flagLoggingVerbose = pflag.Bool("verbose", true, "Print verbose logging") + flagPrefetch = pflag.Uint("prefetch", 10, "Number of migrations to load in advance before executing") + flaglockTimeout = pflag.Uint("lock-timeout", 15, "Allow N seconds to acquire database lock") + + flagDatabaseDSN = pflag.String("database.dsn", defaultDatabaseDSN, "database connection string") + flagDatabaseDriver = pflag.String("database.driver", defaultDatabaseDriver, "database driver") + flagDatabaseAddress = pflag.String("database.address", defaultDatabaseAddress, "address of the database") + flagDatabaseName = pflag.String("database.name", defaultDatabaseName, "name of the database") + flagDatabaseUser = pflag.String("database.user", defaultDatabaseUser, "database username") + flagDatabasePassword = pflag.String("database.password", defaultDatabasePassword, "database password") + flagDatabaseSSL = pflag.String("database.ssl", defaultDatabaseSSL, "database ssl mode") + + flagSource = pflag.String("source", "", "Location of the migrations (driver://url)") + flagPath = pflag.String("path", "", "Shorthand for -source=file://path") + + flagConfigDirectory = pflag.String("config.source", defaultConfigDirectory, "directory of the configuration file") + flagConfigFile = pflag.String("config.file", "", "configuration file name without extension") + + // goto command flags + flagDirty = pflag.Bool("force-dirty-handling", false, "force the handling of dirty database state") + flagMountPath = pflag.String("cache-dir", "", "path to the cache-dir which is used to copy the migration files") +) diff --git a/cmd/migrate/config/defaults.yaml b/cmd/migrate/config/defaults.yaml new file mode 100644 index 000000000..e861bf15c --- /dev/null +++ b/cmd/migrate/config/defaults.yaml @@ -0,0 +1,14 @@ +help: false +version: false +verbose: true +prefetch: 10 +lockTimeout: 15 +path: "/atlas-migrations/migrations" +#source: "file:///atlas-migrations/migrations" +database: + driver: postgres + address: postgres:5432 + name: app_db + user: postgres + password: postgres + ssl: disable \ No newline at end of file diff --git a/cmd/migrate/main.go b/cmd/migrate/main.go index 7cda72e71..b6188d95f 100644 --- a/cmd/migrate/main.go +++ b/cmd/migrate/main.go @@ -1,6 +1,39 @@ package main -import "github.com/golang-migrate/migrate/v4/internal/cli" +import ( + "log" + "strings" + + "github.com/golang-migrate/migrate/v4/internal/cli" + "github.com/infobloxopen/hotload" + _ "github.com/infobloxopen/hotload/fsnotify" + "github.com/jackc/pgx/v4/stdlib" + "github.com/lib/pq" + "github.com/sirupsen/logrus" + "github.com/spf13/pflag" + "github.com/spf13/viper" +) + +func init() { + pflag.Parse() + viper.BindPFlags(pflag.CommandLine) + viper.AutomaticEnv() + viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) + viper.AddConfigPath(viper.GetString("config.source")) + if viper.GetString("config.file") != "" { + viper.SetConfigName(viper.GetString("config.file")) + if err := viper.ReadInConfig(); err != nil { + log.Fatalf("cannot load configuration: %v", err) + } + } + // logrus formatter + customFormatter := new(logrus.JSONFormatter) + logrus.SetFormatter(customFormatter) + + hotload.RegisterSQLDriver("pgx", stdlib.GetDefaultDriver()) + hotload.RegisterSQLDriver("postgres", pq.Driver{}) + hotload.RegisterSQLDriver("postgresql", pq.Driver{}) +} func main() { cli.Main(Version) diff --git a/docker-deploy.sh b/docker-deploy.sh index 558ea79be..d967acdb7 100755 --- a/docker-deploy.sh +++ b/docker-deploy.sh @@ -1,5 +1,5 @@ #!/bin/bash echo "$DOCKER_PASSWORD" | docker login -u "$DOCKER_USERNAME" --password-stdin && \ -docker build --build-arg VERSION="$TRAVIS_TAG" . -t migrate/migrate -t migrate/migrate:"$TRAVIS_TAG" && \ +docker build --pull --build-arg VERSION="$TRAVIS_TAG" . -t migrate/migrate -t migrate/migrate:"$TRAVIS_TAG" && \ docker push migrate/migrate:"$TRAVIS_TAG" && docker push migrate/migrate diff --git a/internal/cli/log.go b/internal/cli/log.go index b17754197..91c6474f2 100644 --- a/internal/cli/log.go +++ b/internal/cli/log.go @@ -2,7 +2,7 @@ package cli import ( "fmt" - logpkg "log" + logpkg "github.com/sirupsen/logrus" "os" ) diff --git a/internal/cli/main.go b/internal/cli/main.go index c7a3bd74a..e1d21934d 100644 --- a/internal/cli/main.go +++ b/internal/cli/main.go @@ -1,8 +1,9 @@ package cli import ( - "flag" + "database/sql" "fmt" + "net/url" "os" "os/signal" "strconv" @@ -10,8 +11,12 @@ import ( "syscall" "time" + flag "github.com/spf13/pflag" + "github.com/spf13/viper" + "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database" + "github.com/golang-migrate/migrate/v4/database/postgres" "github.com/golang-migrate/migrate/v4/source" ) @@ -24,7 +29,9 @@ const ( Use -format option to specify a Go time format string. Note: migrations with the same time cause "duplicate migration version" error. Use -tz option to specify the timezone that will be used when generating non-sequential migrations (defaults: UTC). ` - gotoUsage = `goto V Migrate to version V` + gotoUsage = `goto V [-force-dirty-handling] [-cache-dir P] Migrate to version V + Use -force-dirty-handling to handle dirty database state + Use -cache-dir to specify the intermediate path P for storing migrations` upUsage = `up [N] Apply all or N up migrations` downUsage = `down [N] [-all] Apply all or N down migrations Use -all to apply all down migrations` @@ -58,16 +65,30 @@ func printUsageAndExit() { os.Exit(2) } +func dbMakeConnectionString(driver, user, password, address, name, ssl string) string { + return fmt.Sprintf("%s://%s:%s@%s/%s?sslmode=%s", + driver, url.QueryEscape(user), url.QueryEscape(password), address, name, ssl, + ) +} + // Main function of a cli application. It is public for backwards compatibility with `cli` package func Main(version string) { - helpPtr := flag.Bool("help", false, "") - versionPtr := flag.Bool("version", false, "") - verbosePtr := flag.Bool("verbose", false, "") - prefetchPtr := flag.Uint("prefetch", 10, "") - lockTimeoutPtr := flag.Uint("lock-timeout", 15, "") - pathPtr := flag.String("path", "", "") - databasePtr := flag.String("database", "", "") - sourcePtr := flag.String("source", "", "") + help := viper.GetBool("help") + version = viper.GetString("version") + verbose := viper.GetBool("verbose") + prefetch := viper.GetInt("prefetch") + lockTimeout := viper.GetInt("lock-timeout") + path := viper.GetString("path") + sourcePtr := viper.GetString("source") + + databasePtr := viper.GetString("database.dsn") + if databasePtr == "" { + databasePtr = dbMakeConnectionString( + viper.GetString("database.driver"), viper.GetString("database.user"), + viper.GetString("database.password"), viper.GetString("database.address"), + viper.GetString("database.name"), viper.GetString("database.ssl"), + ) + } flag.Usage = func() { fmt.Fprintf(os.Stderr, @@ -75,14 +96,25 @@ func Main(version string) { migrate [ -version | -help ] Options: - -source Location of the migrations (driver://url) - -path Shorthand for -source=file://path - -database Run migrations against this database (driver://url) - -prefetch N Number of migrations to load in advance before executing (default 10) - -lock-timeout N Allow N seconds to acquire database lock (default 15) - -verbose Print verbose logging - -version Print version - -help Print usage + --source Location of the migrations (driver://url) + --path Shorthand for -source=file://path + --database Run migrations against this database (driver://url) + --prefetch N Number of migrations to load in advance before executing (default 10) + --lock-timeout N Allow N seconds to acquire database lock (default 15) + --verbose Print verbose logging + --version Print version + --help Print usage + + // Infoblox specific + --config.source directory of the configuration file (default "/cli/config") + --config.file configuration file name (without extension) + --database.dsn database connection string + --database.driver database driver (default postgres) + --database.address address of the database (default "0.0.0.0:5432") + --database.name name of the database + --database.user database username (default "postgres") + --database.password database password (default "postgres") + --database.ssl database ssl mode (default "disable") Commands: %s @@ -97,32 +129,50 @@ Source drivers: `+strings.Join(source.List(), ", ")+` Database drivers: `+strings.Join(database.List(), ", ")+"\n", createUsage, gotoUsage, upUsage, downUsage, dropUsage, forceUsage) } - flag.Parse() - // initialize logger - log.verbose = *verbosePtr + log.verbose = verbose // show cli version - if *versionPtr { + if version == "" { fmt.Fprintln(os.Stderr, version) os.Exit(0) } // show help - if *helpPtr { + if help { flag.Usage() os.Exit(0) } // translate -path into -source if given - if *sourcePtr == "" && *pathPtr != "" { - *sourcePtr = fmt.Sprintf("file://%v", *pathPtr) + if sourcePtr == "" && path != "" { + sourcePtr = fmt.Sprintf("file://%v", path) } // initialize migrate // don't catch migraterErr here and let each command decide // how it wants to handle the error - migrater, migraterErr := migrate.New(*sourcePtr, *databasePtr) + var migrater *migrate.Migrate + var migraterErr error + + if driver := viper.GetString("database.driver"); driver == "hotload" { + db, err := sql.Open(driver, databasePtr) + if err != nil { + log.fatalErr(fmt.Errorf("could not open hotload dsn %s: %s", databasePtr, err)) + } + var dbname, user string + if err := db.QueryRow("SELECT current_database(), user").Scan(&dbname, &user); err != nil { + log.fatalErr(fmt.Errorf("could not get current_database: %s", err.Error())) + } + // dbname is not needed since it gets filled in by the driver but we want to be complete + migrateDriver, err := postgres.WithInstance(db, &postgres.Config{DatabaseName: dbname}) + if err != nil { + log.fatalErr(fmt.Errorf("could not create migrate driver: %s", err)) + } + migrater, migraterErr = migrate.NewWithDatabaseInstance(sourcePtr, dbname, migrateDriver) + } else { + migrater, migraterErr = migrate.New(sourcePtr, databasePtr) + } defer func() { if migraterErr == nil { if _, err := migrater.Close(); err != nil { @@ -132,8 +182,8 @@ Database drivers: `+strings.Join(database.List(), ", ")+"\n", createUsage, gotoU }() if migraterErr == nil { migrater.Log = log - migrater.PrefetchMigrations = *prefetchPtr - migrater.LockTimeout = time.Duration(int64(*lockTimeoutPtr)) * time.Second + migrater.PrefetchMigrations = uint(prefetch) + migrater.LockTimeout = time.Duration(int64(lockTimeout)) * time.Second // handle Ctrl+c signals := make(chan os.Signal, 1) @@ -214,8 +264,19 @@ Database drivers: `+strings.Join(database.List(), ", ")+"\n", createUsage, gotoU if err != nil { log.fatal("error: can't read version argument V") } + handleDirty := viper.GetBool("force-dirty-handling") + if handleDirty { + destPath := viper.GetString("cache-dir") + if destPath == "" { + log.fatal("error: cache-dir must be specified when force-dirty-handling is set") + } + + if err = migrater.WithDirtyStateConfig(sourcePtr, destPath, handleDirty); err != nil { + log.fatalErr(err) + } + } - if err := gotoCmd(migrater, uint(v)); err != nil { + if err = gotoCmd(migrater, uint(v)); err != nil { log.fatalErr(err) } diff --git a/migrate.go b/migrate.go index 266cc04eb..32a76145a 100644 --- a/migrate.go +++ b/migrate.go @@ -7,7 +7,11 @@ package migrate import ( "errors" "fmt" + "net/url" "os" + "path/filepath" + "strconv" + "strings" "sync" "time" @@ -36,6 +40,9 @@ var ( ErrLockTimeout = errors.New("timeout: can't acquire database lock") ) +// Define a constant for the migration file name +const lastSuccessfulMigrationFile = "lastSuccessfulMigration" + // ErrShortLimit is an error returned when not enough migrations // can be returned by a source for a given limit. type ErrShortLimit struct { @@ -80,6 +87,21 @@ type Migrate struct { // LockTimeout defaults to DefaultLockTimeout, // but can be set per Migrate instance. LockTimeout time.Duration + + // dirtyStateConfig is used to store the configuration required to handle dirty state of the database + dirtyStateConf *dirtyStateConfig +} + +type dirtyStateConfig struct { + srcScheme string + srcPath string + destScheme string + destPath string + enable bool +} + +func (m *Migrate) IsDirtyHandlingEnabled() bool { + return m.dirtyStateConf != nil && m.dirtyStateConf.enable && m.dirtyStateConf.destPath != "" } // New returns a new Migrate instance from a source URL and a database URL. @@ -107,13 +129,27 @@ func New(sourceURL, databaseURL string) (*Migrate, error) { databaseDrv, err := database.Open(databaseURL) if err != nil { - return nil, fmt.Errorf("failed to open database: %w", err) + return nil, fmt.Errorf("failed to open database, %q: %w", databaseURL, err) } m.databaseDrv = databaseDrv return m, nil } +func (m *Migrate) updateSourceDrv(sourceURL string) error { + sourceName, err := iurl.SchemeFromURL(sourceURL) + if err != nil { + return fmt.Errorf("failed to parse scheme from source URL: %w", err) + } + m.sourceName = sourceName + sourceDrv, err := source.Open(sourceURL) + if err != nil { + return fmt.Errorf("failed to open source, %q: %w", sourceURL, err) + } + m.sourceDrv = sourceDrv + return nil +} + // NewWithDatabaseInstance returns a new Migrate instance from a source URL // and an existing database instance. The source URL scheme is defined by each driver. // Use any string that can serve as an identifier during logging as databaseName. @@ -157,7 +193,7 @@ func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, data databaseDrv, err := database.Open(databaseURL) if err != nil { - return nil, fmt.Errorf("failed to open database: %w", err) + return nil, fmt.Errorf("failed to open database, %q: %w", databaseURL, err) } m.databaseDrv = databaseDrv @@ -182,6 +218,39 @@ func NewWithInstance(sourceName string, sourceInstance source.Driver, databaseNa return m, nil } +func (m *Migrate) WithDirtyStateConfig(srcPath, destPath string, isDirty bool) error { + parsePath := func(path string) (string, string, error) { + uri, err := url.Parse(path) + if err != nil { + return "", "", err + } + scheme := "file" + if uri.Scheme != "file" && uri.Scheme != "" { + return "", "", fmt.Errorf("unsupported scheme: %s", scheme) + } + return scheme + "://", uri.Path, nil + } + + sScheme, sPath, err := parsePath(srcPath) + if err != nil { + return err + } + + dScheme, dPath, err := parsePath(destPath) + if err != nil { + return err + } + + m.dirtyStateConf = &dirtyStateConfig{ + srcScheme: sScheme, + destScheme: dScheme, + srcPath: sPath, + destPath: dPath, + enable: isDirty, + } + return nil +} + func newCommon() *Migrate { return &Migrate{ GracefulStop: make(chan bool, 1), @@ -215,20 +284,42 @@ func (m *Migrate) Migrate(version uint) error { if err := m.lock(); err != nil { return err } - curVersion, dirty, err := m.databaseDrv.Version() if err != nil { return m.unlockErr(err) } + // if the dirty flag is passed to the 'goto' command, handle the dirty state if dirty { - return m.unlockErr(ErrDirty{curVersion}) + if m.IsDirtyHandlingEnabled() { + if err = m.handleDirtyState(); err != nil { + return m.unlockErr(err) + } + } else { + // default behavior + return m.unlockErr(ErrDirty{curVersion}) + } + } + + // Copy migrations to the destination directory, + // if state was dirty when Migrate was called, we should handle the dirty state first before copying the migrations + if err = m.copyFiles(); err != nil { + return m.unlockErr(err) } ret := make(chan interface{}, m.PrefetchMigrations) go m.read(curVersion, int(version), ret) - return m.unlockErr(m.runMigrations(ret)) + if err = m.runMigrations(ret); err != nil { + return m.unlockErr(err) + } + // Success: Clean up and confirm + // Files are cleaned up after the migration is successful + if err = m.cleanupFiles(version); err != nil { + return m.unlockErr(err) + } + // unlock the database + return m.unlock() } // Steps looks at the currently active migration version. @@ -526,7 +617,7 @@ func (m *Migrate) read(from int, to int, ret chan<- interface{}) { } } -// readUp reads up migrations from `from` limited by `limit`. +// readUp reads up migrations from `from` limitted by `limit`. // limit can be -1, implying no limit and reading until there are no more migrations. // Each migration is then written to the ret channel. // If an error occurs during reading, that error is written to the ret channel, too. @@ -626,7 +717,7 @@ func (m *Migrate) readUp(from int, limit int, ret chan<- interface{}) { } } -// readDown reads down migrations from `from` limited by `limit`. +// readDown reads down migrations from `from` limitted by `limit`. // limit can be -1, implying no limit and reading until there are no more migrations. // Each migration is then written to the ret channel. // If an error occurs during reading, that error is written to the ret channel, too. @@ -723,6 +814,7 @@ func (m *Migrate) readDown(from int, limit int, ret chan<- interface{}) { // to stop execution because it might have received a stop signal on the // GracefulStop channel. func (m *Migrate) runMigrations(ret <-chan interface{}) error { + var lastCleanMigrationApplied int for r := range ret { if m.stop() { @@ -744,6 +836,15 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error { if migr.Body != nil { m.logVerbosePrintf("Read and execute %v\n", migr.LogString()) if err := m.databaseDrv.Run(migr.BufferedBody); err != nil { + if m.IsDirtyHandlingEnabled() { + // this condition is required if the first migration fails + if lastCleanMigrationApplied == 0 { + lastCleanMigrationApplied = migr.TargetVersion + } + if e := m.handleMigrationFailure(lastCleanMigrationApplied); e != nil { + return multierror.Append(err, e) + } + } return err } } @@ -752,7 +853,7 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error { if err := m.databaseDrv.SetVersion(migr.TargetVersion, false); err != nil { return err } - + lastCleanMigrationApplied = migr.TargetVersion endTime := time.Now() readTime := migr.FinishedReading.Sub(migr.StartedBuffering) runTime := endTime.Sub(migr.FinishedReading) @@ -979,3 +1080,114 @@ func (m *Migrate) logErr(err error) { m.Log.Printf("error: %v", err) } } + +func (m *Migrate) handleDirtyState() error { + // Perform the following actions when the database state is dirty + /* + 1. Update the source driver to read the migrations from the destination path + 2. Read the last successful migration version from the file + 3. Set the last successful migration version in the schema_migrations table + 4. Delete the last successful migration file + */ + // the source driver should read the migrations from the destination path + // as the DB is dirty and last applied migrations to the database are not present in the source path + if err := m.updateSourceDrv(m.dirtyStateConf.destScheme + m.dirtyStateConf.destPath); err != nil { + return err + } + lastSuccessfulMigrationPath := filepath.Join(m.dirtyStateConf.destPath, lastSuccessfulMigrationFile) + lastVersionBytes, err := os.ReadFile(lastSuccessfulMigrationPath) + if err != nil { + return err + } + lastVersionStr := strings.TrimSpace(string(lastVersionBytes)) + lastVersion, err := strconv.ParseInt(lastVersionStr, 10, 64) + if err != nil { + return fmt.Errorf("failed to parse last successful migration version: %w", err) + } + + // Set the last successful migration version in the schema_migrations table + if err = m.databaseDrv.SetVersion(int(lastVersion), false); err != nil { + return fmt.Errorf("failed to apply last successful migration: %w", err) + } + + m.logPrintf("Successfully set last successful migration version: %s on the DB", lastVersionStr) + + if err = os.Remove(lastSuccessfulMigrationPath); err != nil { + return err + } + + m.logPrintf("Successfully deleted file: %s", lastSuccessfulMigrationPath) + return nil +} + +func (m *Migrate) handleMigrationFailure(lastSuccessfulMigration int) error { + if !m.IsDirtyHandlingEnabled() { + return nil + } + lastSuccessfulMigrationPath := filepath.Join(m.dirtyStateConf.destPath, lastSuccessfulMigrationFile) + return os.WriteFile(lastSuccessfulMigrationPath, []byte(strconv.Itoa(lastSuccessfulMigration)), 0644) +} + +func (m *Migrate) cleanupFiles(targetVersion uint) error { + if !m.IsDirtyHandlingEnabled() { + return nil + } + + files, err := os.ReadDir(m.dirtyStateConf.destPath) + if err != nil { + // If the directory does not exist + return fmt.Errorf("failed to read directory %s: %w", m.dirtyStateConf.destPath, err) + } + + for _, file := range files { + fileName := file.Name() + migration, err := source.Parse(fileName) + if err != nil { + return err + } + // Delete file if version is greater than targetVersion + if migration.Version > targetVersion { + if err = os.Remove(filepath.Join(m.dirtyStateConf.destPath, fileName)); err != nil { + m.logErr(fmt.Errorf("failed to delete file %s: %v", fileName, err)) + continue + } + m.logPrintf("Migration file: %s removed during cleanup", fileName) + } + } + + return nil +} + +// copyFiles copies all files from source to destination volume. +func (m *Migrate) copyFiles() error { + // this is the case when the dirty handling is disabled + if !m.IsDirtyHandlingEnabled() { + return nil + } + + files, err := os.ReadDir(m.dirtyStateConf.srcPath) + if err != nil { + // If the directory does not exist + return fmt.Errorf("failed to read directory %s: %w", m.dirtyStateConf.srcPath, err) + } + m.logPrintf("Copying files from %s to %s", m.dirtyStateConf.srcPath, m.dirtyStateConf.destPath) + for _, file := range files { + fileName := file.Name() + if source.Regex.MatchString(fileName) { + fileContentBytes, err := os.ReadFile(filepath.Join(m.dirtyStateConf.srcPath, fileName)) + if err != nil { + return err + } + info, err := file.Info() + if err != nil { + return err + } + if err = os.WriteFile(filepath.Join(m.dirtyStateConf.destPath, fileName), fileContentBytes, info.Mode().Perm()); err != nil { + return err + } + } + } + + m.logPrintf("Successfully Copied files from %s to %s", m.dirtyStateConf.srcPath, m.dirtyStateConf.destPath) + return nil +} diff --git a/migrate_dirty_test.go b/migrate_dirty_test.go new file mode 100644 index 000000000..ffb82799b --- /dev/null +++ b/migrate_dirty_test.go @@ -0,0 +1,353 @@ +package migrate + +import ( + "fmt" + "os" + "path/filepath" + "strconv" + "strings" + "testing" + + dStub "github.com/golang-migrate/migrate/v4/database/stub" + sStub "github.com/golang-migrate/migrate/v4/source/stub" +) + +func setupMigrateInstance(tempDir string) (*Migrate, *dStub.Stub) { + scheme := "stub://" + m, _ := New(scheme, scheme) + m.dirtyStateConf = &dirtyStateConfig{ + destScheme: scheme, + destPath: tempDir, + enable: true, + } + return m, m.databaseDrv.(*dStub.Stub) +} + +func TestHandleDirtyState(t *testing.T) { + tempDir := t.TempDir() + + m, dbDrv := setupMigrateInstance(tempDir) + m.sourceDrv.(*sStub.Stub).Migrations = sourceStubMigrations + + tests := []struct { + lastSuccessfulVersion int + currentVersion int + err error + setupFailure bool + }{ + {lastSuccessfulVersion: 1, currentVersion: 3, err: nil, setupFailure: false}, + {lastSuccessfulVersion: 4, currentVersion: 7, err: nil, setupFailure: false}, + {lastSuccessfulVersion: 3, currentVersion: 4, err: nil, setupFailure: false}, + {lastSuccessfulVersion: -3, currentVersion: 4, err: ErrInvalidVersion, setupFailure: false}, + {lastSuccessfulVersion: 4, currentVersion: 3, err: fmt.Errorf("open %s: no such file or directory", filepath.Join(tempDir, lastSuccessfulMigrationFile)), setupFailure: true}, + } + + for _, test := range tests { + t.Run("", func(t *testing.T) { + var lastSuccessfulMigrationPath string + // setupFailure flag helps with testing scenario where the 'lastSuccessfulMigrationFile' doesn't exist + if !test.setupFailure { + lastSuccessfulMigrationPath = filepath.Join(tempDir, lastSuccessfulMigrationFile) + if err := os.WriteFile(lastSuccessfulMigrationPath, []byte(strconv.Itoa(test.lastSuccessfulVersion)), 0644); err != nil { + t.Fatal(err) + } + } + // Setting the DB version as dirty + if err := dbDrv.SetVersion(test.currentVersion, true); err != nil { + t.Fatal(err) + } + // Quick check to see if set correctly + version, b, err := dbDrv.Version() + if err != nil { + t.Fatal(err) + } + if version != test.currentVersion { + t.Fatalf("expected version %d, got %d", test.currentVersion, version) + } + + if !b { + t.Fatalf("expected DB to be dirty, got false") + } + + // Handle dirty state + if err = m.handleDirtyState(); err != nil { + if strings.Contains(err.Error(), test.err.Error()) { + t.Logf("expected error %v, got %v", test.err, err) + if !test.setupFailure { + if err = os.Remove(lastSuccessfulMigrationPath); err != nil { + t.Fatal(err) + } + } + return + } else { + t.Fatal(err) + } + } + // Check 1: DB should no longer be dirty + if dbDrv.IsDirty { + t.Fatalf("expected dirty to be false, got true") + } + // Check 2: Current version should be the last successful version + if dbDrv.CurrentVersion != test.lastSuccessfulVersion { + t.Fatalf("expected version %d, got %d", test.lastSuccessfulVersion, dbDrv.CurrentVersion) + } + // Check 3: The lastSuccessfulMigration file shouldn't exist + if _, err = os.Stat(lastSuccessfulMigrationPath); !os.IsNotExist(err) { + t.Fatalf("expected file to be deleted, but it still exists") + } + }) + } +} + +func TestHandleMigrationFailure(t *testing.T) { + tempDir := t.TempDir() + + m, _ := setupMigrateInstance(tempDir) + + tests := []struct { + lastSuccessFulVersion int + }{ + {lastSuccessFulVersion: 3}, + {lastSuccessFulVersion: 4}, + {lastSuccessFulVersion: 5}, + } + + for _, test := range tests { + t.Run("", func(t *testing.T) { + if err := m.handleMigrationFailure(test.lastSuccessFulVersion); err != nil { + t.Fatal(err) + } + // Check 1: last successful Migration version should be stored in a file + lastSuccessfulMigrationPath := filepath.Join(tempDir, lastSuccessfulMigrationFile) + if _, err := os.Stat(lastSuccessfulMigrationPath); os.IsNotExist(err) { + t.Fatalf("expected file to be created, but it does not exist") + } + + // Check 2: Check if the content of last successful migration has the correct version + content, err := os.ReadFile(lastSuccessfulMigrationPath) + if err != nil { + t.Fatal(err) + } + + if string(content) != strconv.Itoa(test.lastSuccessFulVersion) { + t.Fatalf("expected %d, got %s", test.lastSuccessFulVersion, string(content)) + } + }) + } +} + +func TestCleanupFiles(t *testing.T) { + tempDir := t.TempDir() + + m, _ := setupMigrateInstance(tempDir) + m.sourceDrv.(*sStub.Stub).Migrations = sourceStubMigrations + + tests := []struct { + migrationFiles []string + targetVersion uint + remainingFiles []string + emptyDestPath bool + }{ + { + migrationFiles: []string{"1_name.up.sql", "2_name.up.sql", "3_name.up.sql"}, + targetVersion: 2, + remainingFiles: []string{"1_name.up.sql", "2_name.up.sql"}, + }, + { + migrationFiles: []string{"1_name.up.sql", "2_name.up.sql", "3_name.up.sql", "4_name.up.sql", "5_name.up.sql"}, + targetVersion: 3, + remainingFiles: []string{"1_name.up.sql", "2_name.up.sql", "3_name.up.sql"}, + }, + { + migrationFiles: []string{}, + targetVersion: 1, + remainingFiles: []string{}, + emptyDestPath: true, + }, + } + + for _, test := range tests { + t.Run("", func(t *testing.T) { + for _, file := range test.migrationFiles { + if err := os.WriteFile(filepath.Join(tempDir, file), []byte(""), 0644); err != nil { + t.Fatal(err) + } + } + + if test.emptyDestPath { + m.dirtyStateConf.destPath = "" + } + + if err := m.cleanupFiles(test.targetVersion); err != nil { + t.Fatal(err) + } + // check 1: only files upto the target version should exist + for _, file := range test.remainingFiles { + if _, err := os.Stat(filepath.Join(tempDir, file)); os.IsNotExist(err) { + t.Fatalf("expected file %s to exist, but it does not", file) + } + } + + // check 2: the files removed are as expected + deletedFiles := diff(test.migrationFiles, test.remainingFiles) + for _, deletedFile := range deletedFiles { + if _, err := os.Stat(filepath.Join(tempDir, deletedFile)); !os.IsNotExist(err) { + t.Fatalf("expected file %s to be deleted, but it still exists", deletedFile) + } + } + }) + } +} + +func TestCopyFiles(t *testing.T) { + srcDir := t.TempDir() + destDir := t.TempDir() + + m, _ := setupMigrateInstance(destDir) + m.dirtyStateConf.srcPath = srcDir + + tests := []struct { + migrationFiles []string + copiedFiles []string + emptyDestPath bool + }{ + { + migrationFiles: []string{"1_name.up.sql", "2_name.up.sql", "3_name.up.sql"}, + copiedFiles: []string{"1_name.up.sql", "2_name.up.sql", "3_name.up.sql"}, + }, + { + migrationFiles: []string{"1_name.up.sql", "2_name.up.sql", "3_name.up.sql", "4_name.up.sql", "current.sql"}, + copiedFiles: []string{"1_name.up.sql", "2_name.up.sql", "3_name.up.sql", "4_name.up.sql"}, + }, + { + emptyDestPath: true, // copyFiles should not do anything + }, + } + + for _, test := range tests { + t.Run("", func(t *testing.T) { + for _, file := range test.migrationFiles { + if err := os.WriteFile(filepath.Join(srcDir, file), []byte(""), 0644); err != nil { + t.Fatal(err) + } + } + if test.emptyDestPath { + m.dirtyStateConf.destPath = "" + } + + if err := m.copyFiles(); err != nil { + t.Fatal(err) + } + + for _, file := range test.copiedFiles { + if _, err := os.Stat(filepath.Join(destDir, file)); os.IsNotExist(err) { + t.Fatalf("expected file %s to be copied, but it does not exist", file) + } + } + }) + } +} + +func TestWithDirtyStateConfig(t *testing.T) { + tests := []struct { + name string + srcPath string + destPath string + isDirty bool + wantErr bool + wantConf *dirtyStateConfig + }{ + { + name: "Valid file paths", + srcPath: "file:///src/path", + destPath: "file:///dest/path", + isDirty: true, + wantErr: false, + wantConf: &dirtyStateConfig{ + srcScheme: "file://", + destScheme: "file://", + srcPath: "/src/path", + destPath: "/dest/path", + enable: true, + }, + }, + { + name: "Invalid source scheme", + srcPath: "s3:///src/path", + destPath: "file:///dest/path", + isDirty: true, + wantErr: true, + }, + { + name: "Invalid destination scheme", + srcPath: "file:///src/path", + destPath: "s3:///dest/path", + isDirty: true, + wantErr: true, + }, + { + name: "Empty source scheme", + srcPath: "/src/path", + destPath: "file:///dest/path", + isDirty: true, + wantErr: false, + wantConf: &dirtyStateConfig{ + srcScheme: "file://", + destScheme: "file://", + srcPath: "/src/path", + destPath: "/dest/path", + enable: true, + }, + }, + { + name: "Empty destination scheme", + srcPath: "file:///src/path", + destPath: "/dest/path", + isDirty: true, + wantErr: false, + wantConf: &dirtyStateConfig{ + srcScheme: "file://", + destScheme: "file://", + srcPath: "/src/path", + destPath: "/dest/path", + enable: true, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := &Migrate{} + err := m.WithDirtyStateConfig(tt.srcPath, tt.destPath, tt.isDirty) + if (err != nil) != tt.wantErr { + t.Errorf("error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && m.dirtyStateConf == tt.wantConf { + t.Errorf("dirtyStateConf = %v, want %v", m.dirtyStateConf, tt.wantConf) + } + }) + } +} + +/* + diff returns an array containing the elements in Array A and not in B +*/ + +func diff(a, b []string) []string { + temp := map[string]int{} + for _, s := range a { + temp[s]++ + } + for _, s := range b { + temp[s]-- + } + + var result []string + for s, v := range temp { + if v != 0 { + result = append(result, s) + } + } + return result +}