diff --git a/postgresql/resource_postgresql_script.go b/postgresql/resource_postgresql_script.go index bd4cad99..10779ca7 100644 --- a/postgresql/resource_postgresql_script.go +++ b/postgresql/resource_postgresql_script.go @@ -3,6 +3,7 @@ package postgresql import ( "crypto/sha1" "encoding/hex" + "fmt" "log" "time" @@ -54,10 +55,14 @@ func resourcePostgreSQLScript() *schema.Resource { } func resourcePostgreSQLScriptCreateOrUpdate(db *DBConnection, d *schema.ResourceData) error { - commands := d.Get(scriptCommandsAttr).([]any) + commands, err := toStringArray(d.Get(scriptCommandsAttr).([]any)) tries := d.Get(scriptTriesAttr).(int) backoffDelay := d.Get(scriptBackoffDelayAttr).(int) + if err != nil { + return err + } + sum := shasumCommands(commands) if err := executeCommands(db, commands, tries, backoffDelay); err != nil { @@ -70,7 +75,10 @@ func resourcePostgreSQLScriptCreateOrUpdate(db *DBConnection, d *schema.Resource } func resourcePostgreSQLScriptRead(db *DBConnection, d *schema.ResourceData) error { - commands := d.Get(scriptCommandsAttr).([]any) + commands, err := toStringArray(d.Get(scriptCommandsAttr).([]any)) + if err != nil { + return err + } newSum := shasumCommands(commands) d.Set(scriptShasumAttr, newSum) @@ -81,7 +89,7 @@ func resourcePostgreSQLScriptDelete(db *DBConnection, d *schema.ResourceData) er return nil } -func executeCommands(db *DBConnection, commands []any, tries int, backoffDelay int) error { +func executeCommands(db *DBConnection, commands []string, tries int, backoffDelay int) error { for try := 1; ; try++ { err := executeBatch(db, commands) if err == nil { @@ -95,15 +103,15 @@ func executeCommands(db *DBConnection, commands []any, tries int, backoffDelay i } } -func executeBatch(db *DBConnection, commands []any) error { +func executeBatch(db *DBConnection, commands []string) error { for _, command := range commands { - log.Printf("[ERROR] Executing %s", command.(string)) - _, err := db.Query(command.(string)) + log.Printf("[DEBUG] Executing %s", command) + _, err := db.Query(command) if err != nil { - log.Println("[ERROR] Error catched:", err) + log.Println("[DEBUG] Error catched:", err) if _, rollbackError := db.Query("ROLLBACK"); rollbackError != nil { - log.Println("[ERROR] Rollback raised an error:", rollbackError) + log.Println("[DEBUG] Rollback raised an error:", rollbackError) } return err } @@ -111,10 +119,22 @@ func executeBatch(db *DBConnection, commands []any) error { return nil } -func shasumCommands(commands []any) string { +func shasumCommands(commands []string) string { sha := sha1.New() for _, command := range commands { - sha.Write([]byte(command.(string))) + sha.Write([]byte(command)) } return hex.EncodeToString(sha.Sum(nil)) } + +func toStringArray(array []any) ([]string, error) { + strings := make([]string, 0, len(array)) + for _, elem := range array { + str, ok := elem.(string) + if !ok { + return nil, fmt.Errorf("element %v is not a string", elem) + } + strings = append(strings, str) + } + return strings, nil +} diff --git a/postgresql/resource_postgresql_script_test.go b/postgresql/resource_postgresql_script_test.go index 24660f0d..378eb56e 100644 --- a/postgresql/resource_postgresql_script_test.go +++ b/postgresql/resource_postgresql_script_test.go @@ -133,6 +133,29 @@ func TestAccPostgresqlScript_reapply(t *testing.T) { }) } +func TestAccPostgresqlScript_invalid(t *testing.T) { + config := ` + resource "postgresql_script" "invalid" { + commands = [ + "" + ] + tries = 2 + backoff_delay = 2 + } + ` + + resource.Test(t, resource.TestCase{ + PreCheck: func() { testAccPreCheck(t) }, + Providers: testAccProviders, + Steps: []resource.TestStep{ + { + Config: config, + ExpectError: regexp.MustCompile("element is not a string"), + }, + }, + }) +} + func TestAccPostgresqlScript_fail(t *testing.T) { config := ` resource "postgresql_script" "invalid" {