Skip to content

Commit 56579c6

Browse files
Add migration utils for SQL Server (dapr#3280)
Signed-off-by: ItalyPaleAle <[email protected]> Co-authored-by: Deepanshu Agarwal <[email protected]>
1 parent e903af1 commit 56579c6

File tree

4 files changed

+508
-12
lines changed

4 files changed

+508
-12
lines changed
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
/*
2+
Copyright 2023 The Dapr Authors
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
*/
13+
14+
package sqlservermigrations
15+
16+
import (
17+
"context"
18+
"database/sql"
19+
"fmt"
20+
"time"
21+
22+
commonsql "github.com/dapr/components-contrib/common/component/sql"
23+
"github.com/dapr/kit/logger"
24+
)
25+
26+
// Migrations performs migrations for the database schema
27+
type Migrations struct {
28+
DB *sql.DB
29+
Logger logger.Logger
30+
Schema string
31+
MetadataTableName string
32+
MetadataKey string
33+
34+
tableName string
35+
}
36+
37+
// Perform the required migrations
38+
func (m *Migrations) Perform(ctx context.Context, migrationFns []commonsql.MigrationFn) (err error) {
39+
// Setting a short-hand since it's going to be used a lot
40+
m.tableName = fmt.Sprintf("[%s].[%s]", m.Schema, m.MetadataTableName)
41+
42+
// Ensure the metadata table exists
43+
err = m.ensureMetadataTable(ctx)
44+
if err != nil {
45+
return fmt.Errorf("failed to ensure metadata table exists: %w", err)
46+
}
47+
48+
// In order to acquire a row-level lock, we need to have a row in the metadata table
49+
// So, we're going to write a row in there (not using a transaction, as that causes a table-level lock to be created), ignoring duplicates
50+
const lockKey = "lock"
51+
m.Logger.Debugf("Ensuring lock row '%s' exists in metadata table", lockKey)
52+
queryCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
53+
_, err = m.DB.ExecContext(queryCtx, fmt.Sprintf(`
54+
INSERT INTO %[1]s
55+
([Key], [Value])
56+
SELECT @Key, @Value
57+
WHERE NOT EXISTS (
58+
SELECT 1
59+
FROM %[1]s
60+
WHERE [Key] = @Key
61+
);
62+
`, m.tableName), sql.Named("Key", lockKey), sql.Named("Value", lockKey))
63+
cancel()
64+
if err != nil {
65+
return fmt.Errorf("failed to ensure lock row '%s' exists: %w", lockKey, err)
66+
}
67+
68+
// Now, let's use a transaction on a row in the metadata table as a lock
69+
m.Logger.Debug("Starting transaction pre-migration")
70+
tx, err := m.DB.Begin()
71+
if err != nil {
72+
return fmt.Errorf("failed to begin transaction: %w", err)
73+
}
74+
75+
// Always rollback the transaction at the end to release the lock, since the value doesn't really matter
76+
defer func() {
77+
m.Logger.Debug("Releasing migration lock")
78+
rollbackErr := tx.Rollback()
79+
if rollbackErr != nil {
80+
// Panicking here, as this forcibly closes the session and thus ensures we are not leaving locks hanging around
81+
m.Logger.Fatalf("Failed to roll back transaction: %v", rollbackErr)
82+
}
83+
}()
84+
85+
// Now, perform a SELECT with FOR UPDATE to lock the row used for locking, and only that row
86+
// We use a long timeout here as this query may block
87+
m.Logger.Debug("Acquiring migration lock")
88+
queryCtx, cancel = context.WithTimeout(ctx, time.Minute)
89+
var lock string
90+
//nolint:gosec
91+
q := fmt.Sprintf(`
92+
SELECT [Value]
93+
FROM %s
94+
WITH (XLOCK, ROWLOCK)
95+
WHERE [key] = @Key
96+
`, m.tableName)
97+
err = tx.QueryRowContext(queryCtx, q, sql.Named("Key", lockKey)).Scan(&lock)
98+
cancel()
99+
if err != nil {
100+
return fmt.Errorf("failed to acquire migration lock (row-level lock on key '%s'): %w", lockKey, err)
101+
}
102+
m.Logger.Debug("Migration lock acquired")
103+
104+
// Perform the migrations
105+
// Here we pass the database connection and not the transaction, since the transaction is only used to acquire the lock
106+
err = commonsql.Migrate(ctx, commonsql.AdaptDatabaseSQLConn(m.DB), commonsql.MigrationOptions{
107+
Logger: m.Logger,
108+
// Yes, we are using fmt.Sprintf for adding a value in a query.
109+
// This comes from a constant hardcoded at development-time, and cannot be influenced by users. So, no risk of SQL injections here.
110+
GetVersionQuery: fmt.Sprintf(`SELECT [Value] FROM %s WHERE [Key] = '%s'`, m.tableName, m.MetadataKey),
111+
UpdateVersionQuery: func(version string) (string, any) {
112+
return fmt.Sprintf(`
113+
MERGE
114+
%[1]s WITH (HOLDLOCK) AS t
115+
USING (SELECT '%[2]s' AS [Key]) AS s
116+
ON [t].[Key] = [s].[Key]
117+
WHEN MATCHED THEN
118+
UPDATE SET [Value] = @Value
119+
WHEN NOT MATCHED THEN
120+
INSERT ([Key], [Value]) VALUES ('%[2]s', @Value)
121+
;
122+
`,
123+
m.tableName, m.MetadataKey,
124+
), sql.Named("Value", version)
125+
},
126+
Migrations: migrationFns,
127+
})
128+
if err != nil {
129+
return err
130+
}
131+
132+
return nil
133+
}
134+
135+
func (m Migrations) ensureMetadataTable(ctx context.Context) error {
136+
m.Logger.Infof("Ensuring metadata table '%s' exists", m.tableName)
137+
_, err := m.DB.ExecContext(ctx, fmt.Sprintf(`
138+
IF OBJECT_ID('%[1]s', 'U') IS NULL
139+
CREATE TABLE %[1]s (
140+
[Key] VARCHAR(255) COLLATE Latin1_General_100_BIN2 NOT NULL PRIMARY KEY,
141+
[Value] VARCHAR(max) COLLATE Latin1_General_100_BIN2 NOT NULL
142+
)`,
143+
m.tableName,
144+
))
145+
if err != nil {
146+
return fmt.Errorf("failed to create metadata table: %w", err)
147+
}
148+
return nil
149+
}
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
/*
2+
Copyright 2023 The Dapr Authors
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
*/
13+
14+
package sqlservermigrations
15+
16+
import (
17+
"bytes"
18+
"context"
19+
"crypto/rand"
20+
"database/sql"
21+
"encoding/hex"
22+
"fmt"
23+
"io"
24+
"os"
25+
"strconv"
26+
"strings"
27+
"sync/atomic"
28+
"testing"
29+
"time"
30+
31+
// Blank import for the SQL Server driver
32+
_ "github.com/microsoft/go-mssqldb"
33+
34+
"github.com/stretchr/testify/assert"
35+
"github.com/stretchr/testify/require"
36+
37+
commonsql "github.com/dapr/components-contrib/common/component/sql"
38+
commonsqlserver "github.com/dapr/components-contrib/common/component/sqlserver"
39+
"github.com/dapr/kit/logger"
40+
)
41+
42+
// connectionStringEnvKey defines the env key containing the integration test connection string
43+
// To use Docker: `server=localhost;user id=sa;password=Pass@Word1;port=1433;database=dapr_test;`
44+
// To use Azure SQL: `server=<your-db-server-name>.database.windows.net;user id=<your-db-user>;port=1433;password=<your-password>;database=dapr_test;`
45+
const connectionStringEnvKey = "DAPR_TEST_SQL_CONNSTRING"
46+
47+
// Disable gosec in this test as we use string concatenation a lot with queries
48+
func TestMigration(t *testing.T) {
49+
connectionString := os.Getenv(connectionStringEnvKey)
50+
if connectionString == "" {
51+
t.Skipf(`SQLServer migration test skipped. To enable this test, define the connection string using environment variable '%[1]s' (example 'export %[1]s="server=localhost;user id=sa;password=Pass@Word1;port=1433;database=dapr_test;")'`, connectionStringEnvKey)
52+
}
53+
54+
log := logger.NewLogger("migration-test")
55+
log.SetOutputLevel(logger.DebugLevel)
56+
57+
// Connect to the database
58+
db, err := sql.Open("sqlserver", connectionString)
59+
require.NoError(t, err, "Failed to connect to database")
60+
t.Cleanup(func() {
61+
db.Close()
62+
})
63+
64+
// Create a new schema for testing
65+
schema := getUniqueDBSchema(t)
66+
_, err = db.Exec(fmt.Sprintf("CREATE SCHEMA [%s]", schema))
67+
require.NoError(t, err, "Failed to create schema")
68+
t.Cleanup(func() {
69+
err = commonsqlserver.DropSchema(context.Background(), db, schema)
70+
require.NoError(t, err, "Failed to drop schema")
71+
})
72+
73+
t.Run("Metadata table", func(t *testing.T) {
74+
m := &Migrations{
75+
DB: db,
76+
Logger: log,
77+
Schema: schema,
78+
MetadataTableName: "metadata_1",
79+
MetadataKey: "migrations",
80+
}
81+
82+
t.Run("Create new", func(t *testing.T) {
83+
err = m.Perform(context.Background(), []commonsql.MigrationFn{})
84+
require.NoError(t, err)
85+
86+
assertTableExists(t, db, schema, "metadata_1")
87+
})
88+
89+
t.Run("Already exists", func(t *testing.T) {
90+
err = m.Perform(context.Background(), []commonsql.MigrationFn{})
91+
require.NoError(t, err)
92+
93+
assertTableExists(t, db, schema, "metadata_1")
94+
})
95+
})
96+
97+
t.Run("Perform migrations", func(t *testing.T) {
98+
m := &Migrations{
99+
DB: db,
100+
Logger: log,
101+
Schema: schema,
102+
MetadataTableName: "metadata_2",
103+
MetadataKey: "migrations",
104+
}
105+
106+
fn1 := func(ctx context.Context) error {
107+
_, err = m.DB.Exec(fmt.Sprintf("CREATE TABLE [%s].[TestTable] ([Key] INTEGER NOT NULL PRIMARY KEY)", schema))
108+
return err
109+
}
110+
111+
t.Run("First migration", func(t *testing.T) {
112+
err = m.Perform(context.Background(), []commonsql.MigrationFn{fn1})
113+
require.NoError(t, err)
114+
115+
assertTableExists(t, db, schema, "TestTable")
116+
assertMigrationsLevel(t, db, schema, "metadata_2", "migrations", "1")
117+
})
118+
119+
t.Run("Second migration", func(t *testing.T) {
120+
var called bool
121+
fn2 := func(ctx context.Context) error {
122+
// We don't actually have to do anything here, we just care that the migration level has increased
123+
called = true
124+
return nil
125+
}
126+
127+
err = m.Perform(context.Background(), []commonsql.MigrationFn{fn1, fn2})
128+
require.NoError(t, err)
129+
130+
assert.True(t, called)
131+
assertMigrationsLevel(t, db, schema, "metadata_2", "migrations", "2")
132+
})
133+
134+
t.Run("Already has migrated", func(t *testing.T) {
135+
var called bool
136+
fn2 := func(ctx context.Context) error {
137+
// We don't actually have to do anything here, we just care that the migration level has increased
138+
called = true
139+
return nil
140+
}
141+
142+
err = m.Perform(context.Background(), []commonsql.MigrationFn{fn1, fn2})
143+
require.NoError(t, err)
144+
145+
assert.False(t, called)
146+
assertMigrationsLevel(t, db, schema, "metadata_2", "migrations", "2")
147+
})
148+
})
149+
150+
t.Run("Perform migrations concurrently", func(t *testing.T) {
151+
counter := atomic.Uint32{}
152+
fn := func(ctx context.Context) error {
153+
// This migration doesn't actually do anything
154+
counter.Add(1)
155+
return nil
156+
}
157+
158+
const parallel = 5
159+
errs := make(chan error, parallel)
160+
hasLogs := atomic.Uint32{}
161+
for i := 0; i < parallel; i++ {
162+
go func(i int) {
163+
// Collect logs
164+
collectLog := logger.NewLogger("concurrent-" + strconv.Itoa(i))
165+
collectLog.SetOutputLevel(logger.DebugLevel)
166+
buf := &bytes.Buffer{}
167+
collectLog.SetOutput(io.MultiWriter(buf, os.Stdout))
168+
169+
m := &Migrations{
170+
DB: db,
171+
Logger: collectLog,
172+
Schema: schema,
173+
MetadataTableName: "metadata_2",
174+
MetadataKey: "migrations_concurrent",
175+
}
176+
177+
migrateErr := m.Perform(context.Background(), []commonsql.MigrationFn{fn})
178+
if migrateErr != nil {
179+
errs <- fmt.Errorf("migration failed in handler %d: %w", i, migrateErr)
180+
}
181+
182+
// One and only one of the loggers should have any message including "Performing migration"
183+
if strings.Contains(buf.String(), "Performing migration") {
184+
hasLogs.Add(1)
185+
}
186+
187+
errs <- nil
188+
}(i)
189+
}
190+
191+
for i := 0; i < parallel; i++ {
192+
select {
193+
case err := <-errs:
194+
assert.NoError(t, err) //nolint:testifylint
195+
case <-time.After(30 * time.Second):
196+
t.Fatal("timed out waiting for migrations to complete")
197+
}
198+
}
199+
if t.Failed() {
200+
// Short-circuit
201+
t.FailNow()
202+
}
203+
204+
// Handler should have been invoked just once
205+
assert.Equal(t, uint32(1), counter.Load(), "Migrations handler invoked more than once")
206+
assert.Equal(t, uint32(1), hasLogs.Load(), "More than one logger indicated a migration")
207+
})
208+
}
209+
210+
func getUniqueDBSchema(t *testing.T) string {
211+
t.Helper()
212+
213+
b := make([]byte, 4)
214+
_, err := io.ReadFull(rand.Reader, b)
215+
require.NoError(t, err)
216+
return fmt.Sprintf("m%s", hex.EncodeToString(b))
217+
}
218+
219+
func assertTableExists(t *testing.T, db *sql.DB, schema, table string) {
220+
t.Helper()
221+
222+
var found int
223+
err := db.QueryRow(
224+
fmt.Sprintf("SELECT 1 WHERE OBJECT_ID('[%s].[%s]', 'U') IS NOT NULL", schema, table),
225+
).Scan(&found)
226+
require.NoErrorf(t, err, "Table %s not found", table)
227+
require.Equalf(t, 1, found, "Table %s not found", table)
228+
}
229+
230+
func assertMigrationsLevel(t *testing.T, db *sql.DB, schema, table, key, expectLevel string) {
231+
t.Helper()
232+
233+
var foundLevel string
234+
err := db.QueryRow(
235+
fmt.Sprintf("SELECT [Value] FROM [%s].[%s] WHERE [Key] = @Key", schema, table),
236+
sql.Named("Key", key),
237+
).Scan(&foundLevel)
238+
require.NoError(t, err, "Failed to load migrations level")
239+
require.Equal(t, expectLevel, foundLevel, "Migration level does not match")
240+
}

0 commit comments

Comments
 (0)