@@ -5,7 +5,6 @@ package pgx
55
66import (
77 "context"
8- "database/sql"
98 "fmt"
109 "io"
1110 nurl "net/url"
@@ -22,8 +21,7 @@ import (
2221 "github.com/hashicorp/go-multierror"
2322 "github.com/jackc/pgconn"
2423 "github.com/jackc/pgerrcode"
25- _ "github.com/jackc/pgx/v4/stdlib"
26- "github.com/lib/pq"
24+ "github.com/jackc/pgx/v4"
2725)
2826
2927const (
@@ -69,27 +67,26 @@ type Config struct {
6967
7068type Postgres struct {
7169 // Locking and unlocking need to use the same connection
72- conn * sql.Conn
73- db * sql.DB
70+ conn * pgx.Conn
7471 isLocked atomic.Bool
7572
7673 // Open and WithInstance need to guarantee that config is never nil
7774 config * Config
7875}
7976
80- func WithInstance (instance * sql. DB , config * Config ) (database.Driver , error ) {
77+ func WithInstance (instance * pgx. Conn , config * Config ) (database.Driver , error ) {
8178 if config == nil {
8279 return nil , ErrNilConfig
8380 }
8481
85- if err := instance .Ping (); err != nil {
82+ if err := instance .Ping (context . Background () ); err != nil {
8683 return nil , err
8784 }
8885
8986 if config .DatabaseName == "" {
9087 query := `SELECT CURRENT_DATABASE()`
9188 var databaseName string
92- if err := instance .QueryRow (query ).Scan (& databaseName ); err != nil {
89+ if err := instance .QueryRow (context . Background (), query ).Scan (& databaseName ); err != nil {
9390 return nil , & database.Error {OrigErr : err , Query : []byte (query )}
9491 }
9592
@@ -103,7 +100,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
103100 if config .SchemaName == "" {
104101 query := `SELECT CURRENT_SCHEMA()`
105102 var schemaName string
106- if err := instance .QueryRow (query ).Scan (& schemaName ); err != nil {
103+ if err := instance .QueryRow (context . Background (), query ).Scan (& schemaName ); err != nil {
107104 return nil , & database.Error {OrigErr : err , Query : []byte (query )}
108105 }
109106
@@ -139,15 +136,8 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
139136 }
140137 }
141138
142- conn , err := instance .Conn (context .Background ())
143-
144- if err != nil {
145- return nil , err
146- }
147-
148139 px := & Postgres {
149- conn : conn ,
150- db : instance ,
140+ conn : instance ,
151141 config : config ,
152142 }
153143
@@ -173,7 +163,7 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
173163 // i.e. pgx://user:password@host:port/db => postgres://user:password@host:port/db
174164 purl .Scheme = "postgres"
175165
176- db , err := sql . Open ( "pgx/v4" , migrate .FilterCustomQuery (purl ).String ())
166+ db , err := pgx . Connect ( context . Background () , migrate .FilterCustomQuery (purl ).String ())
177167 if err != nil {
178168 return nil , err
179169 }
@@ -240,10 +230,9 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
240230}
241231
242232func (p * Postgres ) Close () error {
243- connErr := p .conn .Close ()
244- dbErr := p .db .Close ()
245- if connErr != nil || dbErr != nil {
246- return fmt .Errorf ("conn: %v, db: %v" , connErr , dbErr )
233+ connErr := p .conn .Close (context .Background ())
234+ if connErr != nil {
235+ return fmt .Errorf ("conn: %w" , connErr )
247236 }
248237 return nil
249238}
@@ -283,19 +272,19 @@ func (p *Postgres) applyAdvisoryLock() error {
283272
284273 // This will wait indefinitely until the lock can be acquired.
285274 query := `SELECT pg_advisory_lock($1)`
286- if _ , err := p .conn .ExecContext (context .Background (), query , aid ); err != nil {
275+ if _ , err := p .conn .Exec (context .Background (), query , aid ); err != nil {
287276 return & database.Error {OrigErr : err , Err : "try lock failed" , Query : []byte (query )}
288277 }
289278 return nil
290279}
291280
292281func (p * Postgres ) applyTableLock () error {
293- tx , err := p .conn .BeginTx (context .Background (), & sql .TxOptions {})
282+ tx , err := p .conn .BeginTx (context .Background (), pgx .TxOptions {})
294283 if err != nil {
295284 return & database.Error {OrigErr : err , Err : "transaction start failed" }
296285 }
297286 defer func () {
298- errRollback := tx .Rollback ()
287+ errRollback := tx .Rollback (context . Background () )
299288 if errRollback != nil {
300289 err = multierror .Append (err , errRollback )
301290 }
@@ -306,30 +295,25 @@ func (p *Postgres) applyTableLock() error {
306295 return err
307296 }
308297
309- query := "SELECT * FROM " + pq . QuoteIdentifier (p .config .LockTable ) + " WHERE lock_id = $1"
310- rows , err := tx .Query (query , aid )
298+ query := "SELECT * FROM " + quoteIdentifier (p .config .LockTable ) + " WHERE lock_id = $1"
299+ rows , err := tx .Query (context . Background (), query , aid )
311300 if err != nil {
312301 return database.Error {OrigErr : err , Err : "failed to fetch migration lock" , Query : []byte (query )}
313302 }
314-
315- defer func () {
316- if errClose := rows .Close (); errClose != nil {
317- err = multierror .Append (err , errClose )
318- }
319- }()
303+ defer rows .Close ()
320304
321305 // If row exists at all, lock is present
322306 locked := rows .Next ()
323307 if locked {
324308 return database .ErrLocked
325309 }
326310
327- query = "INSERT INTO " + pq . QuoteIdentifier (p .config .LockTable ) + " (lock_id) VALUES ($1)"
328- if _ , err := tx .Exec (query , aid ); err != nil {
311+ query = "INSERT INTO " + quoteIdentifier (p .config .LockTable ) + " (lock_id) VALUES ($1)"
312+ if _ , err := tx .Exec (context . Background (), query , aid ); err != nil {
329313 return database.Error {OrigErr : err , Err : "failed to set migration lock" , Query : []byte (query )}
330314 }
331315
332- return tx .Commit ()
316+ return tx .Commit (context . Background () )
333317}
334318
335319func (p * Postgres ) releaseAdvisoryLock () error {
@@ -339,7 +323,7 @@ func (p *Postgres) releaseAdvisoryLock() error {
339323 }
340324
341325 query := `SELECT pg_advisory_unlock($1)`
342- if _ , err := p .conn .ExecContext (context .Background (), query , aid ); err != nil {
326+ if _ , err := p .conn .Exec (context .Background (), query , aid ); err != nil {
343327 return & database.Error {OrigErr : err , Query : []byte (query )}
344328 }
345329
@@ -352,8 +336,8 @@ func (p *Postgres) releaseTableLock() error {
352336 return err
353337 }
354338
355- query := "DELETE FROM " + pq . QuoteIdentifier (p .config .LockTable ) + " WHERE lock_id = $1"
356- if _ , err := p .db .Exec (query , aid ); err != nil {
339+ query := "DELETE FROM " + quoteIdentifier (p .config .LockTable ) + " WHERE lock_id = $1"
340+ if _ , err := p .conn .Exec (context . Background (), query , aid ); err != nil {
357341 return database.Error {OrigErr : err , Err : "failed to release migration lock" , Query : []byte (query )}
358342 }
359343
@@ -391,7 +375,7 @@ func (p *Postgres) runStatement(statement []byte) error {
391375 if strings .TrimSpace (query ) == "" {
392376 return nil
393377 }
394- if _ , err := p .conn .ExecContext (ctx , query ); err != nil {
378+ if _ , err := p .conn .Exec (ctx , query ); err != nil {
395379
396380 if pgErr , ok := err .(* pgconn.PgError ); ok {
397381 var line uint
@@ -448,14 +432,14 @@ func runesLastIndex(input []rune, target rune) int {
448432}
449433
450434func (p * Postgres ) SetVersion (version int , dirty bool ) error {
451- tx , err := p .conn .BeginTx (context .Background (), & sql .TxOptions {})
435+ tx , err := p .conn .BeginTx (context .Background (), pgx .TxOptions {})
452436 if err != nil {
453437 return & database.Error {OrigErr : err , Err : "transaction start failed" }
454438 }
455439
456440 query := `TRUNCATE ` + quoteIdentifier (p .config .migrationsSchemaName ) + `.` + quoteIdentifier (p .config .migrationsTableName )
457- if _ , err := tx .Exec (query ); err != nil {
458- if errRollback := tx .Rollback (); errRollback != nil {
441+ if _ , err := tx .Exec (context . Background (), query ); err != nil {
442+ if errRollback := tx .Rollback (context . Background () ); errRollback != nil {
459443 err = multierror .Append (err , errRollback )
460444 }
461445 return & database.Error {OrigErr : err , Query : []byte (query )}
@@ -466,15 +450,15 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
466450 // See: https://github.com/golang-migrate/migrate/issues/330
467451 if version >= 0 || (version == database .NilVersion && dirty ) {
468452 query = `INSERT INTO ` + quoteIdentifier (p .config .migrationsSchemaName ) + `.` + quoteIdentifier (p .config .migrationsTableName ) + ` (version, dirty) VALUES ($1, $2)`
469- if _ , err := tx .Exec (query , version , dirty ); err != nil {
470- if errRollback := tx .Rollback (); errRollback != nil {
453+ if _ , err := tx .Exec (context . Background (), query , version , dirty ); err != nil {
454+ if errRollback := tx .Rollback (context . Background () ); errRollback != nil {
471455 err = multierror .Append (err , errRollback )
472456 }
473457 return & database.Error {OrigErr : err , Query : []byte (query )}
474458 }
475459 }
476460
477- if err := tx .Commit (); err != nil {
461+ if err := tx .Commit (context . Background () ); err != nil {
478462 return & database.Error {OrigErr : err , Err : "transaction commit failed" }
479463 }
480464
@@ -483,9 +467,9 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
483467
484468func (p * Postgres ) Version () (version int , dirty bool , err error ) {
485469 query := `SELECT version, dirty FROM ` + quoteIdentifier (p .config .migrationsSchemaName ) + `.` + quoteIdentifier (p .config .migrationsTableName ) + ` LIMIT 1`
486- err = p .conn .QueryRowContext (context .Background (), query ).Scan (& version , & dirty )
470+ err = p .conn .QueryRow (context .Background (), query ).Scan (& version , & dirty )
487471 switch {
488- case err == sql .ErrNoRows :
472+ case err == pgx .ErrNoRows :
489473 return database .NilVersion , false , nil
490474
491475 case err != nil :
@@ -504,15 +488,11 @@ func (p *Postgres) Version() (version int, dirty bool, err error) {
504488func (p * Postgres ) Drop () (err error ) {
505489 // select all tables in current schema
506490 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
507- tables , err := p .conn .QueryContext (context .Background (), query )
491+ tables , err := p .conn .Query (context .Background (), query )
508492 if err != nil {
509493 return & database.Error {OrigErr : err , Query : []byte (query )}
510494 }
511- defer func () {
512- if errClose := tables .Close (); errClose != nil {
513- err = multierror .Append (err , errClose )
514- }
515- }()
495+ defer tables .Close ()
516496
517497 // delete one table after another
518498 tableNames := make ([]string , 0 )
@@ -539,7 +519,7 @@ func (p *Postgres) Drop() (err error) {
539519 // delete one by one ...
540520 for _ , t := range tableNames {
541521 query = `DROP TABLE IF EXISTS ` + quoteIdentifier (t ) + ` CASCADE`
542- if _ , err := p .conn .ExecContext (context .Background (), query ); err != nil {
522+ if _ , err := p .conn .Exec (context .Background (), query ); err != nil {
543523 return & database.Error {OrigErr : err , Query : []byte (query )}
544524 }
545525 }
@@ -571,7 +551,7 @@ func (p *Postgres) ensureVersionTable() (err error) {
571551 // `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
572552 // Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
573553 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1`
574- row := p .conn .QueryRowContext (context .Background (), query , p .config .migrationsSchemaName , p .config .migrationsTableName )
554+ row := p .conn .QueryRow (context .Background (), query , p .config .migrationsSchemaName , p .config .migrationsTableName )
575555
576556 var count int
577557 err = row .Scan (& count )
@@ -584,7 +564,7 @@ func (p *Postgres) ensureVersionTable() (err error) {
584564 }
585565
586566 query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier (p .config .migrationsSchemaName ) + `.` + quoteIdentifier (p .config .migrationsTableName ) + ` (version bigint not null primary key, dirty boolean not null)`
587- if _ , err = p .conn .ExecContext (context .Background (), query ); err != nil {
567+ if _ , err = p .conn .Exec (context .Background (), query ); err != nil {
588568 return & database.Error {OrigErr : err , Query : []byte (query )}
589569 }
590570
@@ -598,15 +578,15 @@ func (p *Postgres) ensureLockTable() error {
598578
599579 var count int
600580 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
601- if err := p .db .QueryRow (query , p .config .LockTable ).Scan (& count ); err != nil {
581+ if err := p .conn .QueryRow (context . Background (), query , p .config .LockTable ).Scan (& count ); err != nil {
602582 return & database.Error {OrigErr : err , Query : []byte (query )}
603583 }
604584 if count == 1 {
605585 return nil
606586 }
607587
608- query = `CREATE TABLE ` + pq . QuoteIdentifier (p .config .LockTable ) + ` (lock_id BIGINT NOT NULL PRIMARY KEY)`
609- if _ , err := p .db .Exec (query ); err != nil {
588+ query = `CREATE TABLE ` + quoteIdentifier (p .config .LockTable ) + ` (lock_id BIGINT NOT NULL PRIMARY KEY)`
589+ if _ , err := p .conn .Exec (context . Background (), query ); err != nil {
610590 return & database.Error {OrigErr : err , Query : []byte (query )}
611591 }
612592
@@ -615,9 +595,5 @@ func (p *Postgres) ensureLockTable() error {
615595
616596// Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611
617597func quoteIdentifier (name string ) string {
618- end := strings .IndexRune (name , 0 )
619- if end > - 1 {
620- name = name [:end ]
621- }
622- return `"` + strings .Replace (name , `"` , `""` , - 1 ) + `"`
598+ return pgx .Identifier ([]string {name }).Sanitize ()
623599}
0 commit comments