@@ -27,14 +27,9 @@ const (
2727 featurePrivileges
2828)
2929
30- type dbRegistryEntry struct {
31- db * sql.DB
32- version semver.Version
33- }
34-
3530var (
3631 dbRegistryLock sync.Mutex
37- dbRegistry map [string ]dbRegistryEntry = make (map [string ]dbRegistryEntry , 1 )
32+ dbRegistry map [string ]* DBConnection = make (map [string ]* DBConnection , 1 )
3833
3934 // Mapping of feature flags to versions
4035 featureSupported = map [featureName ]semver.Range {
6863 }
6964)
7065
66+ type DBConnection struct {
67+ * sql.DB
68+
69+ client * Client
70+
71+ // version is the version number of the database as determined by parsing the
72+ // output of `SELECT VERSION()`.x
73+ version semver.Version
74+ }
75+
76+ // featureSupported returns true if a given feature is supported or not. This is
77+ // slightly different from Config's featureSupported in that here we're
78+ // evaluating against the fingerprinted version, not the expected version.
79+ func (db * DBConnection ) featureSupported (name featureName ) bool {
80+ fn , found := featureSupported [name ]
81+ if ! found {
82+ // panic'ing because this is a provider-only bug
83+ panic (fmt .Sprintf ("unknown feature flag %v" , name ))
84+ }
85+
86+ return fn (db .version )
87+ }
88+
89+ // isSuperuser returns true if connected user is a Postgres SUPERUSER
90+ func (db * DBConnection ) isSuperuser () (bool , error ) {
91+ var superuser bool
92+
93+ if err := db .QueryRow ("SELECT rolsuper FROM pg_roles WHERE rolname = CURRENT_USER" ).Scan (& superuser ); err != nil {
94+ return false , fmt .Errorf ("could not check if current user is superuser: %w" , err )
95+ }
96+
97+ return superuser , nil
98+ }
99+
71100type ClientCertificateConfig struct {
72101 CertificatePath string
73102 KeyPath string
@@ -98,14 +127,6 @@ type Client struct {
98127
99128 databaseName string
100129
101- // db is a pointer to the DB connection. Callers are responsible for
102- // releasing their connections.
103- db * sql.DB
104-
105- // version is the version number of the database as determined by parsing the
106- // output of `SELECT VERSION()`.x
107- version semver.Version
108-
109130 // PostgreSQL lock on pg_catalog. Many of the operations that Terraform
110131 // performs are not permitted to be concurrent. Unlike traditional
111132 // PostgreSQL tables that use MVCC, many of the PostgreSQL system
@@ -115,50 +136,11 @@ type Client struct {
115136}
116137
117138// NewClient returns client config for the specified database.
118- func (c * Config ) NewClient (database string ) (* Client , error ) {
119- dbRegistryLock .Lock ()
120- defer dbRegistryLock .Unlock ()
121-
122- dsn := c .connStr (database )
123- dbEntry , found := dbRegistry [dsn ]
124- if ! found {
125- db , err := sql .Open ("postgres" , dsn )
126- if err != nil {
127- return nil , fmt .Errorf ("Error connecting to PostgreSQL server: %w" , err )
128- }
129-
130- // We don't want to retain connection
131- // So when we connect on a specific database which might be managed by terraform,
132- // we don't keep opened connection in case of the db has to be dopped in the plan.
133- db .SetMaxIdleConns (0 )
134- db .SetMaxOpenConns (c .MaxConns )
135-
136- defaultVersion , _ := semver .Parse (defaultExpectedPostgreSQLVersion )
137- version := & c .ExpectedVersion
138- if defaultVersion .Equals (c .ExpectedVersion ) {
139- // Version hint not set by user, need to fingerprint
140- version , err = fingerprintCapabilities (db )
141- if err != nil {
142- db .Close ()
143- return nil , fmt .Errorf ("error detecting capabilities: %w" , err )
144- }
145- }
146-
147- dbEntry = dbRegistryEntry {
148- db : db ,
149- version : * version ,
150- }
151- dbRegistry [dsn ] = dbEntry
152- }
153-
154- client := Client {
139+ func (c * Config ) NewClient (database string ) * Client {
140+ return & Client {
155141 config : * c ,
156142 databaseName : database ,
157- db : dbEntry .db ,
158- version : dbEntry .version ,
159143 }
160-
161- return & client , nil
162144}
163145
164146// featureSupported returns true if a given feature is supported or not. This
@@ -301,11 +283,47 @@ func (c *Config) getDatabaseUsername() string {
301283 return c .Username
302284}
303285
304- // DB returns a copy to an sql.Open()'ed database connection. Callers must
305- // return their database resources. Use of QueryRow() or Exec() is encouraged.
286+ // Connect returns a copy to an sql.Open()'ed database connection wrapped in a DBConnection struct.
287+ // Callers must return their database resources. Use of QueryRow() or Exec() is encouraged.
306288// Query() must have their rows.Close()'ed.
307- func (c * Client ) DB () * sql.DB {
308- return c .db
289+ func (c * Client ) Connect () (* DBConnection , error ) {
290+ dbRegistryLock .Lock ()
291+ defer dbRegistryLock .Unlock ()
292+
293+ dsn := c .config .connStr (c .databaseName )
294+ conn , found := dbRegistry [dsn ]
295+ if ! found {
296+ db , err := sql .Open ("postgres" , dsn )
297+ if err != nil {
298+ return nil , fmt .Errorf ("Error connecting to PostgreSQL server: %w" , err )
299+ }
300+
301+ // We don't want to retain connection
302+ // So when we connect on a specific database which might be managed by terraform,
303+ // we don't keep opened connection in case of the db has to be dopped in the plan.
304+ db .SetMaxIdleConns (0 )
305+ db .SetMaxOpenConns (c .config .MaxConns )
306+
307+ defaultVersion , _ := semver .Parse (defaultExpectedPostgreSQLVersion )
308+ version := & c .config .ExpectedVersion
309+ if defaultVersion .Equals (c .config .ExpectedVersion ) {
310+ // Version hint not set by user, need to fingerprint
311+ version , err = fingerprintCapabilities (db )
312+ if err != nil {
313+ db .Close ()
314+ return nil , fmt .Errorf ("error detecting capabilities: %w" , err )
315+ }
316+ }
317+
318+ conn = & DBConnection {
319+ db ,
320+ c ,
321+ * version ,
322+ }
323+ dbRegistry [dsn ] = conn
324+ }
325+
326+ return conn , nil
309327}
310328
311329// fingerprintCapabilities queries PostgreSQL to populate a local catalog of
@@ -333,27 +351,3 @@ func fingerprintCapabilities(db *sql.DB) (*semver.Version, error) {
333351
334352 return & version , nil
335353}
336-
337- // featureSupported returns true if a given feature is supported or not. This is
338- // slightly different from Config's featureSupported in that here we're
339- // evaluating against the fingerprinted version, not the expected version.
340- func (c * Client ) featureSupported (name featureName ) bool {
341- fn , found := featureSupported [name ]
342- if ! found {
343- // panic'ing because this is a provider-only bug
344- panic (fmt .Sprintf ("unknown feature flag %v" , name ))
345- }
346-
347- return fn (c .version )
348- }
349-
350- // isSuperuser returns true if connected user is a Postgres SUPERUSER
351- func (c * Client ) isSuperuser () (bool , error ) {
352- var superuser bool
353-
354- if err := c .db .QueryRow ("SELECT rolsuper FROM pg_roles WHERE rolname = CURRENT_USER" ).Scan (& superuser ); err != nil {
355- return false , fmt .Errorf ("could not check if current user is superuser: %w" , err )
356- }
357-
358- return superuser , nil
359- }
0 commit comments