Skip to content

Commit 62b3399

Browse files
committed
pgx: don't use database/sql interface
it'd be nice to have contexts for many of these methods, but that'd be a much wider change
1 parent c378583 commit 62b3399

File tree

4 files changed

+117
-174
lines changed

4 files changed

+117
-174
lines changed

database/pgx/pgx.go

Lines changed: 41 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ package pgx
55

66
import (
77
"context"
8-
"database/sql"
98
"fmt"
109
"io"
1110
nurl "net/url"
@@ -22,8 +21,7 @@ import (
2221
"github.com/hashicorp/go-multierror"
2322
"github.com/jackc/pgconn"
2423
"github.com/jackc/pgerrcode"
25-
_ "github.com/jackc/pgx/v4/stdlib"
26-
"github.com/lib/pq"
24+
"github.com/jackc/pgx/v4"
2725
)
2826

2927
const (
@@ -69,27 +67,26 @@ type Config struct {
6967

7068
type Postgres struct {
7169
// Locking and unlocking need to use the same connection
72-
conn *sql.Conn
73-
db *sql.DB
70+
conn *pgx.Conn
7471
isLocked atomic.Bool
7572

7673
// Open and WithInstance need to guarantee that config is never nil
7774
config *Config
7875
}
7976

80-
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
77+
func WithInstance(instance *pgx.Conn, config *Config) (database.Driver, error) {
8178
if config == nil {
8279
return nil, ErrNilConfig
8380
}
8481

85-
if err := instance.Ping(); err != nil {
82+
if err := instance.Ping(context.Background()); err != nil {
8683
return nil, err
8784
}
8885

8986
if config.DatabaseName == "" {
9087
query := `SELECT CURRENT_DATABASE()`
9188
var databaseName string
92-
if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
89+
if err := instance.QueryRow(context.Background(), query).Scan(&databaseName); err != nil {
9390
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
9491
}
9592

@@ -103,7 +100,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
103100
if config.SchemaName == "" {
104101
query := `SELECT CURRENT_SCHEMA()`
105102
var schemaName string
106-
if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
103+
if err := instance.QueryRow(context.Background(), query).Scan(&schemaName); err != nil {
107104
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
108105
}
109106

@@ -139,15 +136,8 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
139136
}
140137
}
141138

142-
conn, err := instance.Conn(context.Background())
143-
144-
if err != nil {
145-
return nil, err
146-
}
147-
148139
px := &Postgres{
149-
conn: conn,
150-
db: instance,
140+
conn: instance,
151141
config: config,
152142
}
153143

@@ -173,7 +163,7 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
173163
// i.e. pgx://user:password@host:port/db => postgres://user:password@host:port/db
174164
purl.Scheme = "postgres"
175165

176-
db, err := sql.Open("pgx/v4", migrate.FilterCustomQuery(purl).String())
166+
db, err := pgx.Connect(context.Background(), migrate.FilterCustomQuery(purl).String())
177167
if err != nil {
178168
return nil, err
179169
}
@@ -240,10 +230,9 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
240230
}
241231

242232
func (p *Postgres) Close() error {
243-
connErr := p.conn.Close()
244-
dbErr := p.db.Close()
245-
if connErr != nil || dbErr != nil {
246-
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
233+
connErr := p.conn.Close(context.Background())
234+
if connErr != nil {
235+
return fmt.Errorf("conn: %w", connErr)
247236
}
248237
return nil
249238
}
@@ -283,19 +272,19 @@ func (p *Postgres) applyAdvisoryLock() error {
283272

284273
// This will wait indefinitely until the lock can be acquired.
285274
query := `SELECT pg_advisory_lock($1)`
286-
if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
275+
if _, err := p.conn.Exec(context.Background(), query, aid); err != nil {
287276
return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
288277
}
289278
return nil
290279
}
291280

292281
func (p *Postgres) applyTableLock() error {
293-
tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
282+
tx, err := p.conn.BeginTx(context.Background(), pgx.TxOptions{})
294283
if err != nil {
295284
return &database.Error{OrigErr: err, Err: "transaction start failed"}
296285
}
297286
defer func() {
298-
errRollback := tx.Rollback()
287+
errRollback := tx.Rollback(context.Background())
299288
if errRollback != nil {
300289
err = multierror.Append(err, errRollback)
301290
}
@@ -306,30 +295,25 @@ func (p *Postgres) applyTableLock() error {
306295
return err
307296
}
308297

309-
query := "SELECT * FROM " + pq.QuoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1"
310-
rows, err := tx.Query(query, aid)
298+
query := "SELECT * FROM " + quoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1"
299+
rows, err := tx.Query(context.Background(), query, aid)
311300
if err != nil {
312301
return database.Error{OrigErr: err, Err: "failed to fetch migration lock", Query: []byte(query)}
313302
}
314-
315-
defer func() {
316-
if errClose := rows.Close(); errClose != nil {
317-
err = multierror.Append(err, errClose)
318-
}
319-
}()
303+
defer rows.Close()
320304

321305
// If row exists at all, lock is present
322306
locked := rows.Next()
323307
if locked {
324308
return database.ErrLocked
325309
}
326310

327-
query = "INSERT INTO " + pq.QuoteIdentifier(p.config.LockTable) + " (lock_id) VALUES ($1)"
328-
if _, err := tx.Exec(query, aid); err != nil {
311+
query = "INSERT INTO " + quoteIdentifier(p.config.LockTable) + " (lock_id) VALUES ($1)"
312+
if _, err := tx.Exec(context.Background(), query, aid); err != nil {
329313
return database.Error{OrigErr: err, Err: "failed to set migration lock", Query: []byte(query)}
330314
}
331315

332-
return tx.Commit()
316+
return tx.Commit(context.Background())
333317
}
334318

335319
func (p *Postgres) releaseAdvisoryLock() error {
@@ -339,7 +323,7 @@ func (p *Postgres) releaseAdvisoryLock() error {
339323
}
340324

341325
query := `SELECT pg_advisory_unlock($1)`
342-
if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
326+
if _, err := p.conn.Exec(context.Background(), query, aid); err != nil {
343327
return &database.Error{OrigErr: err, Query: []byte(query)}
344328
}
345329

@@ -352,8 +336,8 @@ func (p *Postgres) releaseTableLock() error {
352336
return err
353337
}
354338

355-
query := "DELETE FROM " + pq.QuoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1"
356-
if _, err := p.db.Exec(query, aid); err != nil {
339+
query := "DELETE FROM " + quoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1"
340+
if _, err := p.conn.Exec(context.Background(), query, aid); err != nil {
357341
return database.Error{OrigErr: err, Err: "failed to release migration lock", Query: []byte(query)}
358342
}
359343

@@ -391,7 +375,7 @@ func (p *Postgres) runStatement(statement []byte) error {
391375
if strings.TrimSpace(query) == "" {
392376
return nil
393377
}
394-
if _, err := p.conn.ExecContext(ctx, query); err != nil {
378+
if _, err := p.conn.Exec(ctx, query); err != nil {
395379

396380
if pgErr, ok := err.(*pgconn.PgError); ok {
397381
var line uint
@@ -448,14 +432,14 @@ func runesLastIndex(input []rune, target rune) int {
448432
}
449433

450434
func (p *Postgres) SetVersion(version int, dirty bool) error {
451-
tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
435+
tx, err := p.conn.BeginTx(context.Background(), pgx.TxOptions{})
452436
if err != nil {
453437
return &database.Error{OrigErr: err, Err: "transaction start failed"}
454438
}
455439

456440
query := `TRUNCATE ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName)
457-
if _, err := tx.Exec(query); err != nil {
458-
if errRollback := tx.Rollback(); errRollback != nil {
441+
if _, err := tx.Exec(context.Background(), query); err != nil {
442+
if errRollback := tx.Rollback(context.Background()); errRollback != nil {
459443
err = multierror.Append(err, errRollback)
460444
}
461445
return &database.Error{OrigErr: err, Query: []byte(query)}
@@ -466,15 +450,15 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
466450
// See: https://github.com/golang-migrate/migrate/issues/330
467451
if version >= 0 || (version == database.NilVersion && dirty) {
468452
query = `INSERT INTO ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)`
469-
if _, err := tx.Exec(query, version, dirty); err != nil {
470-
if errRollback := tx.Rollback(); errRollback != nil {
453+
if _, err := tx.Exec(context.Background(), query, version, dirty); err != nil {
454+
if errRollback := tx.Rollback(context.Background()); errRollback != nil {
471455
err = multierror.Append(err, errRollback)
472456
}
473457
return &database.Error{OrigErr: err, Query: []byte(query)}
474458
}
475459
}
476460

477-
if err := tx.Commit(); err != nil {
461+
if err := tx.Commit(context.Background()); err != nil {
478462
return &database.Error{OrigErr: err, Err: "transaction commit failed"}
479463
}
480464

@@ -483,9 +467,9 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
483467

484468
func (p *Postgres) Version() (version int, dirty bool, err error) {
485469
query := `SELECT version, dirty FROM ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1`
486-
err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
470+
err = p.conn.QueryRow(context.Background(), query).Scan(&version, &dirty)
487471
switch {
488-
case err == sql.ErrNoRows:
472+
case err == pgx.ErrNoRows:
489473
return database.NilVersion, false, nil
490474

491475
case err != nil:
@@ -504,15 +488,11 @@ func (p *Postgres) Version() (version int, dirty bool, err error) {
504488
func (p *Postgres) Drop() (err error) {
505489
// select all tables in current schema
506490
query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
507-
tables, err := p.conn.QueryContext(context.Background(), query)
491+
tables, err := p.conn.Query(context.Background(), query)
508492
if err != nil {
509493
return &database.Error{OrigErr: err, Query: []byte(query)}
510494
}
511-
defer func() {
512-
if errClose := tables.Close(); errClose != nil {
513-
err = multierror.Append(err, errClose)
514-
}
515-
}()
495+
defer tables.Close()
516496

517497
// delete one table after another
518498
tableNames := make([]string, 0)
@@ -539,7 +519,7 @@ func (p *Postgres) Drop() (err error) {
539519
// delete one by one ...
540520
for _, t := range tableNames {
541521
query = `DROP TABLE IF EXISTS ` + quoteIdentifier(t) + ` CASCADE`
542-
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
522+
if _, err := p.conn.Exec(context.Background(), query); err != nil {
543523
return &database.Error{OrigErr: err, Query: []byte(query)}
544524
}
545525
}
@@ -571,7 +551,7 @@ func (p *Postgres) ensureVersionTable() (err error) {
571551
// `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
572552
// Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
573553
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1`
574-
row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName)
554+
row := p.conn.QueryRow(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName)
575555

576556
var count int
577557
err = row.Scan(&count)
@@ -584,7 +564,7 @@ func (p *Postgres) ensureVersionTable() (err error) {
584564
}
585565

586566
query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)`
587-
if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
567+
if _, err = p.conn.Exec(context.Background(), query); err != nil {
588568
return &database.Error{OrigErr: err, Query: []byte(query)}
589569
}
590570

@@ -598,15 +578,15 @@ func (p *Postgres) ensureLockTable() error {
598578

599579
var count int
600580
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
601-
if err := p.db.QueryRow(query, p.config.LockTable).Scan(&count); err != nil {
581+
if err := p.conn.QueryRow(context.Background(), query, p.config.LockTable).Scan(&count); err != nil {
602582
return &database.Error{OrigErr: err, Query: []byte(query)}
603583
}
604584
if count == 1 {
605585
return nil
606586
}
607587

608-
query = `CREATE TABLE ` + pq.QuoteIdentifier(p.config.LockTable) + ` (lock_id BIGINT NOT NULL PRIMARY KEY)`
609-
if _, err := p.db.Exec(query); err != nil {
588+
query = `CREATE TABLE ` + quoteIdentifier(p.config.LockTable) + ` (lock_id BIGINT NOT NULL PRIMARY KEY)`
589+
if _, err := p.conn.Exec(context.Background(), query); err != nil {
610590
return &database.Error{OrigErr: err, Query: []byte(query)}
611591
}
612592

@@ -615,9 +595,5 @@ func (p *Postgres) ensureLockTable() error {
615595

616596
// Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611
617597
func quoteIdentifier(name string) string {
618-
end := strings.IndexRune(name, 0)
619-
if end > -1 {
620-
name = name[:end]
621-
}
622-
return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
598+
return pgx.Identifier([]string{name}).Sanitize()
623599
}

0 commit comments

Comments
 (0)