@@ -3,11 +3,15 @@ package clients
33import (
44 "context"
55 "crypto/tls"
6+ "database/sql"
67 "errors"
8+ "fmt"
9+ "net/url"
710 "strings"
811 "time"
912
1013 "github.com/go-pg/pg/v10"
14+ _ "github.com/lib/pq"
1115 "github.com/sirupsen/logrus"
1216 "github.com/spf13/viper"
1317)
@@ -18,6 +22,16 @@ type (
1822 }
1923
2024 ctxKey int
25+
26+ dbConfig struct {
27+ addr string
28+ user string
29+ password string
30+ database string
31+ poolSize int
32+ ssl bool
33+ debug bool
34+ }
2135)
2236
2337const ctxRequestStartKey ctxKey = 1 + iota
@@ -47,25 +61,21 @@ func (d dbQueryHook) AfterQuery(ctx context.Context, event *pg.QueryEvent) error
4761}
4862
4963func NewPostgreSQL (config * viper.Viper , logger * logrus.Logger ) (* pg.DB , error ) {
50- config .SetDefault ("database.pool" , 10 )
51- config .SetDefault ("database.debug" , false )
52- config .SetDefault ("database.ssl" , false )
53-
54- dbAddr := config .GetString ("database.addr" )
55- if dbAddr == "" {
56- return nil , errors .New ("missing database address" )
64+ cfg , err := parseDBConfig (config )
65+ if err != nil {
66+ return nil , err
5767 }
5868
5969 opts := & pg.Options {
60- Addr : config . GetString ( "database. addr" ) ,
61- User : config . GetString ( "database. user" ) ,
62- Password : config . GetString ( "database. password" ) ,
63- Database : config . GetString ( " database.database" ) ,
64- PoolSize : config . GetInt ( "database.pool" ) ,
70+ Addr : cfg . addr ,
71+ User : cfg . user ,
72+ Password : cfg . password ,
73+ Database : cfg . database ,
74+ PoolSize : cfg . poolSize ,
6575 }
6676
67- if config . GetBool ( "database. ssl" ) {
68- hp := strings .Split (dbAddr , ":" )
77+ if cfg . ssl {
78+ hp := strings .Split (cfg . addr , ":" )
6979 if len (hp ) != 2 {
7080 return nil , errors .New ("database address has wrong format" )
7181 }
@@ -78,7 +88,7 @@ func NewPostgreSQL(config *viper.Viper, logger *logrus.Logger) (*pg.DB, error) {
7888
7989 connection := pg .Connect (opts )
8090
81- if config . GetBool ( "database. debug" ) {
91+ if cfg . debug {
8292 entry := logger .WithField ("module" , "db" )
8393 connection .AddQueryHook (dbQueryHook {
8494 logger : entry ,
@@ -87,3 +97,49 @@ func NewPostgreSQL(config *viper.Viper, logger *logrus.Logger) (*pg.DB, error) {
8797
8898 return connection , nil
8999}
100+
101+ // NewPostgreSQLForMigrations is a connection that is used for migrations.
102+ // Migrations are implemented with `goose`, which supports only `*sql.DB`.
103+ func NewPostgreSQLForMigrations (config * viper.Viper ) (* sql.DB , error ) {
104+ cfg , err := parseDBConfig (config )
105+ if err != nil {
106+ return nil , err
107+ }
108+
109+ dsn := fmt .Sprintf (
110+ "postgres://%s:%s@%s/%s" ,
111+ cfg .user ,
112+ strings .ReplaceAll (url .QueryEscape (cfg .password ), ":" , "%3A" ),
113+ cfg .addr ,
114+ cfg .database ,
115+ )
116+
117+ if cfg .ssl {
118+ dsn += "?sslmode=verify-ca"
119+ } else {
120+ dsn += "?sslmode=disable"
121+ }
122+
123+ return sql .Open ("postgres" , dsn )
124+ }
125+
126+ func parseDBConfig (config * viper.Viper ) (dbConfig , error ) {
127+ config .SetDefault ("database.pool" , 10 )
128+ config .SetDefault ("database.debug" , false )
129+ config .SetDefault ("database.ssl" , false )
130+
131+ dbAddr := config .GetString ("database.addr" )
132+ if dbAddr == "" {
133+ return dbConfig {}, errors .New ("missing database address" )
134+ }
135+
136+ return dbConfig {
137+ addr : dbAddr ,
138+ user : config .GetString ("database.user" ),
139+ password : config .GetString ("database.password" ),
140+ database : config .GetString ("database.database" ),
141+ poolSize : config .GetInt ("database.pool" ),
142+ ssl : config .GetBool ("database.ssl" ),
143+ debug : config .GetBool ("database.debug" ),
144+ }, nil
145+ }
0 commit comments