@@ -22,8 +22,8 @@ import (
2222 "github.com/hashicorp/go-multierror"
2323 "github.com/jackc/pgconn"
2424 "github.com/jackc/pgerrcode"
25+ "github.com/jackc/pgx/v4"
2526 _ "github.com/jackc/pgx/v4/stdlib"
26- "github.com/lib/pq"
2727)
2828
2929const (
@@ -69,27 +69,26 @@ type Config struct {
6969
7070type Postgres struct {
7171 // Locking and unlocking need to use the same connection
72- conn * sql.Conn
73- db * sql.DB
72+ conn * pgx.Conn
7473 isLocked atomic.Bool
7574
7675 // Open and WithInstance need to guarantee that config is never nil
7776 config * Config
7877}
7978
80- func WithInstance (instance * sql. DB , config * Config ) (database.Driver , error ) {
79+ func WithInstance (instance * pgx. Conn , config * Config ) (database.Driver , error ) {
8180 if config == nil {
8281 return nil , ErrNilConfig
8382 }
8483
85- if err := instance .Ping (); err != nil {
84+ if err := instance .Ping (context . Background () ); err != nil {
8685 return nil , err
8786 }
8887
8988 if config .DatabaseName == "" {
9089 query := `SELECT CURRENT_DATABASE()`
9190 var databaseName string
92- if err := instance .QueryRow (query ).Scan (& databaseName ); err != nil {
91+ if err := instance .QueryRow (context . Background (), query ).Scan (& databaseName ); err != nil {
9392 return nil , & database.Error {OrigErr : err , Query : []byte (query )}
9493 }
9594
@@ -103,7 +102,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
103102 if config .SchemaName == "" {
104103 query := `SELECT CURRENT_SCHEMA()`
105104 var schemaName string
106- if err := instance .QueryRow (query ).Scan (& schemaName ); err != nil {
105+ if err := instance .QueryRow (context . Background (), query ).Scan (& schemaName ); err != nil {
107106 return nil , & database.Error {OrigErr : err , Query : []byte (query )}
108107 }
109108
@@ -139,15 +138,8 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
139138 }
140139 }
141140
142- conn , err := instance .Conn (context .Background ())
143-
144- if err != nil {
145- return nil , err
146- }
147-
148141 px := & Postgres {
149- conn : conn ,
150- db : instance ,
142+ conn : instance ,
151143 config : config ,
152144 }
153145
@@ -173,7 +165,7 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
173165 // i.e. pgx://user:password@host:port/db => postgres://user:password@host:port/db
174166 purl .Scheme = "postgres"
175167
176- db , err := sql . Open ( "pgx/v4" , migrate .FilterCustomQuery (purl ).String ())
168+ db , err := pgx . Connect ( context . Background () , migrate .FilterCustomQuery (purl ).String ())
177169 if err != nil {
178170 return nil , err
179171 }
@@ -240,10 +232,9 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
240232}
241233
242234func (p * Postgres ) Close () error {
243- connErr := p .conn .Close ()
244- dbErr := p .db .Close ()
245- if connErr != nil || dbErr != nil {
246- return fmt .Errorf ("conn: %v, db: %v" , connErr , dbErr )
235+ connErr := p .conn .Close (context .Background ())
236+ if connErr != nil {
237+ return fmt .Errorf ("conn: %w" , connErr )
247238 }
248239 return nil
249240}
@@ -283,19 +274,19 @@ func (p *Postgres) applyAdvisoryLock() error {
283274
284275 // This will wait indefinitely until the lock can be acquired.
285276 query := `SELECT pg_advisory_lock($1)`
286- if _ , err := p .conn .ExecContext (context .Background (), query , aid ); err != nil {
277+ if _ , err := p .conn .Exec (context .Background (), query , aid ); err != nil {
287278 return & database.Error {OrigErr : err , Err : "try lock failed" , Query : []byte (query )}
288279 }
289280 return nil
290281}
291282
292283func (p * Postgres ) applyTableLock () error {
293- tx , err := p .conn .BeginTx (context .Background (), & sql .TxOptions {})
284+ tx , err := p .conn .BeginTx (context .Background (), pgx .TxOptions {})
294285 if err != nil {
295286 return & database.Error {OrigErr : err , Err : "transaction start failed" }
296287 }
297288 defer func () {
298- errRollback := tx .Rollback ()
289+ errRollback := tx .Rollback (context . Background () )
299290 if errRollback != nil {
300291 err = multierror .Append (err , errRollback )
301292 }
@@ -306,30 +297,25 @@ func (p *Postgres) applyTableLock() error {
306297 return err
307298 }
308299
309- query := "SELECT * FROM " + pq . QuoteIdentifier (p .config .LockTable ) + " WHERE lock_id = $1"
310- rows , err := tx .Query (query , aid )
300+ query := "SELECT * FROM " + quoteIdentifier (p .config .LockTable ) + " WHERE lock_id = $1"
301+ rows , err := tx .Query (context . Background (), query , aid )
311302 if err != nil {
312303 return database.Error {OrigErr : err , Err : "failed to fetch migration lock" , Query : []byte (query )}
313304 }
314-
315- defer func () {
316- if errClose := rows .Close (); errClose != nil {
317- err = multierror .Append (err , errClose )
318- }
319- }()
305+ defer rows .Close ()
320306
321307 // If row exists at all, lock is present
322308 locked := rows .Next ()
323309 if locked {
324310 return database .ErrLocked
325311 }
326312
327- query = "INSERT INTO " + pq . QuoteIdentifier (p .config .LockTable ) + " (lock_id) VALUES ($1)"
328- if _ , err := tx .Exec (query , aid ); err != nil {
313+ query = "INSERT INTO " + quoteIdentifier (p .config .LockTable ) + " (lock_id) VALUES ($1)"
314+ if _ , err := tx .Exec (context . Background (), query , aid ); err != nil {
329315 return database.Error {OrigErr : err , Err : "failed to set migration lock" , Query : []byte (query )}
330316 }
331317
332- return tx .Commit ()
318+ return tx .Commit (context . Background () )
333319}
334320
335321func (p * Postgres ) releaseAdvisoryLock () error {
@@ -339,7 +325,7 @@ func (p *Postgres) releaseAdvisoryLock() error {
339325 }
340326
341327 query := `SELECT pg_advisory_unlock($1)`
342- if _ , err := p .conn .ExecContext (context .Background (), query , aid ); err != nil {
328+ if _ , err := p .conn .Exec (context .Background (), query , aid ); err != nil {
343329 return & database.Error {OrigErr : err , Query : []byte (query )}
344330 }
345331
@@ -352,8 +338,8 @@ func (p *Postgres) releaseTableLock() error {
352338 return err
353339 }
354340
355- query := "DELETE FROM " + pq . QuoteIdentifier (p .config .LockTable ) + " WHERE lock_id = $1"
356- if _ , err := p .db .Exec (query , aid ); err != nil {
341+ query := "DELETE FROM " + quoteIdentifier (p .config .LockTable ) + " WHERE lock_id = $1"
342+ if _ , err := p .conn .Exec (context . Background (), query , aid ); err != nil {
357343 return database.Error {OrigErr : err , Err : "failed to release migration lock" , Query : []byte (query )}
358344 }
359345
@@ -391,7 +377,7 @@ func (p *Postgres) runStatement(statement []byte) error {
391377 if strings .TrimSpace (query ) == "" {
392378 return nil
393379 }
394- if _ , err := p .conn .ExecContext (ctx , query ); err != nil {
380+ if _ , err := p .conn .Exec (ctx , query ); err != nil {
395381
396382 if pgErr , ok := err .(* pgconn.PgError ); ok {
397383 var line uint
@@ -448,14 +434,14 @@ func runesLastIndex(input []rune, target rune) int {
448434}
449435
450436func (p * Postgres ) SetVersion (version int , dirty bool ) error {
451- tx , err := p .conn .BeginTx (context .Background (), & sql .TxOptions {})
437+ tx , err := p .conn .BeginTx (context .Background (), pgx .TxOptions {})
452438 if err != nil {
453439 return & database.Error {OrigErr : err , Err : "transaction start failed" }
454440 }
455441
456442 query := `TRUNCATE ` + quoteIdentifier (p .config .migrationsSchemaName ) + `.` + quoteIdentifier (p .config .migrationsTableName )
457- if _ , err := tx .Exec (query ); err != nil {
458- if errRollback := tx .Rollback (); errRollback != nil {
443+ if _ , err := tx .Exec (context . Background (), query ); err != nil {
444+ if errRollback := tx .Rollback (context . Background () ); errRollback != nil {
459445 err = multierror .Append (err , errRollback )
460446 }
461447 return & database.Error {OrigErr : err , Query : []byte (query )}
@@ -466,15 +452,15 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
466452 // See: https://github.com/golang-migrate/migrate/issues/330
467453 if version >= 0 || (version == database .NilVersion && dirty ) {
468454 query = `INSERT INTO ` + quoteIdentifier (p .config .migrationsSchemaName ) + `.` + quoteIdentifier (p .config .migrationsTableName ) + ` (version, dirty) VALUES ($1, $2)`
469- if _ , err := tx .Exec (query , version , dirty ); err != nil {
470- if errRollback := tx .Rollback (); errRollback != nil {
455+ if _ , err := tx .Exec (context . Background (), query , version , dirty ); err != nil {
456+ if errRollback := tx .Rollback (context . Background () ); errRollback != nil {
471457 err = multierror .Append (err , errRollback )
472458 }
473459 return & database.Error {OrigErr : err , Query : []byte (query )}
474460 }
475461 }
476462
477- if err := tx .Commit (); err != nil {
463+ if err := tx .Commit (context . Background () ); err != nil {
478464 return & database.Error {OrigErr : err , Err : "transaction commit failed" }
479465 }
480466
@@ -483,7 +469,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
483469
484470func (p * Postgres ) Version () (version int , dirty bool , err error ) {
485471 query := `SELECT version, dirty FROM ` + quoteIdentifier (p .config .migrationsSchemaName ) + `.` + quoteIdentifier (p .config .migrationsTableName ) + ` LIMIT 1`
486- err = p .conn .QueryRowContext (context .Background (), query ).Scan (& version , & dirty )
472+ err = p .conn .QueryRow (context .Background (), query ).Scan (& version , & dirty )
487473 switch {
488474 case err == sql .ErrNoRows :
489475 return database .NilVersion , false , nil
@@ -504,15 +490,11 @@ func (p *Postgres) Version() (version int, dirty bool, err error) {
504490func (p * Postgres ) Drop () (err error ) {
505491 // select all tables in current schema
506492 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
507- tables , err := p .conn .QueryContext (context .Background (), query )
493+ tables , err := p .conn .Query (context .Background (), query )
508494 if err != nil {
509495 return & database.Error {OrigErr : err , Query : []byte (query )}
510496 }
511- defer func () {
512- if errClose := tables .Close (); errClose != nil {
513- err = multierror .Append (err , errClose )
514- }
515- }()
497+ defer tables .Close ()
516498
517499 // delete one table after another
518500 tableNames := make ([]string , 0 )
@@ -539,7 +521,7 @@ func (p *Postgres) Drop() (err error) {
539521 // delete one by one ...
540522 for _ , t := range tableNames {
541523 query = `DROP TABLE IF EXISTS ` + quoteIdentifier (t ) + ` CASCADE`
542- if _ , err := p .conn .ExecContext (context .Background (), query ); err != nil {
524+ if _ , err := p .conn .Exec (context .Background (), query ); err != nil {
543525 return & database.Error {OrigErr : err , Query : []byte (query )}
544526 }
545527 }
@@ -571,7 +553,7 @@ func (p *Postgres) ensureVersionTable() (err error) {
571553 // `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
572554 // Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
573555 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1`
574- row := p .conn .QueryRowContext (context .Background (), query , p .config .migrationsSchemaName , p .config .migrationsTableName )
556+ row := p .conn .QueryRow (context .Background (), query , p .config .migrationsSchemaName , p .config .migrationsTableName )
575557
576558 var count int
577559 err = row .Scan (& count )
@@ -584,7 +566,7 @@ func (p *Postgres) ensureVersionTable() (err error) {
584566 }
585567
586568 query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier (p .config .migrationsSchemaName ) + `.` + quoteIdentifier (p .config .migrationsTableName ) + ` (version bigint not null primary key, dirty boolean not null)`
587- if _ , err = p .conn .ExecContext (context .Background (), query ); err != nil {
569+ if _ , err = p .conn .Exec (context .Background (), query ); err != nil {
588570 return & database.Error {OrigErr : err , Query : []byte (query )}
589571 }
590572
@@ -598,15 +580,15 @@ func (p *Postgres) ensureLockTable() error {
598580
599581 var count int
600582 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
601- if err := p .db .QueryRow (query , p .config .LockTable ).Scan (& count ); err != nil {
583+ if err := p .conn .QueryRow (context . Background (), query , p .config .LockTable ).Scan (& count ); err != nil {
602584 return & database.Error {OrigErr : err , Query : []byte (query )}
603585 }
604586 if count == 1 {
605587 return nil
606588 }
607589
608- query = `CREATE TABLE ` + pq . QuoteIdentifier (p .config .LockTable ) + ` (lock_id BIGINT NOT NULL PRIMARY KEY)`
609- if _ , err := p .db .Exec (query ); err != nil {
590+ query = `CREATE TABLE ` + quoteIdentifier (p .config .LockTable ) + ` (lock_id BIGINT NOT NULL PRIMARY KEY)`
591+ if _ , err := p .conn .Exec (context . Background (), query ); err != nil {
610592 return & database.Error {OrigErr : err , Query : []byte (query )}
611593 }
612594
@@ -615,9 +597,5 @@ func (p *Postgres) ensureLockTable() error {
615597
616598// Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611
617599func quoteIdentifier (name string ) string {
618- end := strings .IndexRune (name , 0 )
619- if end > - 1 {
620- name = name [:end ]
621- }
622- return `"` + strings .Replace (name , `"` , `""` , - 1 ) + `"`
600+ return pgx .Identifier ([]string {name }).Sanitize ()
623601}
0 commit comments