Skip to content

Commit ebb59f5

Browse files
committed
[B2B-7314] Postgres client for migrations
1 parent cbb1478 commit ebb59f5

File tree

3 files changed

+112
-38
lines changed

3 files changed

+112
-38
lines changed

clients/pg.go

Lines changed: 71 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,15 @@ package clients
33
import (
44
"context"
55
"crypto/tls"
6+
"database/sql"
67
"errors"
8+
"fmt"
9+
"net/url"
710
"strings"
811
"time"
912

1013
"github.com/go-pg/pg/v10"
14+
_ "github.com/lib/pq"
1115
"github.com/sirupsen/logrus"
1216
"github.com/spf13/viper"
1317
)
@@ -18,6 +22,16 @@ type (
1822
}
1923

2024
ctxKey int
25+
26+
dbConfig struct {
27+
addr string
28+
user string
29+
password string
30+
database string
31+
poolSize int
32+
ssl bool
33+
debug bool
34+
}
2135
)
2236

2337
const ctxRequestStartKey ctxKey = 1 + iota
@@ -47,25 +61,21 @@ func (d dbQueryHook) AfterQuery(ctx context.Context, event *pg.QueryEvent) error
4761
}
4862

4963
func NewPostgreSQL(config *viper.Viper, logger *logrus.Logger) (*pg.DB, error) {
50-
config.SetDefault("database.pool", 10)
51-
config.SetDefault("database.debug", false)
52-
config.SetDefault("database.ssl", false)
53-
54-
dbAddr := config.GetString("database.addr")
55-
if dbAddr == "" {
56-
return nil, errors.New("missing database address")
64+
cfg, err := parseDBConfig(config)
65+
if err != nil {
66+
return nil, err
5767
}
5868

5969
opts := &pg.Options{
60-
Addr: config.GetString("database.addr"),
61-
User: config.GetString("database.user"),
62-
Password: config.GetString("database.password"),
63-
Database: config.GetString("database.database"),
64-
PoolSize: config.GetInt("database.pool"),
70+
Addr: cfg.addr,
71+
User: cfg.user,
72+
Password: cfg.password,
73+
Database: cfg.database,
74+
PoolSize: cfg.poolSize,
6575
}
6676

67-
if config.GetBool("database.ssl") {
68-
hp := strings.Split(dbAddr, ":")
77+
if cfg.ssl {
78+
hp := strings.Split(cfg.addr, ":")
6979
if len(hp) != 2 {
7080
return nil, errors.New("database address has wrong format")
7181
}
@@ -78,7 +88,7 @@ func NewPostgreSQL(config *viper.Viper, logger *logrus.Logger) (*pg.DB, error) {
7888

7989
connection := pg.Connect(opts)
8090

81-
if config.GetBool("database.debug") {
91+
if cfg.debug {
8292
entry := logger.WithField("module", "db")
8393
connection.AddQueryHook(dbQueryHook{
8494
logger: entry,
@@ -87,3 +97,49 @@ func NewPostgreSQL(config *viper.Viper, logger *logrus.Logger) (*pg.DB, error) {
8797

8898
return connection, nil
8999
}
100+
101+
// NewPostgreSQLForMigrations is a connection that is used for migrations.
102+
// Migrations are implemented with `goose`, which supports only `*sql.DB`.
103+
func NewPostgreSQLForMigrations(config *viper.Viper) (*sql.DB, error) {
104+
cfg, err := parseDBConfig(config)
105+
if err != nil {
106+
return nil, err
107+
}
108+
109+
dsn := fmt.Sprintf(
110+
"postgres://%s:%s@%s/%s",
111+
cfg.user,
112+
strings.ReplaceAll(url.QueryEscape(cfg.password), ":", "%3A"),
113+
cfg.addr,
114+
cfg.database,
115+
)
116+
117+
if cfg.ssl {
118+
dsn += "?sslmode=verify-ca"
119+
} else {
120+
dsn += "?sslmode=disable"
121+
}
122+
123+
return sql.Open("postgres", dsn)
124+
}
125+
126+
func parseDBConfig(config *viper.Viper) (dbConfig, error) {
127+
config.SetDefault("database.pool", 10)
128+
config.SetDefault("database.debug", false)
129+
config.SetDefault("database.ssl", false)
130+
131+
dbAddr := config.GetString("database.addr")
132+
if dbAddr == "" {
133+
return dbConfig{}, errors.New("missing database address")
134+
}
135+
136+
return dbConfig{
137+
addr: dbAddr,
138+
user: config.GetString("database.user"),
139+
password: config.GetString("database.password"),
140+
database: config.GetString("database.database"),
141+
poolSize: config.GetInt("database.pool"),
142+
ssl: config.GetBool("database.ssl"),
143+
debug: config.GetBool("database.debug"),
144+
}, nil
145+
}

clients/pg_test.go

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,7 @@ func TestNewPostgreSQL(t *testing.T) {
1313
t.Skip("skipping test in short mode")
1414
}
1515

16-
cfg := viper.New()
17-
cfg.Set("database.addr", os.Getenv("DATABASE_ADDR"))
18-
cfg.Set("database.user", os.Getenv("DATABASE_USER"))
19-
cfg.Set("database.password", os.Getenv("DATABASE_PASSWORD"))
20-
cfg.Set("database.database", os.Getenv("DATABASE_DATABASE"))
21-
cfg.Set("database.ssl", os.Getenv("DATABASE_SSL"))
22-
16+
cfg := setupConfig()
2317
logger := logrus.New()
2418

2519
db, err := NewPostgreSQL(cfg, logger)
@@ -40,3 +34,40 @@ func TestNewPostgreSQL(t *testing.T) {
4034
t.Error("unexpected message")
4135
}
4236
}
37+
38+
func TestNewPostgreSQLForMigrations(t *testing.T) {
39+
if testing.Short() {
40+
t.Skip("skipping test in short mode")
41+
}
42+
43+
cfg := setupConfig()
44+
45+
db, err := NewPostgreSQLForMigrations(cfg)
46+
if err != nil {
47+
t.Fatal(err)
48+
}
49+
50+
type StringResult struct {
51+
Message string
52+
}
53+
var res StringResult
54+
err = db.QueryRow("SELECT 'hello' AS message").Scan(&res.Message)
55+
if err != nil {
56+
t.Fatal(err)
57+
}
58+
59+
if res.Message != "hello" {
60+
t.Error("unexpected message")
61+
}
62+
}
63+
64+
func setupConfig() *viper.Viper {
65+
cfg := viper.New()
66+
cfg.Set("database.addr", os.Getenv("DATABASE_ADDR"))
67+
cfg.Set("database.user", os.Getenv("DATABASE_USER"))
68+
cfg.Set("database.password", os.Getenv("DATABASE_PASSWORD"))
69+
cfg.Set("database.database", os.Getenv("DATABASE_DATABASE"))
70+
cfg.Set("database.ssl", os.Getenv("DATABASE_SSL"))
71+
72+
return cfg
73+
}

narada/commands/migrations.go

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
package commands
22

33
import (
4-
"database/sql"
54
"errors"
6-
"fmt"
75

86
"github.com/cryptopay-dev/narada"
7+
"github.com/cryptopay-dev/narada/clients"
98

109
"github.com/pressly/goose"
1110
"github.com/sirupsen/logrus"
@@ -31,7 +30,7 @@ func MigrateUp(p *narada.Narada) *cli.Command {
3130
logger.Println("starting migrations")
3231
dir := c.String("dir")
3332

34-
db, err := connect(v)
33+
db, err := clients.NewPostgreSQLForMigrations(v)
3534
if err != nil {
3635
return err
3736
}
@@ -61,7 +60,7 @@ func MigrateDown(p *narada.Narada) *cli.Command {
6160
logger.Println("rolling back migration")
6261
dir := c.String("dir")
6362

64-
db, err := connect(v)
63+
db, err := clients.NewPostgreSQLForMigrations(v)
6564
if err != nil {
6665
return err
6766
}
@@ -107,15 +106,3 @@ func CreateMigration(p *narada.Narada) *cli.Command {
107106
},
108107
}
109108
}
110-
111-
func connect(v *viper.Viper) (*sql.DB, error) {
112-
dsn := fmt.Sprintf(
113-
"postgres://%s:%s@%s/%s?sslmode=disable",
114-
v.GetString("database.user"),
115-
v.GetString("database.password"),
116-
v.GetString("database.addr"),
117-
v.GetString("database.database"),
118-
)
119-
120-
return sql.Open("postgres", dsn)
121-
}

0 commit comments

Comments
 (0)