|
1 | 1 | package main |
2 | 2 |
|
3 | 3 | import ( |
4 | | - "database/sql" |
| 4 | + "context" |
5 | 5 | "fmt" |
6 | | - "net/url" |
7 | 6 |
|
8 | | - _ "github.com/jackc/pgx/v5/stdlib" |
| 7 | + "github.com/jackc/pgx/v5" |
| 8 | + "github.com/jackc/pgx/v5/pgxpool" |
9 | 9 | "github.com/spf13/cobra" |
10 | 10 | ) |
11 | 11 |
|
@@ -39,42 +39,43 @@ func runReset(cmd *cobra.Command, args []string) error { |
39 | 39 | return err |
40 | 40 | } |
41 | 41 |
|
42 | | - // Parse the URL to get database name |
43 | | - parsedURL, err := url.Parse(dbURL) |
44 | | - if err != nil { |
45 | | - return fmt.Errorf("invalid database URL: %w", err) |
46 | | - } |
| 42 | + ctx := context.Background() |
47 | 43 |
|
48 | | - // Extract database name from path |
49 | | - dbName := parsedURL.Path |
50 | | - if len(dbName) > 0 && dbName[0] == '/' { |
51 | | - dbName = dbName[1:] // Remove leading slash |
| 44 | + // Parse the connection string using pgxpool.ParseConfig which handles both URL and key-value formats |
| 45 | + config, err := pgxpool.ParseConfig(dbURL) |
| 46 | + if err != nil { |
| 47 | + return fmt.Errorf("failed to parse database URL: %w", err) |
52 | 48 | } |
53 | 49 |
|
| 50 | + // Get the database name from the config |
| 51 | + dbName := config.ConnConfig.Database |
54 | 52 | if dbName == "" { |
55 | | - return fmt.Errorf("database name is required in URL") |
| 53 | + return fmt.Errorf("database name not found in connection string") |
56 | 54 | } |
57 | 55 |
|
58 | | - // Connect to postgres database to drop and recreate the system database |
59 | | - parsedURL.Path = "/postgres" |
60 | | - postgresURL := parsedURL.String() |
| 56 | + // Create a connection configuration pointing to the postgres database |
| 57 | + postgresConfig := config.ConnConfig.Copy() |
| 58 | + postgresConfig.Database = "postgres" |
61 | 59 |
|
62 | | - db, err := sql.Open("pgx", postgresURL) |
| 60 | + // Connect to the postgres database |
| 61 | + conn, err := pgx.ConnectConfig(ctx, postgresConfig) |
63 | 62 | if err != nil { |
64 | | - return fmt.Errorf("failed to connect to postgres database: %w", err) |
| 63 | + return fmt.Errorf("failed to connect to PostgreSQL server: %w", err) |
65 | 64 | } |
66 | | - defer db.Close() |
| 65 | + defer conn.Close(ctx) |
67 | 66 |
|
68 | 67 | // Drop the system database if it exists |
69 | 68 | logger.Info("Resetting system database", "database", dbName) |
70 | | - dropQuery := fmt.Sprintf("DROP DATABASE IF EXISTS %s WITH (FORCE)", dbName) |
71 | | - if _, err := db.Exec(dropQuery); err != nil { |
| 69 | + dropSQL := fmt.Sprintf("DROP DATABASE IF EXISTS %s WITH (FORCE)", pgx.Identifier{dbName}.Sanitize()) |
| 70 | + _, err = conn.Exec(ctx, dropSQL) |
| 71 | + if err != nil { |
72 | 72 | return fmt.Errorf("failed to drop system database: %w", err) |
73 | 73 | } |
74 | 74 |
|
75 | 75 | // Create the database |
76 | | - createQuery := fmt.Sprintf("CREATE DATABASE %s", dbName) |
77 | | - if _, err := db.Exec(createQuery); err != nil { |
| 76 | + createSQL := fmt.Sprintf("CREATE DATABASE %s", pgx.Identifier{dbName}.Sanitize()) |
| 77 | + _, err = conn.Exec(ctx, createSQL) |
| 78 | + if err != nil { |
78 | 79 | return fmt.Errorf("failed to create system database: %w", err) |
79 | 80 | } |
80 | 81 |
|
|
0 commit comments