@@ -29,14 +29,9 @@ const (
2929 featurePid
3030)
3131
32- type dbRegistryEntry struct {
33- db * sql.DB
34- version semver.Version
35- }
36-
3732var (
3833 dbRegistryLock sync.Mutex
39- dbRegistry map [string ]dbRegistryEntry = make (map [string ]dbRegistryEntry , 1 )
34+ dbRegistry map [string ]* DBConnection = make (map [string ]* DBConnection , 1 )
4035
4136 // Mapping of feature flags to versions
4237 featureSupported = map [featureName ]semver.Range {
7873 }
7974)
8075
76+ type DBConnection struct {
77+ * sql.DB
78+
79+ client * Client
80+
81+ // version is the version number of the database as determined by parsing the
82+ // output of `SELECT VERSION()`.x
83+ version semver.Version
84+ }
85+
86+ // featureSupported returns true if a given feature is supported or not. This is
87+ // slightly different from Config's featureSupported in that here we're
88+ // evaluating against the fingerprinted version, not the expected version.
89+ func (db * DBConnection ) featureSupported (name featureName ) bool {
90+ fn , found := featureSupported [name ]
91+ if ! found {
92+ // panic'ing because this is a provider-only bug
93+ panic (fmt .Sprintf ("unknown feature flag %v" , name ))
94+ }
95+
96+ return fn (db .version )
97+ }
98+
99+ // isSuperuser returns true if connected user is a Postgres SUPERUSER
100+ func (db * DBConnection ) isSuperuser () (bool , error ) {
101+ var superuser bool
102+
103+ if err := db .QueryRow ("SELECT rolsuper FROM pg_roles WHERE rolname = CURRENT_USER" ).Scan (& superuser ); err != nil {
104+ return false , fmt .Errorf ("could not check if current user is superuser: %w" , err )
105+ }
106+
107+ return superuser , nil
108+ }
109+
81110type ClientCertificateConfig struct {
82111 CertificatePath string
83112 KeyPath string
@@ -108,14 +137,6 @@ type Client struct {
108137
109138 databaseName string
110139
111- // db is a pointer to the DB connection. Callers are responsible for
112- // releasing their connections.
113- db * sql.DB
114-
115- // version is the version number of the database as determined by parsing the
116- // output of `SELECT VERSION()`.x
117- version semver.Version
118-
119140 // PostgreSQL lock on pg_catalog. Many of the operations that Terraform
120141 // performs are not permitted to be concurrent. Unlike traditional
121142 // PostgreSQL tables that use MVCC, many of the PostgreSQL system
@@ -125,50 +146,11 @@ type Client struct {
125146}
126147
127148// NewClient returns client config for the specified database.
128- func (c * Config ) NewClient (database string ) (* Client , error ) {
129- dbRegistryLock .Lock ()
130- defer dbRegistryLock .Unlock ()
131-
132- dsn := c .connStr (database )
133- dbEntry , found := dbRegistry [dsn ]
134- if ! found {
135- db , err := sql .Open ("postgres" , dsn )
136- if err != nil {
137- return nil , fmt .Errorf ("Error connecting to PostgreSQL server: %w" , err )
138- }
139-
140- // We don't want to retain connection
141- // So when we connect on a specific database which might be managed by terraform,
142- // we don't keep opened connection in case of the db has to be dopped in the plan.
143- db .SetMaxIdleConns (0 )
144- db .SetMaxOpenConns (c .MaxConns )
145-
146- defaultVersion , _ := semver .Parse (defaultExpectedPostgreSQLVersion )
147- version := & c .ExpectedVersion
148- if defaultVersion .Equals (c .ExpectedVersion ) {
149- // Version hint not set by user, need to fingerprint
150- version , err = fingerprintCapabilities (db )
151- if err != nil {
152- db .Close ()
153- return nil , fmt .Errorf ("error detecting capabilities: %w" , err )
154- }
155- }
156-
157- dbEntry = dbRegistryEntry {
158- db : db ,
159- version : * version ,
160- }
161- dbRegistry [dsn ] = dbEntry
162- }
163-
164- client := Client {
149+ func (c * Config ) NewClient (database string ) * Client {
150+ return & Client {
165151 config : * c ,
166152 databaseName : database ,
167- db : dbEntry .db ,
168- version : dbEntry .version ,
169153 }
170-
171- return & client , nil
172154}
173155
174156// featureSupported returns true if a given feature is supported or not. This
@@ -311,11 +293,47 @@ func (c *Config) getDatabaseUsername() string {
311293 return c .Username
312294}
313295
314- // DB returns a copy to an sql.Open()'ed database connection. Callers must
315- // return their database resources. Use of QueryRow() or Exec() is encouraged.
296+ // Connect returns a copy to an sql.Open()'ed database connection wrapped in a DBConnection struct.
297+ // Callers must return their database resources. Use of QueryRow() or Exec() is encouraged.
316298// Query() must have their rows.Close()'ed.
317- func (c * Client ) DB () * sql.DB {
318- return c .db
299+ func (c * Client ) Connect () (* DBConnection , error ) {
300+ dbRegistryLock .Lock ()
301+ defer dbRegistryLock .Unlock ()
302+
303+ dsn := c .config .connStr (c .databaseName )
304+ conn , found := dbRegistry [dsn ]
305+ if ! found {
306+ db , err := sql .Open ("postgres" , dsn )
307+ if err != nil {
308+ return nil , fmt .Errorf ("Error connecting to PostgreSQL server: %w" , err )
309+ }
310+
311+ // We don't want to retain connection
312+ // So when we connect on a specific database which might be managed by terraform,
313+ // we don't keep opened connection in case of the db has to be dopped in the plan.
314+ db .SetMaxIdleConns (0 )
315+ db .SetMaxOpenConns (c .config .MaxConns )
316+
317+ defaultVersion , _ := semver .Parse (defaultExpectedPostgreSQLVersion )
318+ version := & c .config .ExpectedVersion
319+ if defaultVersion .Equals (c .config .ExpectedVersion ) {
320+ // Version hint not set by user, need to fingerprint
321+ version , err = fingerprintCapabilities (db )
322+ if err != nil {
323+ db .Close ()
324+ return nil , fmt .Errorf ("error detecting capabilities: %w" , err )
325+ }
326+ }
327+
328+ conn = & DBConnection {
329+ db ,
330+ c ,
331+ * version ,
332+ }
333+ dbRegistry [dsn ] = conn
334+ }
335+
336+ return conn , nil
319337}
320338
321339// fingerprintCapabilities queries PostgreSQL to populate a local catalog of
@@ -343,27 +361,3 @@ func fingerprintCapabilities(db *sql.DB) (*semver.Version, error) {
343361
344362 return & version , nil
345363}
346-
347- // featureSupported returns true if a given feature is supported or not. This is
348- // slightly different from Config's featureSupported in that here we're
349- // evaluating against the fingerprinted version, not the expected version.
350- func (c * Client ) featureSupported (name featureName ) bool {
351- fn , found := featureSupported [name ]
352- if ! found {
353- // panic'ing because this is a provider-only bug
354- panic (fmt .Sprintf ("unknown feature flag %v" , name ))
355- }
356-
357- return fn (c .version )
358- }
359-
360- // isSuperuser returns true if connected user is a Postgres SUPERUSER
361- func (c * Client ) isSuperuser () (bool , error ) {
362- var superuser bool
363-
364- if err := c .db .QueryRow ("SELECT rolsuper FROM pg_roles WHERE rolname = CURRENT_USER" ).Scan (& superuser ); err != nil {
365- return false , fmt .Errorf ("could not check if current user is superuser: %w" , err )
366- }
367-
368- return superuser , nil
369- }
0 commit comments