diff --git a/postgresql/helpers.go b/postgresql/helpers.go index 1cc0cd1d..5e7e41a7 100644 --- a/postgresql/helpers.go +++ b/postgresql/helpers.go @@ -1,12 +1,14 @@ package postgresql import ( + "context" "database/sql" "fmt" "log" "regexp" "strings" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/lib/pq" ) @@ -37,6 +39,23 @@ func PGResourceExistsFunc(fn func(*DBConnection, *schema.ResourceData) (bool, er } } +func PGResourceContextFunc(fn func(context.Context, *DBConnection, *schema.ResourceData) diag.Diagnostics) func(context.Context, *schema.ResourceData, interface{}) diag.Diagnostics { + return func(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { + client := meta.(*Client) + + db, err := client.Connect() + if err != nil { + return diag.Diagnostics{diag.Diagnostic{ + Severity: diag.Error, + Summary: "Failled to connext", + Detail: err.Error(), + }} + } + + return fn(ctx, db, d) + } +} + // QueryAble is a DB connection (sql.DB/Tx) type QueryAble interface { Exec(query string, args ...interface{}) (sql.Result, error) diff --git a/postgresql/resource_postgresql_script.go b/postgresql/resource_postgresql_script.go index bdc0b31e..af3348ee 100644 --- a/postgresql/resource_postgresql_script.go +++ b/postgresql/resource_postgresql_script.go @@ -1,12 +1,14 @@ package postgresql import ( + "context" "crypto/sha1" "encoding/hex" "fmt" "log" "time" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" ) @@ -14,15 +16,16 @@ const ( scriptCommandsAttr = "commands" scriptTriesAttr = "tries" scriptBackoffDelayAttr = "backoff_delay" + scriptTimeoutAttr = "timeout" scriptShasumAttr = "shasum" ) func resourcePostgreSQLScript() *schema.Resource { return &schema.Resource{ - Create: PGResourceFunc(resourcePostgreSQLScriptCreateOrUpdate), - Read: PGResourceFunc(resourcePostgreSQLScriptRead), - Update: PGResourceFunc(resourcePostgreSQLScriptCreateOrUpdate), - Delete: PGResourceFunc(resourcePostgreSQLScriptDelete), + CreateContext: PGResourceContextFunc(resourcePostgreSQLScriptCreateOrUpdate), + Read: PGResourceFunc(resourcePostgreSQLScriptRead), + UpdateContext: PGResourceContextFunc(resourcePostgreSQLScriptCreateOrUpdate), + Delete: PGResourceFunc(resourcePostgreSQLScriptDelete), Schema: map[string]*schema.Schema{ scriptCommandsAttr: { @@ -45,6 +48,12 @@ func resourcePostgreSQLScript() *schema.Resource { Default: 1, Description: "Number of seconds between two tries of the batch of commands", }, + scriptTimeoutAttr: { + Type: schema.TypeInt, + Optional: true, + Default: 5 * 60, + Description: "Number of seconds for a batch of command to timeout", + }, scriptShasumAttr: { Type: schema.TypeString, Computed: true, @@ -54,19 +63,28 @@ func resourcePostgreSQLScript() *schema.Resource { } } -func resourcePostgreSQLScriptCreateOrUpdate(db *DBConnection, d *schema.ResourceData) error { +func resourcePostgreSQLScriptCreateOrUpdate(ctx context.Context, db *DBConnection, d *schema.ResourceData) diag.Diagnostics { commands, err := toStringArray(d.Get(scriptCommandsAttr).([]any)) tries := d.Get(scriptTriesAttr).(int) backoffDelay := d.Get(scriptBackoffDelayAttr).(int) + timeout := d.Get(scriptTimeoutAttr).(int) if err != nil { - return err + return diag.Diagnostics{diag.Diagnostic{ + Severity: diag.Error, + Summary: "Commands input is not valid", + Detail: err.Error(), + }} } sum := shasumCommands(commands) - if err := executeCommands(db, commands, tries, backoffDelay); err != nil { - return err + if err := executeCommands(ctx, db, commands, tries, backoffDelay, timeout); err != nil { + return diag.Diagnostics{diag.Diagnostic{ + Severity: diag.Error, + Summary: "Commands execution failed", + Detail: err.Error(), + }} } d.Set(scriptShasumAttr, sum) @@ -89,9 +107,9 @@ func resourcePostgreSQLScriptDelete(db *DBConnection, d *schema.ResourceData) er return nil } -func executeCommands(db *DBConnection, commands []string, tries int, backoffDelay int) error { +func executeCommands(ctx context.Context, db *DBConnection, commands []string, tries int, backoffDelay int, timeout int) error { for try := 1; ; try++ { - err := executeBatch(db, commands) + err := executeBatch(ctx, db, commands, timeout) if err == nil { return nil } else { @@ -103,10 +121,12 @@ func executeCommands(db *DBConnection, commands []string, tries int, backoffDela } } -func executeBatch(db *DBConnection, commands []string) error { +func executeBatch(ctx context.Context, db *DBConnection, commands []string, timeout int) error { + timeoutContext, timeoutCancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) + defer timeoutCancel() for _, command := range commands { log.Printf("[DEBUG] Executing %s", command) - _, err := db.Exec(command) + _, err := db.ExecContext(timeoutContext, command) log.Printf("[DEBUG] Result %s: %v", command, err) if err != nil { log.Println("[DEBUG] Error catched:", err) diff --git a/postgresql/resource_postgresql_script_test.go b/postgresql/resource_postgresql_script_test.go index 378eb56e..8e624237 100644 --- a/postgresql/resource_postgresql_script_test.go +++ b/postgresql/resource_postgresql_script_test.go @@ -203,3 +203,27 @@ func TestAccPostgresqlScript_failMultiple(t *testing.T) { }, }) } + +func TestAccPostgresqlScript_timeout(t *testing.T) { + config := ` + resource "postgresql_script" "invalid" { + commands = [ + "BEGIN", + "SELECT pg_sleep(2);", + "COMMIT" + ] + timeout = 1 + } + ` + + resource.Test(t, resource.TestCase{ + PreCheck: func() { testAccPreCheck(t) }, + Providers: testAccProviders, + Steps: []resource.TestStep{ + { + Config: config, + ExpectError: regexp.MustCompile("canceling statement"), + }, + }, + }) +}