Skip to content

Commit 0449c0e

Browse files
committed
V2: new methods for *sql.DB connections, auto-cleanup database
1 parent a6ee6dc commit 0449c0e

File tree

5 files changed

+247
-33
lines changed

5 files changed

+247
-33
lines changed

README.md

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,22 @@ is complete.
99
As a side effect tool checks that all resources are released when test exits.
1010
If any Rows is not closed or Conn is not released to pool, test fails.
1111

12-
`go-test-pg` requires schema file to initialize database with. It creates
12+
`go-test-pg` uses schema file to initialize database with. It creates
1313
template database with this schema. Then each temporary database for every test
1414
creates from this template database. If the template database for this
1515
schema is exists, it will be reused. The name of the template database
16-
is composed of `baseName` and md5 hashsum of schema file content.
16+
is composed of `baseName` and md5 hashsum of schema file content. If schema file
17+
is empty, then use default PostgreSQL empty database `template1`.
1718

1819
On complete, temporary databases would be dropped, template database will not
1920
be dropped and would remain for future reuse.
2021

2122
Template database would be created only on first use. If you call `NewPool`
2223
and do not call `With<something>` on it, real database would not be touched.
2324

25+
Each method was `Std` version that returns `*sql.DB`. For example,
26+
default method `WithFixtures` returns `*pgxpool.Pool` and `WithStdFixtures`
27+
returns `*sql.DB`.
2428

2529
## Example usage
2630

@@ -37,11 +41,10 @@ import (
3741
var dbpool = &ptg.Pgpool{SchemaFile: "../schema.sql"}
3842

3943
func TestX(t *testing.T) {
40-
dbPool, dbClear := dbpool.WithEmpty(t)
41-
defer dbClear()
44+
dbPool := dbpool.WithEmpty(t)
4245
var dbName string
4346
err := dbPool.
44-
QueryRow(context.Background(), "select current_database()").
47+
QueryRow(context.Background(), "SELECT current_database()").
4548
Scan(&dbName)
4649
if err != nil {
4750
t.Fatal(err)
@@ -55,8 +58,8 @@ Connection to database configured using standard PostgreSQL environment
5558
variable https://www.postgresql.org/docs/11/libpq-envars.html. User needs
5659
permissions to create databases.
5760

58-
If you want to skip all tests, you need to set Skip field in Pgpool struct
59-
to false.
61+
If you want to skip all database tests, you need to set `Skip` field in Pgpool
62+
struct to `true`.
6063

6164
```go
6265
var dbpool = &ptg.Pgpool{Skip: true}

database.go

Lines changed: 126 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package go_test_pg
33
import (
44
"context"
55
"crypto/md5"
6+
"database/sql"
67
"encoding/hex"
78
"fmt"
89
"io/ioutil"
@@ -15,6 +16,7 @@ import (
1516

1617
"github.com/jackc/pgx/v4"
1718
"github.com/jackc/pgx/v4/pgxpool"
19+
"github.com/jackc/pgx/v4/stdlib"
1820
"github.com/pkg/errors"
1921
)
2022

@@ -29,7 +31,7 @@ type Pgpool struct {
2931
// BaseName is the prefix of template and temporary databases.
3032
// Default is dbtestpg.
3133
BaseName string
32-
// Name of schema file. Required. Tests would fail if not set.
34+
// Name of schema file. If empty, create empty database.
3335
SchemaFile string // schema file name
3436
// If true, skip all database tests.
3537
Skip bool
@@ -42,41 +44,66 @@ type Pgpool struct {
4244

4345
// WithFixtures creates database from template database, and initializes it
4446
// with fixtures from `fixtures` array
45-
func (p *Pgpool) WithFixtures(
46-
t testing.TB,
47-
fixtures []Fixture,
48-
) (*pgxpool.Pool, func()) {
49-
pool, clean := p.WithEmpty(t)
47+
func (p *Pgpool) WithFixtures(t testing.TB, fixtures []Fixture) *pgxpool.Pool {
48+
pool := p.WithEmpty(t)
5049
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
5150
defer cancel()
5251
for i, f := range fixtures {
5352
if _, err := pool.Exec(ctx, f.Query, f.Params...); err != nil {
54-
clean()
5553
t.Fatalf(
5654
"can't load fixture at idx %v: %+v",
5755
i, errors.WithStack(err),
5856
)
5957
}
6058
}
61-
return pool, clean
59+
return pool
60+
}
61+
62+
// WithStdFixtures creates database from template database, and initializes it
63+
// with fixtures from `fixtures` array
64+
func (p *Pgpool) WithStdFixtures(t testing.TB, fixtures []Fixture) *sql.DB {
65+
db := p.WithStdEmpty(t)
66+
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
67+
defer cancel()
68+
for i, f := range fixtures {
69+
if _, err := db.ExecContext(ctx, f.Query, f.Params...); err != nil {
70+
t.Fatalf("can't load fixture at idx %v: %+v",
71+
i, errors.WithStack(err))
72+
}
73+
}
74+
return db
6275
}
6376

6477
// WithSQLs creates database from template database, and initializes it
6578
// with fixtures from `sqls` array
66-
func (p *Pgpool) WithSQLs(t testing.TB, sqls []string) (*pgxpool.Pool, func()) {
67-
pool, clean := p.WithEmpty(t)
79+
func (p *Pgpool) WithSQLs(t testing.TB, sqls []string) *pgxpool.Pool {
80+
pool := p.WithEmpty(t)
6881
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
6982
defer cancel()
7083
for i, s := range sqls {
7184
if _, err := pool.Exec(ctx, s); err != nil {
72-
clean()
7385
t.Fatalf(
7486
"can't load fixture at idx %v: %+v",
7587
i, errors.WithStack(err),
7688
)
7789
}
7890
}
79-
return pool, clean
91+
return pool
92+
}
93+
94+
// WithStdSQLs creates database from template database, and initializes it
95+
// with fixtures from `sqls` array
96+
func (p *Pgpool) WithStdSQLs(t testing.TB, sqls []string) *sql.DB {
97+
db := p.WithStdEmpty(t)
98+
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
99+
defer cancel()
100+
for i, s := range sqls {
101+
if _, err := db.ExecContext(ctx, s); err != nil {
102+
t.Fatalf("can't load fixture at idx %v: %+v",
103+
i, errors.WithStack(err))
104+
}
105+
}
106+
return db
80107
}
81108

82109
func (p *Pgpool) getTmpl(t testing.TB) string {
@@ -111,11 +138,27 @@ func (p *Pgpool) getTmpl(t testing.TB) string {
111138
return p.tmpl
112139
}
113140

114-
func (p *Pgpool) createRndDB(t testing.TB) (*pgxpool.Pool, string) {
141+
// Register pgx.ConnConfig with std driver.
142+
// Return connection string for database/sql and error.
143+
func (p *Pgpool) registerStdConfig(t testing.TB, dbName string) (string, error) {
144+
connConfig, err := pgx.ParseConfig("")
145+
if err != nil {
146+
return "", errors.WithStack(err)
147+
}
148+
connConfig.Logger = newLogger(t)
149+
connConfig.Database = dbName
150+
return stdlib.RegisterConnConfig(connConfig), nil
151+
}
152+
153+
func (p *Pgpool) createRndDB(t testing.TB) (string, error) {
115154
tmpl := p.getTmpl(t)
116155
dbName := fmt.Sprintf("%v_%v", tmpl, p.rnd.Int31())
117156

118-
err := p.createDB(dbName, tmpl)
157+
return dbName, p.createDB(dbName, tmpl)
158+
}
159+
160+
func (p *Pgpool) createRndDBPool(t testing.TB) (*pgxpool.Pool, string) {
161+
dbName, err := p.createRndDB(t)
119162
if err != nil {
120163
t.Fatal(err)
121164
}
@@ -191,9 +234,9 @@ func dropDB(dbName string) error {
191234

192235
// WithEmpty creates empty database from template database, that was
193236
// created from `schema` file.
194-
func (p *Pgpool) WithEmpty(t testing.TB) (*pgxpool.Pool, func()) {
195-
pool, dbName := p.createRndDB(t)
196-
return pool, func() {
237+
func (p *Pgpool) WithEmpty(t testing.TB) *pgxpool.Pool {
238+
pool, dbName := p.createRndDBPool(t)
239+
t.Cleanup(func() {
197240
acquiredConns := pool.Stat().AcquiredConns()
198241
if acquiredConns > 0 {
199242
t.Fatalf(
@@ -206,7 +249,63 @@ func (p *Pgpool) WithEmpty(t testing.TB) (*pgxpool.Pool, func()) {
206249
if err != nil {
207250
t.Errorf("Can't drop DB %v: %v", dbName, err)
208251
}
252+
})
253+
return pool
254+
}
255+
256+
// WithStdEmpty creates empty database from template database, that was
257+
// created from `schema` file.
258+
func (p *Pgpool) WithStdEmpty(t testing.TB) *sql.DB {
259+
db, cleanupFn := p.newStdDBWithCleanup(t)
260+
if cleanupFn != nil {
261+
t.Cleanup(func() {
262+
if err := cleanupFn(); err != nil {
263+
t.Error(err)
264+
}
265+
})
266+
}
267+
return db
268+
}
269+
270+
func (p *Pgpool) newStdDBWithCleanup(t testing.TB) (*sql.DB, func() error) {
271+
dbName, err := p.createRndDB(t)
272+
if err != nil {
273+
t.Fatal(err)
274+
return nil, nil
275+
}
276+
277+
connString, err := p.registerStdConfig(t, dbName)
278+
if err != nil {
279+
_ = dropDB(dbName)
280+
t.Fatal(err)
281+
return nil, nil
282+
}
283+
284+
db, err := sql.Open("pgx", connString)
285+
if err != nil {
286+
_ = dropDB(dbName)
287+
t.Fatal(err)
288+
return nil, nil
289+
}
290+
291+
cleanupFn := func() error {
292+
stats := db.Stats()
293+
if stats.InUse > 0 {
294+
return errors.Errorf(
295+
"unreleased connections exists: %v, can't drop database %v",
296+
stats.InUse, dbName)
297+
}
298+
err := db.Close()
299+
if err != nil {
300+
return errors.Errorf("Can't close DB %v: %v", dbName, err)
301+
}
302+
err = dropDB(dbName)
303+
if err != nil {
304+
return errors.Errorf("Can't drop DB %v: %v", dbName, err)
305+
}
306+
return nil
209307
}
308+
return db, cleanupFn
210309
}
211310

212311
func (p *Pgpool) createDB(name, tmplName string) error {
@@ -224,9 +323,11 @@ func (p *Pgpool) createDB(name, tmplName string) error {
224323
)
225324
}
226325

326+
// Creates template db, populates with SQLs from schema file and return name
327+
// of the new database. If database is exists, just return its name.
227328
func (p *Pgpool) createTemplateDB() (string, error) {
228329
if p.SchemaFile == "" {
229-
return "", errors.New("SchemaFile is empty")
330+
return "template1", nil
230331
}
231332
schemaSql, err := ioutil.ReadFile(p.SchemaFile)
232333
if err != nil {
@@ -238,7 +339,7 @@ func (p *Pgpool) createTemplateDB() (string, error) {
238339
if p.BaseName != "" {
239340
baseName = p.BaseName
240341
}
241-
tmpl := fmt.Sprintf("%v_%v", baseName, schemaHex)
342+
tmplDbName := fmt.Sprintf("%v_%v", baseName, schemaHex)
242343

243344
var dbExists bool
244345
err = withNewConnection(
@@ -247,14 +348,14 @@ func (p *Pgpool) createTemplateDB() (string, error) {
247348
query := `
248349
SELECT EXISTS(SELECT 1 FROM pg_database WHERE datname = $1)
249350
`
250-
err := conn.QueryRow(ctx, query, tmpl).Scan(&dbExists)
351+
err := conn.QueryRow(ctx, query, tmplDbName).Scan(&dbExists)
251352
if err != nil {
252353
return errors.WithStack(err)
253354
}
254355
if dbExists {
255356
return nil
256357
}
257-
_, err = conn.Exec(ctx, `CREATE DATABASE `+quote(tmpl))
358+
_, err = conn.Exec(ctx, `CREATE DATABASE `+quote(tmplDbName))
258359
return errors.WithStack(err)
259360
},
260361
)
@@ -263,23 +364,23 @@ SELECT EXISTS(SELECT 1 FROM pg_database WHERE datname = $1)
263364
}
264365

265366
if dbExists {
266-
return tmpl, nil
367+
return tmplDbName, nil
267368
}
268369

269370
err = withNewConnection(
270-
tmpl,
371+
tmplDbName,
271372
func(ctx context.Context, conn *pgx.Conn) error {
272373
_, err = conn.Exec(ctx, string(schemaSql))
273374
return errors.WithStack(err)
274375
},
275376
)
276377

277378
if err != nil {
278-
_ = dropDB(tmpl)
379+
_ = dropDB(tmplDbName)
279380
return "", err
280381
}
281382

282-
return tmpl, nil
383+
return tmplDbName, nil
283384
}
284385

285386
func quote(name string) string {

0 commit comments

Comments
 (0)