@@ -3,6 +3,7 @@ package go_test_pg
33import (
44 "context"
55 "crypto/md5"
6+ "database/sql"
67 "encoding/hex"
78 "fmt"
89 "io/ioutil"
@@ -15,6 +16,7 @@ import (
1516
1617 "github.com/jackc/pgx/v4"
1718 "github.com/jackc/pgx/v4/pgxpool"
19+ "github.com/jackc/pgx/v4/stdlib"
1820 "github.com/pkg/errors"
1921)
2022
@@ -29,7 +31,7 @@ type Pgpool struct {
2931 // BaseName is the prefix of template and temporary databases.
3032 // Default is dbtestpg.
3133 BaseName string
32- // Name of schema file. Required. Tests would fail if not set .
34+ // Name of schema file. If empty, create empty database .
3335 SchemaFile string // schema file name
3436 // If true, skip all database tests.
3537 Skip bool
@@ -42,41 +44,66 @@ type Pgpool struct {
4244
4345// WithFixtures creates database from template database, and initializes it
4446// with fixtures from `fixtures` array
45- func (p * Pgpool ) WithFixtures (
46- t testing.TB ,
47- fixtures []Fixture ,
48- ) (* pgxpool.Pool , func ()) {
49- pool , clean := p .WithEmpty (t )
47+ func (p * Pgpool ) WithFixtures (t testing.TB , fixtures []Fixture ) * pgxpool.Pool {
48+ pool := p .WithEmpty (t )
5049 ctx , cancel := context .WithTimeout (context .Background (), defaultTimeout )
5150 defer cancel ()
5251 for i , f := range fixtures {
5352 if _ , err := pool .Exec (ctx , f .Query , f .Params ... ); err != nil {
54- clean ()
5553 t .Fatalf (
5654 "can't load fixture at idx %v: %+v" ,
5755 i , errors .WithStack (err ),
5856 )
5957 }
6058 }
61- return pool , clean
59+ return pool
60+ }
61+
62+ // WithStdFixtures creates database from template database, and initializes it
63+ // with fixtures from `fixtures` array
64+ func (p * Pgpool ) WithStdFixtures (t testing.TB , fixtures []Fixture ) * sql.DB {
65+ db := p .WithStdEmpty (t )
66+ ctx , cancel := context .WithTimeout (context .Background (), defaultTimeout )
67+ defer cancel ()
68+ for i , f := range fixtures {
69+ if _ , err := db .ExecContext (ctx , f .Query , f .Params ... ); err != nil {
70+ t .Fatalf ("can't load fixture at idx %v: %+v" ,
71+ i , errors .WithStack (err ))
72+ }
73+ }
74+ return db
6275}
6376
6477// WithSQLs creates database from template database, and initializes it
6578// with fixtures from `sqls` array
66- func (p * Pgpool ) WithSQLs (t testing.TB , sqls []string ) ( * pgxpool.Pool , func ()) {
67- pool , clean := p .WithEmpty (t )
79+ func (p * Pgpool ) WithSQLs (t testing.TB , sqls []string ) * pgxpool.Pool {
80+ pool := p .WithEmpty (t )
6881 ctx , cancel := context .WithTimeout (context .Background (), defaultTimeout )
6982 defer cancel ()
7083 for i , s := range sqls {
7184 if _ , err := pool .Exec (ctx , s ); err != nil {
72- clean ()
7385 t .Fatalf (
7486 "can't load fixture at idx %v: %+v" ,
7587 i , errors .WithStack (err ),
7688 )
7789 }
7890 }
79- return pool , clean
91+ return pool
92+ }
93+
94+ // WithStdSQLs creates database from template database, and initializes it
95+ // with fixtures from `sqls` array
96+ func (p * Pgpool ) WithStdSQLs (t testing.TB , sqls []string ) * sql.DB {
97+ db := p .WithStdEmpty (t )
98+ ctx , cancel := context .WithTimeout (context .Background (), defaultTimeout )
99+ defer cancel ()
100+ for i , s := range sqls {
101+ if _ , err := db .ExecContext (ctx , s ); err != nil {
102+ t .Fatalf ("can't load fixture at idx %v: %+v" ,
103+ i , errors .WithStack (err ))
104+ }
105+ }
106+ return db
80107}
81108
82109func (p * Pgpool ) getTmpl (t testing.TB ) string {
@@ -111,11 +138,27 @@ func (p *Pgpool) getTmpl(t testing.TB) string {
111138 return p .tmpl
112139}
113140
114- func (p * Pgpool ) createRndDB (t testing.TB ) (* pgxpool.Pool , string ) {
141+ // Register pgx.ConnConfig with std driver.
142+ // Return connection string for database/sql and error.
143+ func (p * Pgpool ) registerStdConfig (t testing.TB , dbName string ) (string , error ) {
144+ connConfig , err := pgx .ParseConfig ("" )
145+ if err != nil {
146+ return "" , errors .WithStack (err )
147+ }
148+ connConfig .Logger = newLogger (t )
149+ connConfig .Database = dbName
150+ return stdlib .RegisterConnConfig (connConfig ), nil
151+ }
152+
153+ func (p * Pgpool ) createRndDB (t testing.TB ) (string , error ) {
115154 tmpl := p .getTmpl (t )
116155 dbName := fmt .Sprintf ("%v_%v" , tmpl , p .rnd .Int31 ())
117156
118- err := p .createDB (dbName , tmpl )
157+ return dbName , p .createDB (dbName , tmpl )
158+ }
159+
160+ func (p * Pgpool ) createRndDBPool (t testing.TB ) (* pgxpool.Pool , string ) {
161+ dbName , err := p .createRndDB (t )
119162 if err != nil {
120163 t .Fatal (err )
121164 }
@@ -191,9 +234,9 @@ func dropDB(dbName string) error {
191234
192235// WithEmpty creates empty database from template database, that was
193236// created from `schema` file.
194- func (p * Pgpool ) WithEmpty (t testing.TB ) ( * pgxpool.Pool , func ()) {
195- pool , dbName := p .createRndDB (t )
196- return pool , func () {
237+ func (p * Pgpool ) WithEmpty (t testing.TB ) * pgxpool.Pool {
238+ pool , dbName := p .createRndDBPool (t )
239+ t . Cleanup ( func () {
197240 acquiredConns := pool .Stat ().AcquiredConns ()
198241 if acquiredConns > 0 {
199242 t .Fatalf (
@@ -206,7 +249,63 @@ func (p *Pgpool) WithEmpty(t testing.TB) (*pgxpool.Pool, func()) {
206249 if err != nil {
207250 t .Errorf ("Can't drop DB %v: %v" , dbName , err )
208251 }
252+ })
253+ return pool
254+ }
255+
256+ // WithStdEmpty creates empty database from template database, that was
257+ // created from `schema` file.
258+ func (p * Pgpool ) WithStdEmpty (t testing.TB ) * sql.DB {
259+ db , cleanupFn := p .newStdDBWithCleanup (t )
260+ if cleanupFn != nil {
261+ t .Cleanup (func () {
262+ if err := cleanupFn (); err != nil {
263+ t .Error (err )
264+ }
265+ })
266+ }
267+ return db
268+ }
269+
270+ func (p * Pgpool ) newStdDBWithCleanup (t testing.TB ) (* sql.DB , func () error ) {
271+ dbName , err := p .createRndDB (t )
272+ if err != nil {
273+ t .Fatal (err )
274+ return nil , nil
275+ }
276+
277+ connString , err := p .registerStdConfig (t , dbName )
278+ if err != nil {
279+ _ = dropDB (dbName )
280+ t .Fatal (err )
281+ return nil , nil
282+ }
283+
284+ db , err := sql .Open ("pgx" , connString )
285+ if err != nil {
286+ _ = dropDB (dbName )
287+ t .Fatal (err )
288+ return nil , nil
289+ }
290+
291+ cleanupFn := func () error {
292+ stats := db .Stats ()
293+ if stats .InUse > 0 {
294+ return errors .Errorf (
295+ "unreleased connections exists: %v, can't drop database %v" ,
296+ stats .InUse , dbName )
297+ }
298+ err := db .Close ()
299+ if err != nil {
300+ return errors .Errorf ("Can't close DB %v: %v" , dbName , err )
301+ }
302+ err = dropDB (dbName )
303+ if err != nil {
304+ return errors .Errorf ("Can't drop DB %v: %v" , dbName , err )
305+ }
306+ return nil
209307 }
308+ return db , cleanupFn
210309}
211310
212311func (p * Pgpool ) createDB (name , tmplName string ) error {
@@ -224,9 +323,11 @@ func (p *Pgpool) createDB(name, tmplName string) error {
224323 )
225324}
226325
326+ // Creates template db, populates with SQLs from schema file and return name
327+ // of the new database. If database is exists, just return its name.
227328func (p * Pgpool ) createTemplateDB () (string , error ) {
228329 if p .SchemaFile == "" {
229- return "" , errors . New ( "SchemaFile is empty" )
330+ return "template1 " , nil
230331 }
231332 schemaSql , err := ioutil .ReadFile (p .SchemaFile )
232333 if err != nil {
@@ -238,7 +339,7 @@ func (p *Pgpool) createTemplateDB() (string, error) {
238339 if p .BaseName != "" {
239340 baseName = p .BaseName
240341 }
241- tmpl := fmt .Sprintf ("%v_%v" , baseName , schemaHex )
342+ tmplDbName := fmt .Sprintf ("%v_%v" , baseName , schemaHex )
242343
243344 var dbExists bool
244345 err = withNewConnection (
@@ -247,14 +348,14 @@ func (p *Pgpool) createTemplateDB() (string, error) {
247348 query := `
248349SELECT EXISTS(SELECT 1 FROM pg_database WHERE datname = $1)
249350`
250- err := conn .QueryRow (ctx , query , tmpl ).Scan (& dbExists )
351+ err := conn .QueryRow (ctx , query , tmplDbName ).Scan (& dbExists )
251352 if err != nil {
252353 return errors .WithStack (err )
253354 }
254355 if dbExists {
255356 return nil
256357 }
257- _ , err = conn .Exec (ctx , `CREATE DATABASE ` + quote (tmpl ))
358+ _ , err = conn .Exec (ctx , `CREATE DATABASE ` + quote (tmplDbName ))
258359 return errors .WithStack (err )
259360 },
260361 )
@@ -263,23 +364,23 @@ SELECT EXISTS(SELECT 1 FROM pg_database WHERE datname = $1)
263364 }
264365
265366 if dbExists {
266- return tmpl , nil
367+ return tmplDbName , nil
267368 }
268369
269370 err = withNewConnection (
270- tmpl ,
371+ tmplDbName ,
271372 func (ctx context.Context , conn * pgx.Conn ) error {
272373 _ , err = conn .Exec (ctx , string (schemaSql ))
273374 return errors .WithStack (err )
274375 },
275376 )
276377
277378 if err != nil {
278- _ = dropDB (tmpl )
379+ _ = dropDB (tmplDbName )
279380 return "" , err
280381 }
281382
282- return tmpl , nil
383+ return tmplDbName , nil
283384}
284385
285386func quote (name string ) string {
0 commit comments