55 "fmt"
66 "log"
77 "os"
8+ "path/filepath"
89 "sync"
910 "time"
1011
@@ -139,9 +140,11 @@ func executeDumpCommand() {
139140func executeRestoreCommand () {
140141 restoreCmd := flag .NewFlagSet ("restore" , flag .ExitOnError )
141142 // Flags to connect to the DB Server and load the dump file
143+ // `dumpFile` here can have two values, if the `s3Dump` flag is true (default) the value is the S3 bucket URI
144+ // otherwise is the local system path.
142145 var pgHost , pgUser , pgPassword , dbName , dbType , dumpFile string
143146 var pgPort , numCPUCores int
144- var localDump bool
147+ var s3Dump bool
145148
146149 restoreCmd .StringVar (& pgHost , "h" , "127.0.0.1" , "PostgreSQL host" )
147150
@@ -153,13 +156,13 @@ func executeRestoreCommand() {
153156
154157 restoreCmd .StringVar (& dbType , "t" , "postgres" , "The type of database to restore (postgres,mysql,etc...)" )
155158
156- restoreCmd .StringVar (& dumpFile , "f" , "" , "The absolute path of the dump file to restore the DB from" )
159+ restoreCmd .StringVar (& dumpFile , "f" , "" , "The absolute path of the dump file to restore the DB from or the S3 URI " )
157160
158161 restoreCmd .IntVar (& pgPort , "p" , 5432 , "PostgreSQL port" )
159162
160163 restoreCmd .IntVar (& numCPUCores , "n" , 2 , "Number of parallel processes (1 per CPU Core) to use" )
161164
162- restoreCmd .BoolVar (& localDump , "s" , true , "Download the dump from AWS S3" )
165+ restoreCmd .BoolVar (& s3Dump , "s" , true , "Download the dump from AWS S3" )
163166
164167 err := restoreCmd .Parse (os .Args [2 :])
165168 if err != nil {
@@ -172,22 +175,28 @@ func executeRestoreCommand() {
172175 os .Exit (1 )
173176 }
174177
178+ downloader , err := dump .NewS3Downloader (s3Dump )
179+ if err != nil {
180+ fmt .Println ("Error initializing S3 downloader:" , err )
181+ os .Exit (1 )
182+ }
183+
175184 var dbrestorer dump.Restorer = dump .NewPostgresRestorer (pgHost , pgPort , numCPUCores , pgUser , pgPassword , dumpFile , dbName )
176185
177186 var wg sync.WaitGroup
178187
179188 fmt .Println ("\n Initalizing restore process in background..." )
180189 fmt .Printf ("-> Spawning restore task for '%s'.\n " , * & dbName )
181190 wg .Add (1 )
182- go processDatabaseRestore (& wg , * & dbName , dbrestorer , * & dumpFile )
191+ go processDatabaseRestore (& wg , * & dbName , dbrestorer , dumpFile , s3Dump , downloader )
183192 fmt .Println ("\n All restore processes have been started. Waiting for them to complete." )
184193 wg .Wait ()
185194 fmt .Println ("Check the 'restore_log_*.log' file for progress and results." )
186195}
187196
188197// processDatabaseRestore is used to restore a DB using the dump file passed as parameter.
189198// It can also take the number of processes to execute in parallel based on the number of available CPUs core
190- func processDatabaseRestore (wg * sync.WaitGroup , dbName string , dbrestorer dump.Restorer , dumpFileName string ) {
199+ func processDatabaseRestore (wg * sync.WaitGroup , dbName string , dbrestorer dump.Restorer , dumpFileName string , s3Download bool , downloader dump. Downloader ) {
191200 defer wg .Done ()
192201
193202 logFilename := fmt .Sprintf ("restore_log_%s_%s.log" , dbName , time .Now ().UTC ().Format ("20060102_150405" ))
@@ -207,7 +216,26 @@ func processDatabaseRestore(wg *sync.WaitGroup, dbName string, dbrestorer dump.R
207216 logger := log .New (logFile , "" , log .LstdFlags )
208217 logger .Printf ("Starting restore process for '%s' using '%s' dump..." , dbName , dumpFileName )
209218
210- logger .Println ("1. Restoring database..." )
219+ // We need to manage the restore based on the remote S3 bucket or the local file path
220+ var local_dump_file string
221+ if s3Download {
222+ local_dump_file = filepath .Base (dumpFileName )
223+ logger .Println (" Downloading dump from S3..." )
224+ s3Uri , err := downloader .Download (dumpFileName , local_dump_file )
225+ if err != nil {
226+ logger .Printf ("Error downloading dump from S3: %v" , err )
227+ return
228+ }
229+ logger .Println ("Download successful." )
230+ logger .Printf ("File downloaded to: %s" , s3Uri )
231+ }
232+
233+ logger .Println ("Restoring database..." )
234+ if s3Download {
235+ dumpFileName = local_dump_file
236+ } else {
237+ dumpFileName = * & dumpFileName
238+ }
211239 if err := dbrestorer .Restore (dbName , dumpFileName ); err != nil {
212240 logger .Printf ("Error during restoring the DB '%s': %v" , dbName , err )
213241 return
0 commit comments