Skip to content

Commit 936f7a3

Browse files
feat: improve control over Env Vars required by the cli for AWS
1 parent f46840d commit 936f7a3

File tree

1 file changed

+34
-6
lines changed

1 file changed

+34
-6
lines changed

cmd/dbdump/main.go

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"log"
77
"os"
8+
"path/filepath"
89
"sync"
910
"time"
1011

@@ -139,9 +140,11 @@ func executeDumpCommand() {
139140
func executeRestoreCommand() {
140141
restoreCmd := flag.NewFlagSet("restore", flag.ExitOnError)
141142
// Flags to connect to the DB Server and load the dump file
143+
// `dumpFile` here can have two values, if the `s3Dump` flag is true (default) the value is the S3 bucket URI
144+
// otherwise is the local system path.
142145
var pgHost, pgUser, pgPassword, dbName, dbType, dumpFile string
143146
var pgPort, numCPUCores int
144-
var localDump bool
147+
var s3Dump bool
145148

146149
restoreCmd.StringVar(&pgHost, "h", "127.0.0.1", "PostgreSQL host")
147150

@@ -153,13 +156,13 @@ func executeRestoreCommand() {
153156

154157
restoreCmd.StringVar(&dbType, "t", "postgres", "The type of database to restore (postgres,mysql,etc...)")
155158

156-
restoreCmd.StringVar(&dumpFile, "f", "", "The absolute path of the dump file to restore the DB from")
159+
restoreCmd.StringVar(&dumpFile, "f", "", "The absolute path of the dump file to restore the DB from or the S3 URI")
157160

158161
restoreCmd.IntVar(&pgPort, "p", 5432, "PostgreSQL port")
159162

160163
restoreCmd.IntVar(&numCPUCores, "n", 2, "Number of parallel processes (1 per CPU Core) to use")
161164

162-
restoreCmd.BoolVar(&localDump, "s", true, "Download the dump from AWS S3")
165+
restoreCmd.BoolVar(&s3Dump, "s", true, "Download the dump from AWS S3")
163166

164167
err := restoreCmd.Parse(os.Args[2:])
165168
if err != nil {
@@ -172,22 +175,28 @@ func executeRestoreCommand() {
172175
os.Exit(1)
173176
}
174177

178+
downloader, err := dump.NewS3Downloader(s3Dump)
179+
if err != nil {
180+
fmt.Println("Error initializing S3 downloader:", err)
181+
os.Exit(1)
182+
}
183+
175184
var dbrestorer dump.Restorer = dump.NewPostgresRestorer(pgHost, pgPort, numCPUCores, pgUser, pgPassword, dumpFile, dbName)
176185

177186
var wg sync.WaitGroup
178187

179188
fmt.Println("\nInitalizing restore process in background...")
180189
fmt.Printf("-> Spawning restore task for '%s'.\n", *&dbName)
181190
wg.Add(1)
182-
go processDatabaseRestore(&wg, *&dbName, dbrestorer, *&dumpFile)
191+
go processDatabaseRestore(&wg, *&dbName, dbrestorer, dumpFile, s3Dump, downloader)
183192
fmt.Println("\nAll restore processes have been started. Waiting for them to complete.")
184193
wg.Wait()
185194
fmt.Println("Check the 'restore_log_*.log' file for progress and results.")
186195
}
187196

188197
// processDatabaseRestore is used to restore a DB using the dump file passed as parameter.
189198
// It can also take the number of processes to execute in parallel based on the number of available CPUs core
190-
func processDatabaseRestore(wg *sync.WaitGroup, dbName string, dbrestorer dump.Restorer, dumpFileName string) {
199+
func processDatabaseRestore(wg *sync.WaitGroup, dbName string, dbrestorer dump.Restorer, dumpFileName string, s3Download bool, downloader dump.Downloader) {
191200
defer wg.Done()
192201

193202
logFilename := fmt.Sprintf("restore_log_%s_%s.log", dbName, time.Now().UTC().Format("20060102_150405"))
@@ -207,7 +216,26 @@ func processDatabaseRestore(wg *sync.WaitGroup, dbName string, dbrestorer dump.R
207216
logger := log.New(logFile, "", log.LstdFlags)
208217
logger.Printf("Starting restore process for '%s' using '%s' dump...", dbName, dumpFileName)
209218

210-
logger.Println("1. Restoring database...")
219+
// We need to manage the restore based on the remote S3 bucket or the local file path
220+
var local_dump_file string
221+
if s3Download {
222+
local_dump_file = filepath.Base(dumpFileName)
223+
logger.Println(" Downloading dump from S3...")
224+
s3Uri, err := downloader.Download(dumpFileName, local_dump_file)
225+
if err != nil {
226+
logger.Printf("Error downloading dump from S3: %v", err)
227+
return
228+
}
229+
logger.Println("Download successful.")
230+
logger.Printf("File downloaded to: %s", s3Uri)
231+
}
232+
233+
logger.Println("Restoring database...")
234+
if s3Download {
235+
dumpFileName = local_dump_file
236+
} else {
237+
dumpFileName = *&dumpFileName
238+
}
211239
if err := dbrestorer.Restore(dbName, dumpFileName); err != nil {
212240
logger.Printf("Error during restoring the DB '%s': %v", dbName, err)
213241
return

0 commit comments

Comments
 (0)