Skip to content

Commit 04541ab

Browse files
committed
Use lazy connections.
This avoids the provider to try to connect even if no resources needed. Also it also avoids the provider to try to connect when defining an RDS instance (for example) and some postgres resources in the same state. (otherwise, at the first creation, provider will try to connect on a server which does not exist yet)
1 parent 4465ed5 commit 04541ab

14 files changed

+364
-442
lines changed

postgresql/config.go

Lines changed: 77 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,9 @@ const (
2727
featurePrivileges
2828
)
2929

30-
type dbRegistryEntry struct {
31-
db *sql.DB
32-
version semver.Version
33-
}
34-
3530
var (
3631
dbRegistryLock sync.Mutex
37-
dbRegistry map[string]dbRegistryEntry = make(map[string]dbRegistryEntry, 1)
32+
dbRegistry map[string]*DBConnection = make(map[string]*DBConnection, 1)
3833

3934
// Mapping of feature flags to versions
4035
featureSupported = map[featureName]semver.Range{
@@ -68,6 +63,40 @@ var (
6863
}
6964
)
7065

66+
type DBConnection struct {
67+
*sql.DB
68+
69+
client *Client
70+
71+
// version is the version number of the database as determined by parsing the
72+
// output of `SELECT VERSION()`.x
73+
version semver.Version
74+
}
75+
76+
// featureSupported returns true if a given feature is supported or not. This is
77+
// slightly different from Config's featureSupported in that here we're
78+
// evaluating against the fingerprinted version, not the expected version.
79+
func (db *DBConnection) featureSupported(name featureName) bool {
80+
fn, found := featureSupported[name]
81+
if !found {
82+
// panic'ing because this is a provider-only bug
83+
panic(fmt.Sprintf("unknown feature flag %v", name))
84+
}
85+
86+
return fn(db.version)
87+
}
88+
89+
// isSuperuser returns true if connected user is a Postgres SUPERUSER
90+
func (db *DBConnection) isSuperuser() (bool, error) {
91+
var superuser bool
92+
93+
if err := db.QueryRow("SELECT rolsuper FROM pg_roles WHERE rolname = CURRENT_USER").Scan(&superuser); err != nil {
94+
return false, fmt.Errorf("could not check if current user is superuser: %w", err)
95+
}
96+
97+
return superuser, nil
98+
}
99+
71100
type ClientCertificateConfig struct {
72101
CertificatePath string
73102
KeyPath string
@@ -98,14 +127,6 @@ type Client struct {
98127

99128
databaseName string
100129

101-
// db is a pointer to the DB connection. Callers are responsible for
102-
// releasing their connections.
103-
db *sql.DB
104-
105-
// version is the version number of the database as determined by parsing the
106-
// output of `SELECT VERSION()`.x
107-
version semver.Version
108-
109130
// PostgreSQL lock on pg_catalog. Many of the operations that Terraform
110131
// performs are not permitted to be concurrent. Unlike traditional
111132
// PostgreSQL tables that use MVCC, many of the PostgreSQL system
@@ -115,50 +136,11 @@ type Client struct {
115136
}
116137

117138
// NewClient returns client config for the specified database.
118-
func (c *Config) NewClient(database string) (*Client, error) {
119-
dbRegistryLock.Lock()
120-
defer dbRegistryLock.Unlock()
121-
122-
dsn := c.connStr(database)
123-
dbEntry, found := dbRegistry[dsn]
124-
if !found {
125-
db, err := sql.Open("postgres", dsn)
126-
if err != nil {
127-
return nil, fmt.Errorf("Error connecting to PostgreSQL server: %w", err)
128-
}
129-
130-
// We don't want to retain connection
131-
// So when we connect on a specific database which might be managed by terraform,
132-
// we don't keep opened connection in case of the db has to be dopped in the plan.
133-
db.SetMaxIdleConns(0)
134-
db.SetMaxOpenConns(c.MaxConns)
135-
136-
defaultVersion, _ := semver.Parse(defaultExpectedPostgreSQLVersion)
137-
version := &c.ExpectedVersion
138-
if defaultVersion.Equals(c.ExpectedVersion) {
139-
// Version hint not set by user, need to fingerprint
140-
version, err = fingerprintCapabilities(db)
141-
if err != nil {
142-
db.Close()
143-
return nil, fmt.Errorf("error detecting capabilities: %w", err)
144-
}
145-
}
146-
147-
dbEntry = dbRegistryEntry{
148-
db: db,
149-
version: *version,
150-
}
151-
dbRegistry[dsn] = dbEntry
152-
}
153-
154-
client := Client{
139+
func (c *Config) NewClient(database string) *Client {
140+
return &Client{
155141
config: *c,
156142
databaseName: database,
157-
db: dbEntry.db,
158-
version: dbEntry.version,
159143
}
160-
161-
return &client, nil
162144
}
163145

164146
// featureSupported returns true if a given feature is supported or not. This
@@ -301,11 +283,47 @@ func (c *Config) getDatabaseUsername() string {
301283
return c.Username
302284
}
303285

304-
// DB returns a copy to an sql.Open()'ed database connection. Callers must
305-
// return their database resources. Use of QueryRow() or Exec() is encouraged.
286+
// Connect returns a copy to an sql.Open()'ed database connection wrapped in a DBConnection struct.
287+
// Callers must return their database resources. Use of QueryRow() or Exec() is encouraged.
306288
// Query() must have their rows.Close()'ed.
307-
func (c *Client) DB() *sql.DB {
308-
return c.db
289+
func (c *Client) Connect() (*DBConnection, error) {
290+
dbRegistryLock.Lock()
291+
defer dbRegistryLock.Unlock()
292+
293+
dsn := c.config.connStr(c.databaseName)
294+
conn, found := dbRegistry[dsn]
295+
if !found {
296+
db, err := sql.Open("postgres", dsn)
297+
if err != nil {
298+
return nil, fmt.Errorf("Error connecting to PostgreSQL server: %w", err)
299+
}
300+
301+
// We don't want to retain connection
302+
// So when we connect on a specific database which might be managed by terraform,
303+
// we don't keep opened connection in case of the db has to be dopped in the plan.
304+
db.SetMaxIdleConns(0)
305+
db.SetMaxOpenConns(c.config.MaxConns)
306+
307+
defaultVersion, _ := semver.Parse(defaultExpectedPostgreSQLVersion)
308+
version := &c.config.ExpectedVersion
309+
if defaultVersion.Equals(c.config.ExpectedVersion) {
310+
// Version hint not set by user, need to fingerprint
311+
version, err = fingerprintCapabilities(db)
312+
if err != nil {
313+
db.Close()
314+
return nil, fmt.Errorf("error detecting capabilities: %w", err)
315+
}
316+
}
317+
318+
conn = &DBConnection{
319+
db,
320+
c,
321+
*version,
322+
}
323+
dbRegistry[dsn] = conn
324+
}
325+
326+
return conn, nil
309327
}
310328

311329
// fingerprintCapabilities queries PostgreSQL to populate a local catalog of
@@ -333,27 +351,3 @@ func fingerprintCapabilities(db *sql.DB) (*semver.Version, error) {
333351

334352
return &version, nil
335353
}
336-
337-
// featureSupported returns true if a given feature is supported or not. This is
338-
// slightly different from Config's featureSupported in that here we're
339-
// evaluating against the fingerprinted version, not the expected version.
340-
func (c *Client) featureSupported(name featureName) bool {
341-
fn, found := featureSupported[name]
342-
if !found {
343-
// panic'ing because this is a provider-only bug
344-
panic(fmt.Sprintf("unknown feature flag %v", name))
345-
}
346-
347-
return fn(c.version)
348-
}
349-
350-
// isSuperuser returns true if connected user is a Postgres SUPERUSER
351-
func (c *Client) isSuperuser() (bool, error) {
352-
var superuser bool
353-
354-
if err := c.db.QueryRow("SELECT rolsuper FROM pg_roles WHERE rolname = CURRENT_USER").Scan(&superuser); err != nil {
355-
return false, fmt.Errorf("could not check if current user is superuser: %w", err)
356-
}
357-
358-
return superuser, nil
359-
}

postgresql/helpers.go

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,38 @@ import (
1010
"github.com/lib/pq"
1111
)
1212

13+
func PGResourceFunc(fn func(*DBConnection, *schema.ResourceData) error) func(*schema.ResourceData, interface{}) error {
14+
return func(d *schema.ResourceData, meta interface{}) error {
15+
client := meta.(*Client)
16+
17+
client.catalogLock.Lock()
18+
defer client.catalogLock.Unlock()
19+
20+
db, err := client.Connect()
21+
if err != nil {
22+
return err
23+
}
24+
25+
return fn(db, d)
26+
}
27+
}
28+
29+
func PGResourceExistsFunc(fn func(*DBConnection, *schema.ResourceData) (bool, error)) func(*schema.ResourceData, interface{}) (bool, error) {
30+
return func(d *schema.ResourceData, meta interface{}) (bool, error) {
31+
client := meta.(*Client)
32+
33+
client.catalogLock.Lock()
34+
defer client.catalogLock.Unlock()
35+
36+
db, err := client.Connect()
37+
if err != nil {
38+
return false, err
39+
}
40+
41+
return fn(db, d)
42+
}
43+
}
44+
1345
// QueryAble is a DB connection (sql.DB/Tx)
1446
type QueryAble interface {
1547
Exec(query string, args ...interface{}) (sql.Result, error)
@@ -252,13 +284,13 @@ func pgArrayToSet(arr pq.ByteaArray) *schema.Set {
252284
// it will create a new connection pool if needed.
253285
func startTransaction(client *Client, database string) (*sql.Tx, error) {
254286
if database != "" && database != client.databaseName {
255-
var err error
256-
client, err = client.config.NewClient(database)
257-
if err != nil {
258-
return nil, err
259-
}
287+
client = client.config.NewClient(database)
288+
}
289+
db, err := client.Connect()
290+
if err != nil {
291+
return nil, err
260292
}
261-
db := client.DB()
293+
262294
txn, err := db.Begin()
263295
if err != nil {
264296
return nil, fmt.Errorf("could not start transaction: %w", err)
@@ -328,14 +360,12 @@ func deferredRollback(txn *sql.Tx) {
328360
}
329361
}
330362

331-
func getDatabase(d *schema.ResourceData, client *Client) string {
332-
database := client.databaseName
333-
363+
func getDatabase(d *schema.ResourceData, databaseName string) string {
334364
if v, ok := d.GetOk(extDatabaseAttr); ok {
335-
database = v.(string)
365+
databaseName = v.(string)
336366
}
337367

338-
return database
368+
return databaseName
339369
}
340370

341371
func getDatabaseOwner(db QueryAble, database string) (string, error) {

postgresql/provider.go

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,10 +182,6 @@ func providerConfigure(d *schema.ResourceData) (interface{}, error) {
182182
}
183183
}
184184

185-
client, err := config.NewClient(d.Get("database").(string))
186-
if err != nil {
187-
return nil, fmt.Errorf("Error initializing PostgreSQL client: %w", err)
188-
}
189-
185+
client := config.NewClient(d.Get("database").(string))
190186
return client, nil
191187
}

0 commit comments

Comments
 (0)