Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions postgresql/helpers.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package postgresql

import (
"context"
"database/sql"
"fmt"
"log"
"regexp"
"strings"

"github.com/hashicorp/terraform-plugin-sdk/v2/diag"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/lib/pq"
)
Expand Down Expand Up @@ -37,6 +39,23 @@ func PGResourceExistsFunc(fn func(*DBConnection, *schema.ResourceData) (bool, er
}
}

func PGResourceContextFunc(fn func(context.Context, *DBConnection, *schema.ResourceData) diag.Diagnostics) func(context.Context, *schema.ResourceData, interface{}) diag.Diagnostics {
return func(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
client := meta.(*Client)

db, err := client.Connect()
if err != nil {
return diag.Diagnostics{diag.Diagnostic{
Severity: diag.Error,
Summary: "Failled to connext",
Detail: err.Error(),
}}
}

return fn(ctx, db, d)
}
}

// QueryAble is a DB connection (sql.DB/Tx)
type QueryAble interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Expand Down
44 changes: 32 additions & 12 deletions postgresql/resource_postgresql_script.go
Original file line number Diff line number Diff line change
@@ -1,28 +1,31 @@
package postgresql

import (
"context"
"crypto/sha1"
"encoding/hex"
"fmt"
"log"
"time"

"github.com/hashicorp/terraform-plugin-sdk/v2/diag"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
)

const (
scriptCommandsAttr = "commands"
scriptTriesAttr = "tries"
scriptBackoffDelayAttr = "backoff_delay"
scriptTimeoutAttr = "timeout"
scriptShasumAttr = "shasum"
)

func resourcePostgreSQLScript() *schema.Resource {
return &schema.Resource{
Create: PGResourceFunc(resourcePostgreSQLScriptCreateOrUpdate),
Read: PGResourceFunc(resourcePostgreSQLScriptRead),
Update: PGResourceFunc(resourcePostgreSQLScriptCreateOrUpdate),
Delete: PGResourceFunc(resourcePostgreSQLScriptDelete),
CreateContext: PGResourceContextFunc(resourcePostgreSQLScriptCreateOrUpdate),
Read: PGResourceFunc(resourcePostgreSQLScriptRead),
UpdateContext: PGResourceContextFunc(resourcePostgreSQLScriptCreateOrUpdate),
Delete: PGResourceFunc(resourcePostgreSQLScriptDelete),

Schema: map[string]*schema.Schema{
scriptCommandsAttr: {
Expand All @@ -45,6 +48,12 @@ func resourcePostgreSQLScript() *schema.Resource {
Default: 1,
Description: "Number of seconds between two tries of the batch of commands",
},
scriptTimeoutAttr: {
Type: schema.TypeInt,
Optional: true,
Default: 5 * 60,
Description: "Number of seconds for a batch of command to timeout",
},
scriptShasumAttr: {
Type: schema.TypeString,
Computed: true,
Expand All @@ -54,19 +63,28 @@ func resourcePostgreSQLScript() *schema.Resource {
}
}

func resourcePostgreSQLScriptCreateOrUpdate(db *DBConnection, d *schema.ResourceData) error {
func resourcePostgreSQLScriptCreateOrUpdate(ctx context.Context, db *DBConnection, d *schema.ResourceData) diag.Diagnostics {
commands, err := toStringArray(d.Get(scriptCommandsAttr).([]any))
tries := d.Get(scriptTriesAttr).(int)
backoffDelay := d.Get(scriptBackoffDelayAttr).(int)
timeout := d.Get(scriptTimeoutAttr).(int)

if err != nil {
return err
return diag.Diagnostics{diag.Diagnostic{
Severity: diag.Error,
Summary: "Commands input is not valid",
Detail: err.Error(),
}}
}

sum := shasumCommands(commands)

if err := executeCommands(db, commands, tries, backoffDelay); err != nil {
return err
if err := executeCommands(ctx, db, commands, tries, backoffDelay, timeout); err != nil {
return diag.Diagnostics{diag.Diagnostic{
Severity: diag.Error,
Summary: "Commands execution failed",
Detail: err.Error(),
}}
}

d.Set(scriptShasumAttr, sum)
Expand All @@ -89,9 +107,9 @@ func resourcePostgreSQLScriptDelete(db *DBConnection, d *schema.ResourceData) er
return nil
}

func executeCommands(db *DBConnection, commands []string, tries int, backoffDelay int) error {
func executeCommands(ctx context.Context, db *DBConnection, commands []string, tries int, backoffDelay int, timeout int) error {
for try := 1; ; try++ {
err := executeBatch(db, commands)
err := executeBatch(ctx, db, commands, timeout)
if err == nil {
return nil
} else {
Expand All @@ -103,10 +121,12 @@ func executeCommands(db *DBConnection, commands []string, tries int, backoffDela
}
}

func executeBatch(db *DBConnection, commands []string) error {
func executeBatch(ctx context.Context, db *DBConnection, commands []string, timeout int) error {
timeoutContext, timeoutCancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
defer timeoutCancel()
for _, command := range commands {
log.Printf("[DEBUG] Executing %s", command)
_, err := db.Exec(command)
_, err := db.ExecContext(timeoutContext, command)
log.Printf("[DEBUG] Result %s: %v", command, err)
if err != nil {
log.Println("[DEBUG] Error catched:", err)
Expand Down
24 changes: 24 additions & 0 deletions postgresql/resource_postgresql_script_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,27 @@ func TestAccPostgresqlScript_failMultiple(t *testing.T) {
},
})
}

func TestAccPostgresqlScript_timeout(t *testing.T) {
config := `
resource "postgresql_script" "invalid" {
commands = [
"BEGIN",
"SELECT pg_sleep(2);",
"COMMIT"
]
timeout = 1
}
`

resource.Test(t, resource.TestCase{
PreCheck: func() { testAccPreCheck(t) },
Providers: testAccProviders,
Steps: []resource.TestStep{
{
Config: config,
ExpectError: regexp.MustCompile("canceling statement"),
},
},
})
}
Loading