From 13060a5879aa5eda55fd0ffcd5a13b9c4f0e2099 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Torres=20Cogollo?= Date: Tue, 30 Sep 2025 13:42:19 +0200 Subject: [PATCH] fix: Connection retry logic --- .go-version | 2 +- postgresql/config.go | 44 ++++++++++++++++++++++++++++++------------ postgresql/provider.go | 34 +++++++++++++++++++++++++++++--- 3 files changed, 64 insertions(+), 16 deletions(-) diff --git a/.go-version b/.go-version index 4e00d0ac..53cc1a6f 100644 --- a/.go-version +++ b/.go-version @@ -1 +1 @@ -1.14.4 +1.24.0 diff --git a/postgresql/config.go b/postgresql/config.go index d19ca762..ebf2abc1 100644 --- a/postgresql/config.go +++ b/postgresql/config.go @@ -8,9 +8,11 @@ import ( "strconv" "strings" "sync" + "time" "unicode" "github.com/blang/semver" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/retry" _ "github.com/lib/pq" // PostgreSQL db "gocloud.dev/gcp" "gocloud.dev/gcp/cloudsql" @@ -178,6 +180,9 @@ type Config struct { ApplicationName string Timeout int ConnectTimeoutSec int + MaxConnRetries int + ConnectionRetryTimeoutSeconds int + ConnMaxLifetimeSeconds int MaxConns int ExpectedVersion semver.Version SSLClientCert *ClientCertificateConfig @@ -282,6 +287,7 @@ func (c *Config) getDatabaseUsername() string { func (c *Client) Connect() (*DBConnection, error) { dbRegistryLock.Lock() defer dbRegistryLock.Unlock() + ctx := context.Background() dsn := c.config.connStr(c.databaseName) conn, found := dbRegistry[dsn] @@ -289,19 +295,32 @@ func (c *Client) Connect() (*DBConnection, error) { var db *sql.DB var err error - if c.config.Scheme == "postgres" { - db, err = sql.Open(proxyDriverName, dsn) - } else if c.config.Scheme == "gcppostgres" && c.config.GCPIAMImpersonateServiceAccount != "" { - db, err = openImpersonatedGCPDBConnection(context.Background(), dsn, c.config.GCPIAMImpersonateServiceAccount) - } else { - db, err = postgres.Open(context.Background(), dsn) - } + retryCount := 0 + + connectRetryTimeout := time.Duration(c.config.ConnectionRetryTimeoutSeconds) * time.Second + retryError := retry.RetryContext(ctx, connectRetryTimeout, func() *retry.RetryError { + if c.config.Scheme == "postgres" { + db, err = sql.Open(proxyDriverName, dsn) + } else if c.config.Scheme == "gcppostgres" && c.config.GCPIAMImpersonateServiceAccount != "" { + db, err = openImpersonatedGCPDBConnection(ctx, dsn, c.config.GCPIAMImpersonateServiceAccount) + } else { + db, err = postgres.Open(ctx, dsn) + } + if err == nil { + err = db.PingContext(ctx) + } - if err == nil { - err = db.Ping() - } - if err != nil { - errString := strings.Replace(err.Error(), c.config.Password, "XXXX", 2) + retryCount++ + if err != nil { + if retryCount >= c.config.MaxConnRetries { + return retry.NonRetryableError(err) + } + return retry.RetryableError(err) + } + return nil + }) + if retryError != nil { + errString := strings.Replace(retryError.Error(), c.config.Password, "XXXX", 2) return nil, fmt.Errorf("error connecting to PostgreSQL server %s (scheme: %s): %s", c.config.Host, c.config.Scheme, errString) } @@ -310,6 +329,7 @@ func (c *Client) Connect() (*DBConnection, error) { // we don't keep opened connection in case of the db has to be dropped in the plan. db.SetMaxIdleConns(0) db.SetMaxOpenConns(c.config.MaxConns) + db.SetConnMaxLifetime(time.Duration(c.config.ConnMaxLifetimeSeconds) * time.Second) defaultVersion, _ := semver.Parse(defaultExpectedPostgreSQLVersion) version := &c.config.ExpectedVersion diff --git a/postgresql/provider.go b/postgresql/provider.go index 778778be..041fabce 100644 --- a/postgresql/provider.go +++ b/postgresql/provider.go @@ -3,9 +3,10 @@ package postgresql import ( "context" "fmt" + "os" + "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/sts" - "os" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" @@ -21,8 +22,11 @@ import ( ) const ( - defaultProviderMaxOpenConnections = 20 - defaultExpectedPostgreSQLVersion = "9.0.0" + defaultProviderMaxOpenConnections = 20 + defaultProviderConnMaxLifetimeSeconds = 300 + defaultProviderMaxConnRetries = 5 + defaultProviderConnectionRetryTimeoutSeconds = 5 + defaultExpectedPostgreSQLVersion = "9.0.0" ) // Provider returns a terraform.ResourceProvider. @@ -185,6 +189,20 @@ func Provider() *schema.Provider { Description: "Maximum wait for connection, in seconds. Zero or not specified means wait indefinitely.", ValidateFunc: validation.IntAtLeast(-1), }, + "max_conn_retries": { + Type: schema.TypeInt, + Optional: true, + Default: defaultProviderMaxConnRetries, + Description: "Maximum number of connection retries.", + ValidateFunc: validation.IntAtLeast(0), + }, + "connection_retry_timeout_seconds": { + Type: schema.TypeInt, + Optional: true, + Default: defaultProviderConnectionRetryTimeoutSeconds, + Description: "Maximum wait for connection retries, in seconds.", + ValidateFunc: validation.IntAtLeast(0), + }, "max_connections": { Type: schema.TypeInt, Optional: true, @@ -192,6 +210,13 @@ func Provider() *schema.Provider { Description: "Maximum number of connections to establish to the database. Zero means unlimited.", ValidateFunc: validation.IntAtLeast(-1), }, + "conn_max_lifetime_seconds": { + Type: schema.TypeInt, + Optional: true, + Default: defaultProviderConnMaxLifetimeSeconds, + Description: "Maximum lifetime of a connection, in seconds. Zero means unlimited.", + ValidateFunc: validation.IntAtLeast(-1), + }, "expected_version": { Type: schema.TypeString, Optional: true, @@ -382,7 +407,10 @@ func providerConfigure(d *schema.ResourceData) (any, error) { SSLMode: sslMode, ApplicationName: "Terraform provider", ConnectTimeoutSec: d.Get("connect_timeout").(int), + MaxConnRetries: d.Get("max_conn_retries").(int), + ConnectionRetryTimeoutSeconds: d.Get("connection_retry_timeout_seconds").(int), MaxConns: d.Get("max_connections").(int), + ConnMaxLifetimeSeconds: d.Get("conn_max_lifetime_seconds").(int), ExpectedVersion: version, SSLRootCertPath: d.Get("sslrootcert").(string), GCPIAMImpersonateServiceAccount: d.Get("gcp_iam_impersonate_service_account").(string),