Skip to content

Commit ab2f4fb

Browse files
feat: Allow to set a database to connect to in a script (#43)
## Add database parameter to postgresql_script resource ### Summary Adds an optional `database` parameter to `postgresql_script` to specify which database to execute commands in. Defaults to the provider's configured database for backwards compatibility. ### Changes - Added `database` schema field (optional, computed) - Updated connection logic to use `getDatabase()` helper and create new connection if needed - Refactored to call `ReadImpl()` at end of Create/Update (matches pattern used by `postgresql_schema`) ### Usage ```hcl resource "postgresql_script" "app_schema" { database = "myapp" commands = ["CREATE TABLE users (id INT)"] } ``` ### Breaking Changes None - fully backwards compatible. ### FOR REVIEWERS I did not close explicitely the connection because tests was failling due to ` Error: Commands execution failed .... sql: database is closed` It seems that for a DB the provider need to be able to reuse a connection pool that is cached for further operation. So if we close the connection we have issue for further connections
1 parent 0bda91c commit ab2f4fb

File tree

2 files changed

+172
-3
lines changed

2 files changed

+172
-3
lines changed

postgresql/resource_postgresql_script.go

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414

1515
const (
1616
scriptCommandsAttr = "commands"
17+
scriptDatabaseAttr = "database"
1718
scriptTriesAttr = "tries"
1819
scriptBackoffDelayAttr = "backoff_delay"
1920
scriptTimeoutAttr = "timeout"
@@ -28,6 +29,13 @@ func resourcePostgreSQLScript() *schema.Resource {
2829
Delete: PGResourceFunc(resourcePostgreSQLScriptDelete),
2930

3031
Schema: map[string]*schema.Schema{
32+
scriptDatabaseAttr: {
33+
Type: schema.TypeString,
34+
Optional: true,
35+
Computed: true,
36+
ForceNew: true,
37+
Description: "The database to execute commands in (defaults to provider's configured database)",
38+
},
3139
scriptCommandsAttr: {
3240
Type: schema.TypeList,
3341
Required: true,
@@ -77,28 +85,64 @@ func resourcePostgreSQLScriptCreateOrUpdate(ctx context.Context, db *DBConnectio
7785
}}
7886
}
7987

80-
sum := shasumCommands(commands)
88+
// Get the target database connection
89+
database := getDatabaseAttrOrDefault(d, db.client.databaseName)
90+
91+
client := db.client.config.NewClient(database)
92+
newDB, err := client.Connect()
93+
if err != nil {
94+
return diag.Diagnostics{diag.Diagnostic{
95+
Severity: diag.Error,
96+
Summary: "Failed to connect to database",
97+
Detail: err.Error(),
98+
}}
99+
}
81100

82-
if err := executeCommands(ctx, db, commands, tries, backoffDelay, timeout); err != nil {
101+
if err := executeCommands(ctx, newDB, commands, tries, backoffDelay, timeout); err != nil {
83102
return diag.Diagnostics{diag.Diagnostic{
84103
Severity: diag.Error,
85104
Summary: "Commands execution failed",
86105
Detail: err.Error(),
87106
}}
88107
}
89108

90-
d.Set(scriptShasumAttr, sum)
109+
sum := shasumCommands(commands)
91110
d.SetId(sum)
111+
112+
if err := resourcePostgreSQLScriptReadImpl(db, d); err != nil {
113+
return diag.Diagnostics{diag.Diagnostic{
114+
Severity: diag.Error,
115+
Summary: "Failed to read script state",
116+
Detail: err.Error(),
117+
}}
118+
}
119+
92120
return nil
93121
}
94122

123+
func getDatabaseAttrOrDefault(d *schema.ResourceData, databaseName string) string {
124+
if v, ok := d.GetOk(scriptDatabaseAttr); ok {
125+
databaseName = v.(string)
126+
}
127+
128+
return databaseName
129+
}
130+
95131
func resourcePostgreSQLScriptRead(db *DBConnection, d *schema.ResourceData) error {
132+
return resourcePostgreSQLScriptReadImpl(db, d)
133+
}
134+
135+
func resourcePostgreSQLScriptReadImpl(db *DBConnection, d *schema.ResourceData) error {
96136
commands, err := toStringArray(d.Get(scriptCommandsAttr).([]any))
97137
if err != nil {
98138
return err
99139
}
100140
newSum := shasumCommands(commands)
141+
142+
database := getDatabaseAttrOrDefault(d, db.client.databaseName)
143+
101144
d.Set(scriptShasumAttr, newSum)
145+
d.Set(scriptDatabaseAttr, database)
102146

103147
return nil
104148
}

postgresql/resource_postgresql_script_test.go

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
package postgresql
22

33
import (
4+
"fmt"
45
"regexp"
56
"testing"
67

78
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource"
9+
"github.com/hashicorp/terraform-plugin-sdk/v2/terraform"
810
)
911

1012
func TestAccPostgresqlScript_basic(t *testing.T) {
@@ -227,3 +229,126 @@ func TestAccPostgresqlScript_timeout(t *testing.T) {
227229
},
228230
})
229231
}
232+
233+
func TestAccPostgresqlScript_withDatabase(t *testing.T) {
234+
config := `
235+
resource "postgresql_database" "test_db" {
236+
name = "test_script_db"
237+
}
238+
239+
resource "postgresql_script" "test" {
240+
database = postgresql_database.test_db.name
241+
commands = [
242+
"CREATE TABLE test_table (id INT);",
243+
"INSERT INTO test_table VALUES (1);"
244+
]
245+
depends_on = [postgresql_database.test_db]
246+
}
247+
248+
resource "postgresql_script" "test_default" {
249+
commands = [
250+
"CREATE TABLE default_db_table (id INT);",
251+
"INSERT INTO default_db_table VALUES (1);",
252+
"INSERT INTO default_db_table VALUES (2);"
253+
]
254+
depends_on = [postgresql_database.test_db]
255+
}
256+
`
257+
258+
resource.Test(t, resource.TestCase{
259+
PreCheck: func() { testAccPreCheck(t) },
260+
Providers: testAccProviders,
261+
CheckDestroy: testAccCheckScriptTablesDestroyed,
262+
Steps: []resource.TestStep{
263+
{
264+
Config: config,
265+
Check: resource.ComposeTestCheckFunc(
266+
resource.TestCheckResourceAttr("postgresql_script.test", "database", "test_script_db"),
267+
resource.TestCheckResourceAttr("postgresql_script.test", "commands.0", "CREATE TABLE test_table (id INT);"),
268+
resource.TestCheckResourceAttr("postgresql_script.test", "commands.1", "INSERT INTO test_table VALUES (1);"),
269+
resource.TestCheckResourceAttr("postgresql_script.test_default", "database", "postgres"),
270+
resource.TestCheckResourceAttr("postgresql_script.test_default", "commands.0", "CREATE TABLE default_db_table (id INT);"),
271+
resource.TestCheckResourceAttr("postgresql_script.test_default", "commands.1", "INSERT INTO default_db_table VALUES (1);"),
272+
resource.TestCheckResourceAttr("postgresql_script.test_default", "commands.2", "INSERT INTO default_db_table VALUES (2);"),
273+
testAccCheckTableExistsInDatabase("test_script_db", "test_table"),
274+
testAccCheckTableHasRecords("test_script_db", "test_table", 1),
275+
testAccCheckTableExistsInDatabase("postgres", "default_db_table"),
276+
testAccCheckTableHasRecords("postgres", "default_db_table", 2),
277+
),
278+
},
279+
},
280+
})
281+
}
282+
283+
func testAccCheckScriptTablesDestroyed(s *terraform.State) error {
284+
return testAccDropTables(map[string][]string{
285+
"test_script_db": {"test_table"},
286+
"postgres": {"default_db_table"},
287+
})
288+
}
289+
290+
func testAccDropTables(tablesToDrop map[string][]string) error {
291+
client := testAccProvider.Meta().(*Client)
292+
293+
for dbName, tables := range tablesToDrop {
294+
dbClient := client.config.NewClient(dbName)
295+
db, err := dbClient.Connect()
296+
if err != nil {
297+
continue // Skip if we can't connect to the database
298+
}
299+
300+
for _, tableName := range tables {
301+
_, _ = db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName))
302+
}
303+
}
304+
305+
return nil
306+
}
307+
308+
func testAccCheckTableExistsInDatabase(dbName, tableName string) resource.TestCheckFunc {
309+
return func(s *terraform.State) error {
310+
client := testAccProvider.Meta().(*Client)
311+
dbClient := client.config.NewClient(dbName)
312+
db, err := dbClient.Connect()
313+
if err != nil {
314+
return fmt.Errorf("Error connecting to database %s: %s", dbName, err)
315+
}
316+
317+
var exists bool
318+
query := "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_schema = 'public' AND table_name = $1)"
319+
err = db.QueryRow(query, tableName).Scan(&exists)
320+
if err != nil {
321+
return fmt.Errorf("Error checking if table %s exists: %s", tableName, err)
322+
}
323+
324+
if !exists {
325+
return fmt.Errorf("Table %s does not exist in database %s", tableName, dbName)
326+
}
327+
328+
return nil
329+
}
330+
}
331+
332+
func testAccCheckTableHasRecords(dbName, tableName string, expectedCount int) resource.TestCheckFunc {
333+
return func(s *terraform.State) error {
334+
client := testAccProvider.Meta().(*Client)
335+
dbClient := client.config.NewClient(dbName)
336+
db, err := dbClient.Connect()
337+
if err != nil {
338+
return fmt.Errorf("Error connecting to database %s: %s", dbName, err)
339+
}
340+
341+
var count int
342+
query := fmt.Sprintf("SELECT COUNT(*) FROM %s", tableName)
343+
err = db.QueryRow(query).Scan(&count)
344+
if err != nil {
345+
return fmt.Errorf("Error counting records in table %s: %s", tableName, err)
346+
}
347+
348+
if count != expectedCount {
349+
return fmt.Errorf("Expected %d records but got %d in table %s", expectedCount, count, tableName)
350+
}
351+
352+
return nil
353+
}
354+
}

0 commit comments

Comments
 (0)