Skip to content

Commit fc25522

Browse files
authored
Use lazy connections. (#5)
* Use lazy connections. This avoids the provider to try to connect even if no resources needed. Also it also avoids the provider to try to connect when defining an RDS instance (for example) and some postgres resources in the same state. (otherwise, at the first creation, provider will try to connect on a server which does not exist yet) * fixup! Merge remote-tracking branch 'origin/master' into lazy-connections
1 parent f5da1e2 commit fc25522

15 files changed

+387
-481
lines changed

postgresql/config.go

Lines changed: 77 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,9 @@ const (
2929
featurePid
3030
)
3131

32-
type dbRegistryEntry struct {
33-
db *sql.DB
34-
version semver.Version
35-
}
36-
3732
var (
3833
dbRegistryLock sync.Mutex
39-
dbRegistry map[string]dbRegistryEntry = make(map[string]dbRegistryEntry, 1)
34+
dbRegistry map[string]*DBConnection = make(map[string]*DBConnection, 1)
4035

4136
// Mapping of feature flags to versions
4237
featureSupported = map[featureName]semver.Range{
@@ -78,6 +73,40 @@ var (
7873
}
7974
)
8075

76+
type DBConnection struct {
77+
*sql.DB
78+
79+
client *Client
80+
81+
// version is the version number of the database as determined by parsing the
82+
// output of `SELECT VERSION()`.x
83+
version semver.Version
84+
}
85+
86+
// featureSupported returns true if a given feature is supported or not. This is
87+
// slightly different from Config's featureSupported in that here we're
88+
// evaluating against the fingerprinted version, not the expected version.
89+
func (db *DBConnection) featureSupported(name featureName) bool {
90+
fn, found := featureSupported[name]
91+
if !found {
92+
// panic'ing because this is a provider-only bug
93+
panic(fmt.Sprintf("unknown feature flag %v", name))
94+
}
95+
96+
return fn(db.version)
97+
}
98+
99+
// isSuperuser returns true if connected user is a Postgres SUPERUSER
100+
func (db *DBConnection) isSuperuser() (bool, error) {
101+
var superuser bool
102+
103+
if err := db.QueryRow("SELECT rolsuper FROM pg_roles WHERE rolname = CURRENT_USER").Scan(&superuser); err != nil {
104+
return false, fmt.Errorf("could not check if current user is superuser: %w", err)
105+
}
106+
107+
return superuser, nil
108+
}
109+
81110
type ClientCertificateConfig struct {
82111
CertificatePath string
83112
KeyPath string
@@ -108,14 +137,6 @@ type Client struct {
108137

109138
databaseName string
110139

111-
// db is a pointer to the DB connection. Callers are responsible for
112-
// releasing their connections.
113-
db *sql.DB
114-
115-
// version is the version number of the database as determined by parsing the
116-
// output of `SELECT VERSION()`.x
117-
version semver.Version
118-
119140
// PostgreSQL lock on pg_catalog. Many of the operations that Terraform
120141
// performs are not permitted to be concurrent. Unlike traditional
121142
// PostgreSQL tables that use MVCC, many of the PostgreSQL system
@@ -125,50 +146,11 @@ type Client struct {
125146
}
126147

127148
// NewClient returns client config for the specified database.
128-
func (c *Config) NewClient(database string) (*Client, error) {
129-
dbRegistryLock.Lock()
130-
defer dbRegistryLock.Unlock()
131-
132-
dsn := c.connStr(database)
133-
dbEntry, found := dbRegistry[dsn]
134-
if !found {
135-
db, err := sql.Open("postgres", dsn)
136-
if err != nil {
137-
return nil, fmt.Errorf("Error connecting to PostgreSQL server: %w", err)
138-
}
139-
140-
// We don't want to retain connection
141-
// So when we connect on a specific database which might be managed by terraform,
142-
// we don't keep opened connection in case of the db has to be dopped in the plan.
143-
db.SetMaxIdleConns(0)
144-
db.SetMaxOpenConns(c.MaxConns)
145-
146-
defaultVersion, _ := semver.Parse(defaultExpectedPostgreSQLVersion)
147-
version := &c.ExpectedVersion
148-
if defaultVersion.Equals(c.ExpectedVersion) {
149-
// Version hint not set by user, need to fingerprint
150-
version, err = fingerprintCapabilities(db)
151-
if err != nil {
152-
db.Close()
153-
return nil, fmt.Errorf("error detecting capabilities: %w", err)
154-
}
155-
}
156-
157-
dbEntry = dbRegistryEntry{
158-
db: db,
159-
version: *version,
160-
}
161-
dbRegistry[dsn] = dbEntry
162-
}
163-
164-
client := Client{
149+
func (c *Config) NewClient(database string) *Client {
150+
return &Client{
165151
config: *c,
166152
databaseName: database,
167-
db: dbEntry.db,
168-
version: dbEntry.version,
169153
}
170-
171-
return &client, nil
172154
}
173155

174156
// featureSupported returns true if a given feature is supported or not. This
@@ -311,11 +293,47 @@ func (c *Config) getDatabaseUsername() string {
311293
return c.Username
312294
}
313295

314-
// DB returns a copy to an sql.Open()'ed database connection. Callers must
315-
// return their database resources. Use of QueryRow() or Exec() is encouraged.
296+
// Connect returns a copy to an sql.Open()'ed database connection wrapped in a DBConnection struct.
297+
// Callers must return their database resources. Use of QueryRow() or Exec() is encouraged.
316298
// Query() must have their rows.Close()'ed.
317-
func (c *Client) DB() *sql.DB {
318-
return c.db
299+
func (c *Client) Connect() (*DBConnection, error) {
300+
dbRegistryLock.Lock()
301+
defer dbRegistryLock.Unlock()
302+
303+
dsn := c.config.connStr(c.databaseName)
304+
conn, found := dbRegistry[dsn]
305+
if !found {
306+
db, err := sql.Open("postgres", dsn)
307+
if err != nil {
308+
return nil, fmt.Errorf("Error connecting to PostgreSQL server: %w", err)
309+
}
310+
311+
// We don't want to retain connection
312+
// So when we connect on a specific database which might be managed by terraform,
313+
// we don't keep opened connection in case of the db has to be dopped in the plan.
314+
db.SetMaxIdleConns(0)
315+
db.SetMaxOpenConns(c.config.MaxConns)
316+
317+
defaultVersion, _ := semver.Parse(defaultExpectedPostgreSQLVersion)
318+
version := &c.config.ExpectedVersion
319+
if defaultVersion.Equals(c.config.ExpectedVersion) {
320+
// Version hint not set by user, need to fingerprint
321+
version, err = fingerprintCapabilities(db)
322+
if err != nil {
323+
db.Close()
324+
return nil, fmt.Errorf("error detecting capabilities: %w", err)
325+
}
326+
}
327+
328+
conn = &DBConnection{
329+
db,
330+
c,
331+
*version,
332+
}
333+
dbRegistry[dsn] = conn
334+
}
335+
336+
return conn, nil
319337
}
320338

321339
// fingerprintCapabilities queries PostgreSQL to populate a local catalog of
@@ -343,27 +361,3 @@ func fingerprintCapabilities(db *sql.DB) (*semver.Version, error) {
343361

344362
return &version, nil
345363
}
346-
347-
// featureSupported returns true if a given feature is supported or not. This is
348-
// slightly different from Config's featureSupported in that here we're
349-
// evaluating against the fingerprinted version, not the expected version.
350-
func (c *Client) featureSupported(name featureName) bool {
351-
fn, found := featureSupported[name]
352-
if !found {
353-
// panic'ing because this is a provider-only bug
354-
panic(fmt.Sprintf("unknown feature flag %v", name))
355-
}
356-
357-
return fn(c.version)
358-
}
359-
360-
// isSuperuser returns true if connected user is a Postgres SUPERUSER
361-
func (c *Client) isSuperuser() (bool, error) {
362-
var superuser bool
363-
364-
if err := c.db.QueryRow("SELECT rolsuper FROM pg_roles WHERE rolname = CURRENT_USER").Scan(&superuser); err != nil {
365-
return false, fmt.Errorf("could not check if current user is superuser: %w", err)
366-
}
367-
368-
return superuser, nil
369-
}

postgresql/helpers.go

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,38 @@ import (
1010
"github.com/lib/pq"
1111
)
1212

13+
func PGResourceFunc(fn func(*DBConnection, *schema.ResourceData) error) func(*schema.ResourceData, interface{}) error {
14+
return func(d *schema.ResourceData, meta interface{}) error {
15+
client := meta.(*Client)
16+
17+
client.catalogLock.Lock()
18+
defer client.catalogLock.Unlock()
19+
20+
db, err := client.Connect()
21+
if err != nil {
22+
return err
23+
}
24+
25+
return fn(db, d)
26+
}
27+
}
28+
29+
func PGResourceExistsFunc(fn func(*DBConnection, *schema.ResourceData) (bool, error)) func(*schema.ResourceData, interface{}) (bool, error) {
30+
return func(d *schema.ResourceData, meta interface{}) (bool, error) {
31+
client := meta.(*Client)
32+
33+
client.catalogLock.Lock()
34+
defer client.catalogLock.Unlock()
35+
36+
db, err := client.Connect()
37+
if err != nil {
38+
return false, err
39+
}
40+
41+
return fn(db, d)
42+
}
43+
}
44+
1345
// QueryAble is a DB connection (sql.DB/Tx)
1446
type QueryAble interface {
1547
Exec(query string, args ...interface{}) (sql.Result, error)
@@ -252,13 +284,13 @@ func pgArrayToSet(arr pq.ByteaArray) *schema.Set {
252284
// it will create a new connection pool if needed.
253285
func startTransaction(client *Client, database string) (*sql.Tx, error) {
254286
if database != "" && database != client.databaseName {
255-
var err error
256-
client, err = client.config.NewClient(database)
257-
if err != nil {
258-
return nil, err
259-
}
287+
client = client.config.NewClient(database)
288+
}
289+
db, err := client.Connect()
290+
if err != nil {
291+
return nil, err
260292
}
261-
db := client.DB()
293+
262294
txn, err := db.Begin()
263295
if err != nil {
264296
return nil, fmt.Errorf("could not start transaction: %w", err)
@@ -328,14 +360,12 @@ func deferredRollback(txn *sql.Tx) {
328360
}
329361
}
330362

331-
func getDatabase(d *schema.ResourceData, client *Client) string {
332-
database := client.databaseName
333-
363+
func getDatabase(d *schema.ResourceData, databaseName string) string {
334364
if v, ok := d.GetOk(extDatabaseAttr); ok {
335-
database = v.(string)
365+
databaseName = v.(string)
336366
}
337367

338-
return database
368+
return databaseName
339369
}
340370

341371
func getDatabaseOwner(db QueryAble, database string) (string, error) {

postgresql/provider.go

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,6 @@ func providerConfigure(d *schema.ResourceData) (interface{}, error) {
183183
}
184184
}
185185

186-
client, err := config.NewClient(d.Get("database").(string))
187-
if err != nil {
188-
return nil, fmt.Errorf("Error initializing PostgreSQL client: %w", err)
189-
}
190-
186+
client := config.NewClient(d.Get("database").(string))
191187
return client, nil
192188
}

0 commit comments

Comments
 (0)