diff --git a/.travis.yml b/.travis.yml index fdaea8cbd..0a333969c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -65,6 +65,7 @@ deploy: - cli/build/migrate.linux-armv7.tar.gz - cli/build/migrate.linux-arm64.tar.gz - cli/build/migrate.darwin-amd64.tar.gz + - cli/build/migrate.darwin-arm64.tar.gz - cli/build/migrate.windows-amd64.exe.tar.gz - cli/build/migrate.windows-386.exe.tar.gz - cli/build/sha256sum.txt diff --git a/Makefile b/Makefile index 8e23a43c7..32abce32a 100644 --- a/Makefile +++ b/Makefile @@ -18,6 +18,7 @@ build-cli: clean cd ./cmd/migrate && CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=7 go build -a -o ../../cli/build/migrate.linux-armv7 -ldflags='-X main.Version=$(VERSION) -extldflags "-static"' -tags '$(DATABASE) $(SOURCE)' . cd ./cmd/migrate && CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -a -o ../../cli/build/migrate.linux-arm64 -ldflags='-X main.Version=$(VERSION) -extldflags "-static"' -tags '$(DATABASE) $(SOURCE)' . cd ./cmd/migrate && CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -a -o ../../cli/build/migrate.darwin-amd64 -ldflags='-X main.Version=$(VERSION) -extldflags "-static"' -tags '$(DATABASE) $(SOURCE)' . + cd ./cmd/migrate && CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -a -o ../../cli/build/migrate.darwin-arm64 -ldflags='-X main.Version=$(VERSION) -extldflags "-static"' -tags '$(DATABASE) $(SOURCE)' . cd ./cmd/migrate && CGO_ENABLED=0 GOOS=windows GOARCH=386 go build -a -o ../../cli/build/migrate.windows-386.exe -ldflags='-X main.Version=$(VERSION) -extldflags "-static"' -tags '$(DATABASE) $(SOURCE)' . cd ./cmd/migrate && CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -a -o ../../cli/build/migrate.windows-amd64.exe -ldflags='-X main.Version=$(VERSION) -extldflags "-static"' -tags '$(DATABASE) $(SOURCE)' . cd ./cli/build && find . -name 'migrate*' | xargs -I{} tar czf {}.tar.gz {} diff --git a/database/sqlserver/README.md b/database/sqlserver/README.md index 8ecd87723..8f7e77ea9 100644 --- a/database/sqlserver/README.md +++ b/database/sqlserver/README.md @@ -5,7 +5,6 @@ | URL Query | WithInstance Config | Description | |------------|---------------------|-------------| -| `x-migrations-table` | `MigrationsTable` | Name of the migrations table | | `username` | | enter the SQL Server Authentication user id or the Windows Authentication user id in the DOMAIN\User format. On Windows, if user id is empty or missing Single-Sign-On is used. | | `password` | | The user's password. | | `host` | | The host to connect to. | @@ -17,6 +16,8 @@ | `encrypt` | | `disable` - Data send between client and server is not encrypted. `false` - Data sent between client and server is not encrypted beyond the login packet (Default). `true` - Data sent between client and server is encrypted. | | `app+name` || The application name (default is go-mssqldb). | | `useMsi` | | `true` - Use Azure MSI Authentication for connecting to Sql Server. Must be running from an Azure VM/an instance with MSI enabled. `false` - Use password authentication (Default). See [here for Azure MSI Auth details](https://docs.microsoft.com/en-us/azure/app-service/app-service-web-tutorial-connect-msi). NOTE: Since this cannot be tested locally, this is not officially supported. +| `x-migrations-table` | `MigrationsTable` | Name of the migrations table | +| `x-batch-enabled` | | Process batch statements using the go-mssqldb "batch" processor to support the SSMS `GO` statement. | See https://github.com/microsoft/go-mssqldb for full parameter list. diff --git a/database/sqlserver/sqlserver.go b/database/sqlserver/sqlserver.go index 3cfa48bf9..e9efe4e07 100644 --- a/database/sqlserver/sqlserver.go +++ b/database/sqlserver/sqlserver.go @@ -16,6 +16,7 @@ import ( "github.com/golang-migrate/migrate/v4/database" "github.com/hashicorp/go-multierror" mssql "github.com/microsoft/go-mssqldb" // mssql support + "github.com/microsoft/go-mssqldb/batch" // batch support ) func init() { @@ -45,6 +46,7 @@ type Config struct { MigrationsTable string DatabaseName string SchemaName string + BatchEnabled bool } // SQL Server connection @@ -103,7 +105,6 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { } conn, err := instance.Conn(context.Background()) - if err != nil { return nil, err } @@ -168,11 +169,19 @@ func (ss *SQLServer) Open(url string) (database.Driver, error) { migrationsTable := purl.Query().Get("x-migrations-table") + batchEnabled := false + if bes := purl.Query().Get("x-batch-enabled"); len(bes) > 0 { + batchEnabled, err = strconv.ParseBool(bes) + if err != nil { + return nil, fmt.Errorf("Unacceptable value for option x-batch-enabled, unable to parse option : %w", err) + } + } + px, err := WithInstance(db, &Config{ DatabaseName: purl.Path, MigrationsTable: migrationsTable, + BatchEnabled: batchEnabled, }) - if err != nil { return nil, err } @@ -241,15 +250,24 @@ func (ss *SQLServer) Run(migration io.Reader) error { // run migration query := string(migr[:]) - if _, err := ss.conn.ExecContext(context.Background(), query); err != nil { - if msErr, ok := err.(mssql.Error); ok { - message := fmt.Sprintf("migration failed: %s", msErr.Message) - if msErr.ProcName != "" { - message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName) + + scripts := []string{query} + + if ss.config.BatchEnabled { + scripts = batch.Split(query, "GO") + } + + for _, script := range scripts { + if _, err := ss.conn.ExecContext(context.Background(), script); err != nil { + if msErr, ok := err.(mssql.Error); ok { + message := fmt.Sprintf("migration failed: %s", msErr.Message) + if msErr.ProcName != "" { + message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName) + } + return database.Error{OrigErr: err, Err: message, Query: []byte(script), Line: uint(msErr.LineNo)} } - return database.Error{OrigErr: err, Err: message, Query: migr, Line: uint(msErr.LineNo)} + return database.Error{OrigErr: err, Err: "migration failed", Query: []byte(script)} } - return database.Error{OrigErr: err, Err: "migration failed", Query: migr} } return nil @@ -257,13 +275,12 @@ func (ss *SQLServer) Run(migration io.Reader) error { // SetVersion for the current database func (ss *SQLServer) SetVersion(version int, dirty bool) error { - tx, err := ss.conn.BeginTx(context.Background(), &sql.TxOptions{}) if err != nil { return &database.Error{OrigErr: err, Err: "transaction start failed"} } - query := `TRUNCATE TABLE ` + ss.getMigrationTable() + query := `DELETE FROM ` + ss.getMigrationTable() if _, err := tx.Exec(query); err != nil { if errRollback := tx.Rollback(); errRollback != nil { err = multierror.Append(err, errRollback) @@ -314,7 +331,6 @@ func (ss *SQLServer) Version() (version int, dirty bool, err error) { // Drop all tables from the database. func (ss *SQLServer) Drop() error { - // drop all referential integrity constraints query := ` DECLARE @Sql NVARCHAR(500) DECLARE @Cursor CURSOR diff --git a/database/sqlserver/sqlserver_test.go b/database/sqlserver/sqlserver_test.go index 402f4480f..1810c046c 100644 --- a/database/sqlserver/sqlserver_test.go +++ b/database/sqlserver/sqlserver_test.go @@ -20,8 +20,10 @@ import ( _ "github.com/golang-migrate/migrate/v4/source/file" ) -const defaultPort = 1433 -const saPassword = "Root1234" +const ( + defaultPort = 1433 + saPassword = "Root1234" +) var ( sqlServerOpts = dktest.Options{ @@ -49,6 +51,10 @@ func msConnectionStringMsi(host, port string, useMsi bool) string { return fmt.Sprintf("sqlserver://sa@%v:%v?database=master&useMsi=%t", host, port, useMsi) } +func msConnectionStringWithOptions(host, port string, options ...string) string { + return fmt.Sprintf("sqlserver://sa:%v@%v:%v?%s", saPassword, host, port, strings.Join(options, "&")) +} + func isReady(ctx context.Context, c dktest.ContainerInfo) bool { ip, port, err := c.Port(defaultPort) if err != nil { @@ -327,3 +333,43 @@ func testMsiFalse(t *testing.T) { } }) } + +func TestBatchedStatement(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.Port(defaultPort) + if err != nil { + t.Fatal(err) + } + + addr := msConnectionStringWithOptions(ip, port, "x-batch-enabled=true") + ms := &SQLServer{} + d, err := ms.Open(addr) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := d.Close(); err != nil { + t.Error(err) + } + }() + + batchedSQL := `CREATE VIEW v AS SELECT 1; +GO +CREATE VIEW v2 AS SELECT 2; +GO +CREATE VIEW v3 AS SELECT 3;` + + if err := d.Run(strings.NewReader(batchedSQL)); err != nil { + t.Fatalf("expected err to be nil, got %v", err) + } + + // make sure second proc exists + var exists int + if err := d.(*SQLServer).conn.QueryRowContext(context.Background(), "Select COUNT(1) from sysobjects where type = 'V' and [NAME] = 'v2'").Scan(&exists); err != nil { + t.Fatal(err) + } + if exists != 1 { + t.Fatalf("expected proc uspB to exist") + } + }) +}