-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.go
More file actions
215 lines (168 loc) · 5.46 KB
/
main.go
File metadata and controls
215 lines (168 loc) · 5.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
package main
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"flag"
"fmt"
"github.com/lib/pq"
"log"
"os"
"path/filepath"
"strconv"
)
type migration struct {
name string
sql string
}
type config struct {
host string
port int
user string
password string
database string
driver string
directory string
table string
}
// Main function to run pending migrations
func main() {
config := parseFlags()
log.Println("Phase 1: Connecting to database")
db, err := connectDB(config)
if err != nil {
log.Fatal(err)
}
log.Println("Phase 2: Getting new migrations")
migrations, err := getNewMigrations(db, config.table, config.directory)
if err != nil {
log.Fatal(err)
}
log.Println("Phase 2: Found", len(migrations), "new migration(s)")
log.Println("Phase 3: Applying migrations")
err = runMigrations(db, config.table, migrations)
if err != nil {
log.Fatal(err)
}
log.Println("Phase 3: Migrations complete")
}
// Parse the command line flags and return the config
func parseFlags() *config {
hostPtr := flag.String("host", "localhost", "Database host")
portPtr := flag.Int("port", 5432, "Database port")
userPtr := flag.String("user", "gomi", "Database username")
passwordPtr := flag.String("password", "gomi", "Database password")
databasePtr := flag.String("database", "gomi", "Database name")
driverPtr := flag.String("driver", "postgres", "Database SQL driver")
directoryPtr := flag.String("directory", "./migrations", "Directory containing migration files")
tablePtr := flag.String("table", "_migration", "Table to store migration history")
flag.Parse()
return &config{
host: *hostPtr,
port: *portPtr,
user: *userPtr,
password: *passwordPtr,
database: *databasePtr,
driver: *driverPtr,
directory: *directoryPtr,
table: *tablePtr,
}
}
// Get a database connector for a driver
func getDBConnector(config *config) (driver.Connector, error) {
if config.driver == "postgres" {
dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable",
config.host, strconv.Itoa(config.port), config.user, config.password, config.database)
return pq.NewConnector(dsn)
}
return nil, fmt.Errorf("Error database driver not supported: '%s'", config.driver)
}
// Connect to the database using the given connector
func connectDB(config *config) (*sql.DB, error) {
connector, err := getDBConnector(config)
if err != nil {
return nil, err
}
db := sql.OpenDB(connector)
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("Error pinging database: %v", err)
}
return db, nil
}
// Run the migrations against the database and record them in the tracking table
func runMigrations(db *sql.DB, table string, migrations []migration) error {
transaction, err := db.BeginTx(context.TODO(), nil)
if err != nil {
return fmt.Errorf("Error starting transaction: %v", err)
}
// Apply each migration and record it in the tracking table
for _, migration := range migrations {
log.Println("Applying migration:", migration.name)
// Execute the migration SQL
_, err := transaction.Exec(migration.sql)
if err != nil {
return fmt.Errorf("Error executing migration: %v", errors.Join(err, transaction.Rollback()))
}
// Record the migration in the tracking table
_, err = transaction.Exec(fmt.Sprintf("INSERT INTO public.%s (name) VALUES ($1);", table), migration.name)
if err != nil {
return fmt.Errorf("Error inserting migration record: %v", errors.Join(err, transaction.Rollback()))
}
}
return transaction.Commit()
}
// Get a map (set) of applied migrations from the tracking table
func getAppliedMigrations(db *sql.DB, table string) (map[string]bool, error) {
// TODO: Is there a better place to create the tracking table?
_, err := db.Exec(fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS public.%s (
id SERIAL PRIMARY KEY,
name VARCHAR(255) NOT NULL,
applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL);
`, table))
if err != nil {
return nil, fmt.Errorf("Error creating migration tracking table: %v", err)
}
rows, err := db.Query(fmt.Sprintf("SELECT name FROM public.%s;", table))
if err != nil {
return nil, fmt.Errorf("Error getting applied migrations: %v", err)
}
defer rows.Close()
appliedMigrations := make(map[string]bool)
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return nil, fmt.Errorf("Error scanning applied migration: %v", err)
}
appliedMigrations[name] = true
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("Error reading applied migrations: %v", err)
}
return appliedMigrations, nil
}
// Reads SQL migration files from the directory and returns a list of non-applied migrations
func getNewMigrations(db *sql.DB, table string, directory string) ([]migration, error) {
entries, err := os.ReadDir(directory)
if err != nil {
return nil, fmt.Errorf("Error reading directory: %v", err)
}
appliedMigrations, err := getAppliedMigrations(db, table)
if err != nil {
return nil, err
}
var migrations []migration
for _, entry := range entries {
if entry.IsDir() || appliedMigrations[entry.Name()] {
continue // Skip directories and previously applied migrations
}
filePath := filepath.Join(directory, entry.Name())
content, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("Error reading file: '%s': %v", filePath, err)
}
migrations = append(migrations, migration{name: entry.Name(), sql: string(content)})
}
return migrations, nil
}