Skip to content

Commit bb38de7

Browse files
authored
Merge pull request #707 from lightninglabs/migrations-tests
tapdb: add minimalistic test framework for testing DB migrations
2 parents d763d68 + efeb01c commit bb38de7

File tree

12 files changed

+369
-69
lines changed

12 files changed

+369
-69
lines changed

internal/test/helpers.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package test
33
import (
44
"bytes"
55
"encoding/hex"
6+
"os"
7+
"path/filepath"
68
"strconv"
79
"strings"
810
"sync"
@@ -369,3 +371,13 @@ func ScriptSchnorrSig(t *testing.T, pubKey *btcec.PublicKey) txscript.TapLeaf {
369371
require.NoError(t, err)
370372
return txscript.NewBaseTapLeaf(script2)
371373
}
374+
375+
// ReadTestDataFile reads a file from the testdata directory and returns its
376+
// content as a string.
377+
func ReadTestDataFile(t *testing.T, fileName string) string {
378+
path := filepath.Join("testdata", fileName)
379+
fileBytes, err := os.ReadFile(path)
380+
require.NoError(t, err)
381+
382+
return string(fileBytes)
383+
}

tapdb/asset_minting_test.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@ func newAssetStore(t *testing.T) (*AssetMintingStore, *AssetStore,
3636
// First, Make a new test database.
3737
db := NewTestDB(t)
3838

39+
mintStore, assetStore := newAssetStoreFromDB(db.BaseDB)
40+
return mintStore, assetStore, db
41+
}
42+
43+
// newAssetStoreFromDB makes a new instance of the AssetMintingStore backed by
44+
// the passed database.
45+
func newAssetStoreFromDB(db *BaseDB) (*AssetMintingStore, *AssetStore) {
3946
// TODO(roasbeef): can use another layer of type params since
4047
// duplicated?
4148
txCreator := func(tx *sql.Tx) PendingAssetStore {
@@ -50,7 +57,7 @@ func newAssetStore(t *testing.T) (*AssetMintingStore, *AssetStore,
5057
testClock := clock.NewTestClock(time.Now())
5158

5259
return NewAssetMintingStore(assetMintingDB),
53-
NewAssetStore(assetsDB, testClock), db
60+
NewAssetStore(assetsDB, testClock)
5461
}
5562

5663
func assertBatchState(t *testing.T, batch *tapgarden.MintingBatch,

tapdb/assets_store.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -861,7 +861,7 @@ func fetchAssetsWithWitness(ctx context.Context, q ActiveAssetsStore,
861861
// First, we'll fetch all the assets we know of on disk.
862862
dbAssets, err := q.QueryAssets(ctx, assetFilter)
863863
if err != nil {
864-
return nil, nil, fmt.Errorf("unable to read db assets: %v", err)
864+
return nil, nil, fmt.Errorf("unable to read db assets: %w", err)
865865
}
866866

867867
assetIDs := fMap(dbAssets, func(a ConfirmedAsset) int64 {

tapdb/migrations.go

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package tapdb
22

33
import (
44
"bytes"
5+
"errors"
56
"io"
67
"io/fs"
78
"net/http"
@@ -12,11 +13,31 @@ import (
1213
"github.com/golang-migrate/migrate/v4/source/httpfs"
1314
)
1415

15-
// applyMigrations executes all database migration files found in the given file
16+
// MigrationTarget is a functional option that can be passed to applyMigrations
17+
// to specify a target version to migrate to.
18+
type MigrationTarget func(mig *migrate.Migrate) error
19+
20+
var (
21+
// TargetLatest is a MigrationTarget that migrates to the latest
22+
// version available.
23+
TargetLatest = func(mig *migrate.Migrate) error {
24+
return mig.Up()
25+
}
26+
27+
// TargetVersion is a MigrationTarget that migrates to the given
28+
// version.
29+
TargetVersion = func(version uint) MigrationTarget {
30+
return func(mig *migrate.Migrate) error {
31+
return mig.Migrate(version)
32+
}
33+
}
34+
)
35+
36+
// applyMigrations executes database migration files found in the given file
1637
// system under the given path, using the passed database driver and database
17-
// name.
18-
func applyMigrations(fs fs.FS, driver database.Driver, path,
19-
dbName string) error {
38+
// name, up to or down to the given target version.
39+
func applyMigrations(fs fs.FS, driver database.Driver, path, dbName string,
40+
targetVersion MigrationTarget) error {
2041

2142
// With the migrate instance open, we'll create a new migration source
2243
// using the embedded file system stored in sqlSchemas. The library
@@ -36,8 +57,10 @@ func applyMigrations(fs fs.FS, driver database.Driver, path,
3657
if err != nil {
3758
return err
3859
}
39-
err = sqlMigrate.Up()
40-
if err != nil && err != migrate.ErrNoChange {
60+
61+
// Execute the migration based on the target given.
62+
err = targetVersion(sqlMigrate)
63+
if err != nil && !errors.Is(err, migrate.ErrNoChange) {
4164
return err
4265
}
4366

tapdb/migrations_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package tapdb
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/stretchr/testify/require"
8+
)
9+
10+
// TestMigrationSteps is an example test that illustrates how to test database
11+
// migrations by selectively applying only some migrations, inserting dummy data
12+
// and then applying the remaining migrations.
13+
func TestMigrationSteps(t *testing.T) {
14+
ctx := context.Background()
15+
16+
// As a first step, we create a new database but only migrate to
17+
// version 1, which only contains the macaroon tables.
18+
db := NewTestDBWithVersion(t, 1)
19+
20+
// If we create an assets store now, there should be no tables for the
21+
// assets yet.
22+
_, assetStore := newAssetStoreFromDB(db.BaseDB)
23+
_, err := assetStore.FetchAllAssets(ctx, true, true, nil)
24+
require.True(t, IsSchemaError(MapSQLError(err)))
25+
26+
// We now migrate to a later but not yet latest version.
27+
err = db.ExecuteMigrations(TargetVersion(11))
28+
require.NoError(t, err)
29+
30+
// Now there should be an asset table.
31+
_, err = assetStore.FetchAllAssets(ctx, true, true, nil)
32+
require.NoError(t, err)
33+
34+
// Assuming the next version does some changes to the data within the
35+
// asset table, we now add some dummy data to the assets related tables,
36+
// so we could then test that migration.
37+
InsertTestdata(t, db.BaseDB, "migrations_test_00011_dummy_data.sql")
38+
39+
// Make sure we now have actual assets in the database.
40+
dbAssets, err := assetStore.FetchAllAssets(ctx, true, true, nil)
41+
require.NoError(t, err)
42+
require.Len(t, dbAssets, 4)
43+
44+
// And now that we have test data inserted, we can migrate to the latest
45+
// version.
46+
err = db.ExecuteMigrations(TargetLatest)
47+
require.NoError(t, err)
48+
49+
// Here we would now test that the migration to the latest version did
50+
// what we expected it to do. But this is just an example, illustrating
51+
// the steps that can be taken to test migrations, so we are done for
52+
// this test.
53+
}

tapdb/postgres.go

Lines changed: 63 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,16 @@ var (
3030
// fully executed yet. So this time needs to be chosen correctly to be
3131
// longer than the longest expected individual test run time.
3232
DefaultPostgresFixtureLifetime = 60 * time.Minute
33+
34+
// postgresSchemaReplacements is a map of schema strings that need to be
35+
// replaced for postgres. This is needed because we write the schemas
36+
// to work with sqlite primarily, and postgres has some differences.
37+
postgresSchemaReplacements = map[string]string{
38+
"BLOB": "BYTEA",
39+
"INTEGER PRIMARY KEY": "SERIAL PRIMARY KEY",
40+
"BIGINT PRIMARY KEY": "BIGSERIAL PRIMARY KEY",
41+
"TIMESTAMP": "TIMESTAMP WITHOUT TIME ZONE",
42+
}
3343
)
3444

3545
// PostgresConfig holds the postgres database configuration.
@@ -107,44 +117,41 @@ func NewPostgresStore(cfg *PostgresConfig) (*PostgresStore, error) {
107117
rawDb.SetConnMaxLifetime(connMaxLifetime)
108118
rawDb.SetConnMaxIdleTime(connMaxIdleTime)
109119

110-
if !cfg.SkipMigrations {
111-
// Now that the database is open, populate the database with
112-
// our set of schemas based on our embedded in-memory file
113-
// system.
114-
//
115-
// First, we'll need to open up a new migration instance for
116-
// our current target database: sqlite.
117-
driver, err := postgres_migrate.WithInstance(
118-
rawDb, &postgres_migrate.Config{},
119-
)
120-
if err != nil {
121-
return nil, err
122-
}
123-
124-
postgresFS := newReplacerFS(sqlSchemas, map[string]string{
125-
"BLOB": "BYTEA",
126-
"INTEGER PRIMARY KEY": "SERIAL PRIMARY KEY",
127-
"BIGINT PRIMARY KEY": "BIGSERIAL PRIMARY KEY",
128-
"TIMESTAMP": "TIMESTAMP WITHOUT TIME ZONE",
129-
})
130-
131-
err = applyMigrations(
132-
postgresFS, driver, "sqlc/migrations", cfg.DBName,
133-
)
134-
if err != nil {
135-
return nil, err
136-
}
137-
}
138-
139120
queries := sqlc.NewPostgres(rawDb)
140-
141-
return &PostgresStore{
121+
s := &PostgresStore{
142122
cfg: cfg,
143123
BaseDB: &BaseDB{
144124
DB: rawDb,
145125
Queries: queries,
146126
},
147-
}, nil
127+
}
128+
129+
// Now that the database is open, populate the database with our set of
130+
// schemas based on our embedded in-memory file system.
131+
if !cfg.SkipMigrations {
132+
if err := s.ExecuteMigrations(TargetLatest); err != nil {
133+
return nil, fmt.Errorf("error executing migrations: "+
134+
"%w", err)
135+
}
136+
}
137+
138+
return s, nil
139+
}
140+
141+
// ExecuteMigrations runs migrations for the Postgres database, depending on the
142+
// target given, either all migrations or up to a given version.
143+
func (s *PostgresStore) ExecuteMigrations(target MigrationTarget) error {
144+
driver, err := postgres_migrate.WithInstance(
145+
s.DB, &postgres_migrate.Config{},
146+
)
147+
if err != nil {
148+
return fmt.Errorf("error creating postgres migration: %w", err)
149+
}
150+
151+
postgresFS := newReplacerFS(sqlSchemas, postgresSchemaReplacements)
152+
return applyMigrations(
153+
postgresFS, driver, "sqlc/migrations", s.cfg.DBName, target,
154+
)
148155
}
149156

150157
// NewTestPostgresDB is a helper function that creates a Postgres database for
@@ -164,3 +171,27 @@ func NewTestPostgresDB(t *testing.T) *PostgresStore {
164171

165172
return store
166173
}
174+
175+
// NewTestPostgresDBWithVersion is a helper function that creates a Postgres
176+
// database for testing and migrates it to the given version.
177+
func NewTestPostgresDBWithVersion(t *testing.T, version uint) *PostgresStore {
178+
t.Helper()
179+
180+
t.Logf("Creating new Postgres DB for testing, migrating to version %d",
181+
version)
182+
183+
sqlFixture := NewTestPgFixture(t, DefaultPostgresFixtureLifetime, true)
184+
storeCfg := sqlFixture.GetConfig()
185+
storeCfg.SkipMigrations = true
186+
store, err := NewPostgresStore(storeCfg)
187+
require.NoError(t, err)
188+
189+
err = store.ExecuteMigrations(TargetVersion(version))
190+
require.NoError(t, err)
191+
192+
t.Cleanup(func() {
193+
sqlFixture.TearDown(t)
194+
})
195+
196+
return store
197+
}

tapdb/sqlerrors.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package tapdb
33
import (
44
"errors"
55
"fmt"
6+
"strings"
67

78
"github.com/jackc/pgconn"
89
"github.com/jackc/pgerrcode"
@@ -52,6 +53,20 @@ func parseSqliteError(sqliteErr *sqlite.Error) error {
5253
DbError: sqliteErr,
5354
}
5455

56+
// Generic error, need to parse the message further.
57+
case sqlite3.SQLITE_ERROR:
58+
errMsg := sqliteErr.Error()
59+
60+
switch {
61+
case strings.Contains(errMsg, "no such table"):
62+
return &ErrSchemaError{
63+
DbError: sqliteErr,
64+
}
65+
66+
default:
67+
return fmt.Errorf("unknown sqlite error: %w", sqliteErr)
68+
}
69+
5570
default:
5671
return fmt.Errorf("unknown sqlite error: %w", sqliteErr)
5772
}
@@ -73,6 +88,12 @@ func parsePostgresError(pqErr *pgconn.PgError) error {
7388
DbError: pqErr,
7489
}
7590

91+
// Handle schema error.
92+
case pgerrcode.UndefinedColumn, pgerrcode.UndefinedTable:
93+
return &ErrSchemaError{
94+
DbError: pqErr,
95+
}
96+
7697
default:
7798
return fmt.Errorf("unknown postgres error: %w", pqErr)
7899
}
@@ -111,3 +132,25 @@ func IsSerializationError(err error) bool {
111132
var serializationError *ErrSerializationError
112133
return errors.As(err, &serializationError)
113134
}
135+
136+
// ErrSchemaError is an error type which represents a database agnostic error
137+
// that the schema of the database is incorrect for the given query.
138+
type ErrSchemaError struct {
139+
DbError error
140+
}
141+
142+
// Unwrap returns the wrapped error.
143+
func (e ErrSchemaError) Unwrap() error {
144+
return e.DbError
145+
}
146+
147+
// Error returns the error message.
148+
func (e ErrSchemaError) Error() string {
149+
return e.DbError.Error()
150+
}
151+
152+
// IsSchemaError returns true if the given error is a schema error.
153+
func IsSchemaError(err error) bool {
154+
var schemaError *ErrSchemaError
155+
return errors.As(err, &schemaError)
156+
}

0 commit comments

Comments
 (0)