-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathreset_db.go
More file actions
162 lines (141 loc) · 4.19 KB
/
reset_db.go
File metadata and controls
162 lines (141 loc) · 4.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
package client
import (
"context"
"fmt"
"log/slog"
"strings"
"github.com/jackc/pgx/v5"
)
// ServicePortProvider provides a way to get the mapped port for a service
type ServicePortProvider interface {
GetServicePort(ctx context.Context, serviceName string, internalPort string) (string, error)
}
// DatabaseConfig holds database connection configuration
type DatabaseConfig struct {
User string
Password string
Database string
Host string
Port string
}
// DefaultDatabaseConfig returns the default database configuration
func DefaultDatabaseConfig() DatabaseConfig {
return DatabaseConfig{
User: "infisical",
Password: "infisical",
Database: "infisical",
Host: "localhost",
Port: "5432",
}
}
// ResetDBOptions holds options for resetting the database
type ResetDBOptions struct {
SkipTables map[string]struct{} // Tables to skip when truncating (e.g., migrations)
DBConfig DatabaseConfig
}
// DefaultResetDBOptions returns default options for resetting the database
func DefaultResetDBOptions() ResetDBOptions {
return ResetDBOptions{
SkipTables: map[string]struct{}{
"public.infisical_migrations": {},
"public.infisical_migrations_lock": {},
},
DBConfig: DefaultDatabaseConfig(),
}
}
// ResetDB resets the PostgreSQL database.
// It accepts a port provider to get service ports, and options to configure the reset behavior.
func ResetDB(ctx context.Context, opts ...func(*ResetDBOptions)) error {
options := DefaultResetDBOptions()
for _, opt := range opts {
opt(&options)
}
// Reset PostgreSQL database
if err := resetPostgresDB(ctx, options); err != nil {
return fmt.Errorf("failed to reset PostgreSQL database: %w", err)
}
return nil
}
// resetPostgresDB resets the PostgreSQL database by truncating all tables (except skipped ones)
// and inserting a default super_admin record.
func resetPostgresDB(ctx context.Context, opts ResetDBOptions) error {
// Build connection string using config
connStr := fmt.Sprintf("postgresql://%s:%s@%s:%s/%s",
opts.DBConfig.User,
opts.DBConfig.Password,
opts.DBConfig.Host,
opts.DBConfig.Port,
opts.DBConfig.Database,
)
conn, err := pgx.Connect(ctx, connStr)
if err != nil {
slog.Error("Unable to connect to database", "err", err)
return err
}
defer conn.Close(ctx)
// Get all tables
query := `
SELECT table_schema, table_name
FROM information_schema.tables
WHERE table_type = 'BASE TABLE'
AND table_schema NOT IN ('pg_catalog', 'information_schema')
ORDER BY table_schema, table_name;
`
rows, err := conn.Query(ctx, query)
if err != nil {
slog.Error("Unable to execute query", "query", query, "err", err)
return err
}
defer rows.Close()
tables := make([]string, 0)
for rows.Next() {
var schema, table string
if err := rows.Scan(&schema, &table); err != nil {
slog.Error("Scan failed", "error", err)
return err
}
tables = append(tables, fmt.Sprintf("%s.%s", schema, table))
}
if err := rows.Err(); err != nil {
slog.Error("Row iteration error", "error", err)
return err
}
// Build truncate statements
var builder strings.Builder
for _, table := range tables {
if _, ok := opts.SkipTables[table]; ok {
continue
}
builder.WriteString(fmt.Sprintf("TRUNCATE TABLE %s RESTART IDENTITY CASCADE;\n", table))
}
truncateQuery := builder.String()
if truncateQuery != "" {
_, err = conn.Exec(ctx, truncateQuery)
if err != nil {
slog.Error("Truncate failed", "error", err)
return err
}
slog.Info("Truncate all tables successfully")
}
// Insert default super_admin record
_, err = conn.Exec(ctx,
`INSERT INTO public.super_admin ("id", "fipsEnabled", "initialized", "allowSignUp") VALUES ($1, $2, $3, $4)`,
"00000000-0000-0000-0000-000000000000", true, false, true)
if err != nil {
slog.Error("Failed to insert super_admin", "error", err)
return err
}
return nil
}
// WithSkipTables sets the tables to skip when truncating
func WithSkipTables(tables map[string]struct{}) func(*ResetDBOptions) {
return func(opts *ResetDBOptions) {
opts.SkipTables = tables
}
}
// WithDatabaseConfig sets the database configuration
func WithDatabaseConfig(config DatabaseConfig) func(*ResetDBOptions) {
return func(opts *ResetDBOptions) {
opts.DBConfig = config
}
}