@@ -2,9 +2,16 @@ package clients
22
33import (
44 "context"
5+ "crypto/tls"
6+ "database/sql"
7+ "errors"
8+ "fmt"
9+ "net/url"
10+ "strings"
511 "time"
612
7- "github.com/go-pg/pg/v10"
13+ pg "github.com/go-pg/pg/v10"
14+ _ "github.com/lib/pq"
815 "github.com/sirupsen/logrus"
916 "github.com/spf13/viper"
1017)
@@ -15,6 +22,16 @@ type (
1522 }
1623
1724 ctxKey int
25+
26+ dbConfig struct {
27+ addr string
28+ user string
29+ password string
30+ database string
31+ poolSize int
32+ ssl bool
33+ debug bool
34+ }
1835)
1936
2037const ctxRequestStartKey ctxKey = 1 + iota
@@ -43,24 +60,86 @@ func (d dbQueryHook) AfterQuery(ctx context.Context, event *pg.QueryEvent) error
4360 return nil
4461}
4562
46- func NewPostgreSQL (config * viper.Viper , logger * logrus.Logger ) * pg.DB {
47- config .SetDefault ("database.pool" , 10 )
48- config .SetDefault ("database.debug" , false )
63+ func NewPostgreSQL (config * viper.Viper , logger * logrus.Logger ) (* pg.DB , error ) {
64+ cfg , err := parseDBConfig (config )
65+ if err != nil {
66+ return nil , err
67+ }
4968
50- connection := pg . Connect ( & pg.Options {
51- Addr : config . GetString ( "database. addr" ) ,
52- User : config . GetString ( "database. user" ) ,
53- Password : config . GetString ( "database. password" ) ,
54- Database : config . GetString ( " database.database" ) ,
55- PoolSize : config . GetInt ( "database.pool" ) ,
56- })
69+ opts := & pg.Options {
70+ Addr : cfg . addr ,
71+ User : cfg . user ,
72+ Password : cfg . password ,
73+ Database : cfg . database ,
74+ PoolSize : cfg . poolSize ,
75+ }
5776
58- if config .GetBool ("database.debug" ) {
77+ if cfg .ssl {
78+ hp := strings .Split (cfg .addr , ":" )
79+ if len (hp ) != 2 {
80+ return nil , errors .New ("database address has wrong format" )
81+ }
82+
83+ opts .TLSConfig = & tls.Config {
84+ InsecureSkipVerify : false ,
85+ ServerName : hp [0 ],
86+ }
87+ }
88+
89+ connection := pg .Connect (opts )
90+
91+ if cfg .debug {
5992 entry := logger .WithField ("module" , "db" )
6093 connection .AddQueryHook (dbQueryHook {
6194 logger : entry ,
6295 })
6396 }
6497
65- return connection
98+ return connection , nil
99+ }
100+
101+ // NewPostgreSQLForMigrations is a connection that is used for migrations.
102+ // Migrations are implemented with `goose`, which supports only `*sql.DB`.
103+ func NewPostgreSQLForMigrations (config * viper.Viper ) (* sql.DB , error ) {
104+ cfg , err := parseDBConfig (config )
105+ if err != nil {
106+ return nil , err
107+ }
108+
109+ dsn := fmt .Sprintf (
110+ "postgres://%s:%s@%s/%s" ,
111+ cfg .user ,
112+ strings .ReplaceAll (url .QueryEscape (cfg .password ), ":" , "%3A" ),
113+ cfg .addr ,
114+ cfg .database ,
115+ )
116+
117+ if cfg .ssl {
118+ dsn += "?sslmode=verify-ca"
119+ } else {
120+ dsn += "?sslmode=disable"
121+ }
122+
123+ return sql .Open ("postgres" , dsn )
124+ }
125+
126+ func parseDBConfig (config * viper.Viper ) (dbConfig , error ) {
127+ config .SetDefault ("database.pool" , 10 )
128+ config .SetDefault ("database.debug" , false )
129+ config .SetDefault ("database.ssl" , false )
130+
131+ dbAddr := config .GetString ("database.addr" )
132+ if dbAddr == "" {
133+ return dbConfig {}, errors .New ("missing database address" )
134+ }
135+
136+ return dbConfig {
137+ addr : dbAddr ,
138+ user : config .GetString ("database.user" ),
139+ password : config .GetString ("database.password" ),
140+ database : config .GetString ("database.database" ),
141+ poolSize : config .GetInt ("database.pool" ),
142+ ssl : config .GetBool ("database.ssl" ),
143+ debug : config .GetBool ("database.debug" ),
144+ }, nil
66145}
0 commit comments