Skip to content

Commit e59366a

Browse files
committed
refactor database_backup and include compression
1 parent 97f756a commit e59366a

File tree

2 files changed

+138
-40
lines changed

2 files changed

+138
-40
lines changed

sde_collections/management/commands/database_backup.py

Lines changed: 68 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,17 @@
33
44
Usage:
55
docker-compose -f local.yml run --rm django python manage.py database_backup
6+
docker-compose -f local.yml run --rm django python manage.py database_backup --no-compress
67
docker-compose -f production.yml run --rm django python manage.py database_backup
78
"""
89

910
import enum
11+
import gzip
1012
import os
13+
import shutil
1114
import socket
1215
import subprocess
16+
from contextlib import contextmanager
1317
from datetime import datetime
1418

1519
from django.conf import settings
@@ -24,34 +28,85 @@ class Server(enum.Enum):
2428

2529
def detect_server() -> Server:
2630
hostname = socket.gethostname().upper()
27-
2831
if "PRODUCTION" in hostname:
2932
return Server.PRODUCTION
3033
elif "STAGING" in hostname:
3134
return Server.STAGING
3235
return Server.UNKNOWN
3336

3437

38+
@contextmanager
39+
def temp_file_handler(filename: str):
40+
"""Context manager to handle temporary files, ensuring cleanup."""
41+
try:
42+
yield filename
43+
finally:
44+
if os.path.exists(filename):
45+
os.remove(filename)
46+
47+
3548
class Command(BaseCommand):
3649
help = "Creates a PostgreSQL backup using pg_dump"
3750

38-
def handle(self, *args, **options):
39-
server = detect_server()
51+
def add_arguments(self, parser):
52+
parser.add_argument(
53+
"--no-compress",
54+
action="store_true",
55+
help="Disable backup file compression (enabled by default)",
56+
)
57+
58+
def get_backup_filename(self, server: Server, compress: bool) -> tuple[str, str]:
59+
"""Generate backup filename and actual dump path."""
4060
date_str = datetime.now().strftime("%Y%m%d")
41-
backup_file = f"{server.value.lower()}_backup_{date_str}.sql"
61+
base_name = f"{server.value.lower()}_backup_{date_str}.sql"
62+
return f"{base_name}.gz" if compress else base_name, base_name
4263

64+
def run_pg_dump(self, output_file: str, env: dict) -> None:
65+
"""Execute pg_dump with given parameters."""
4366
db_settings = settings.DATABASES["default"]
44-
host = db_settings["HOST"]
45-
name = db_settings["NAME"]
46-
user = db_settings["USER"]
47-
password = db_settings["PASSWORD"]
67+
cmd = [
68+
"pg_dump",
69+
"-h",
70+
db_settings["HOST"],
71+
"-U",
72+
db_settings["USER"],
73+
"-d",
74+
db_settings["NAME"],
75+
"--no-owner",
76+
"--no-privileges",
77+
"-f",
78+
output_file,
79+
]
80+
subprocess.run(cmd, env=env, check=True)
4881

49-
cmd = ["pg_dump", "-h", host, "-U", user, "-d", name, "--no-owner", "--no-privileges", "-f", backup_file]
82+
def compress_file(self, input_file: str, output_file: str) -> None:
83+
"""Compress input file to output file using gzip."""
84+
with open(input_file, "rb") as f_in:
85+
with gzip.open(output_file, "wb") as f_out:
86+
shutil.copyfileobj(f_in, f_out)
87+
88+
def handle(self, *args, **options):
89+
server = detect_server()
90+
compress = not options["no_compress"]
91+
backup_file, dump_file = self.get_backup_filename(server, compress)
92+
93+
env = os.environ.copy()
94+
env["PGPASSWORD"] = settings.DATABASES["default"]["PASSWORD"]
5095

5196
try:
52-
env = os.environ.copy()
53-
env["PGPASSWORD"] = password
54-
subprocess.run(cmd, env=env, check=True)
55-
self.stdout.write(self.style.SUCCESS(f"Successfully created backup for {server.value}: {backup_file}"))
97+
if compress:
98+
with temp_file_handler(dump_file):
99+
self.run_pg_dump(dump_file, env)
100+
self.compress_file(dump_file, backup_file)
101+
else:
102+
self.run_pg_dump(backup_file, env)
103+
104+
self.stdout.write(
105+
self.style.SUCCESS(
106+
f"Successfully created {'compressed ' if compress else ''}backup for {server.value}: {backup_file}"
107+
)
108+
)
56109
except subprocess.CalledProcessError as e:
57110
self.stdout.write(self.style.ERROR(f"Backup failed on {server.value}: {str(e)}"))
111+
except Exception as e:
112+
self.stdout.write(self.style.ERROR(f"Error during backup process: {str(e)}"))
Lines changed: 70 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
"""
2-
Management command to restore PostgreSQL database.
2+
Management command to restore PostgreSQL database from backup.
33
44
Usage:
5-
docker-compose -f local.yml run --rm django python manage.py database_restore path/to/backup.sql
6-
docker-compose -f production.yml run --rm django python manage.py database_restore path/to/backup.sql
5+
docker-compose -f local.yml run --rm django python manage.py database_restore path/to/backup.sql[.gz]
6+
docker-compose -f production.yml run --rm django python manage.py database_restore path/to/backup.sql[.gz]
77
"""
88

99
import enum
10+
import gzip
1011
import os
12+
import shutil
1113
import socket
1214
import subprocess
15+
from contextlib import contextmanager
1316

1417
from django.conf import settings
1518
from django.core.management.base import BaseCommand, CommandError
@@ -30,46 +33,86 @@ def detect_server() -> Server:
3033
return Server.UNKNOWN
3134

3235

36+
@contextmanager
37+
def temp_file_handler(filename: str):
38+
"""Context manager to handle temporary files, ensuring cleanup."""
39+
try:
40+
yield filename
41+
finally:
42+
if os.path.exists(filename):
43+
os.remove(filename)
44+
45+
3346
class Command(BaseCommand):
34-
help = "Restores PostgreSQL database from backup file"
47+
help = "Restores PostgreSQL database from backup file (compressed or uncompressed)"
3548

3649
def add_arguments(self, parser):
37-
parser.add_argument("backup_path", type=str, help="Path to the backup file")
50+
parser.add_argument("backup_path", type=str, help="Path to the backup file (.sql or .sql.gz)")
51+
52+
def get_db_settings(self):
53+
"""Get database connection settings."""
54+
db = settings.DATABASES["default"]
55+
return {
56+
"host": db["HOST"],
57+
"name": db["NAME"],
58+
"user": db["USER"],
59+
"password": db["PASSWORD"],
60+
}
61+
62+
def run_psql_command(self, command: str, db_name: str = "postgres", env: dict = None) -> None:
63+
"""Execute a psql command."""
64+
db = self.get_db_settings()
65+
cmd = ["psql", "-h", db["host"], "-U", db["user"], "-d", db_name, "-c", command]
66+
subprocess.run(cmd, env=env, check=True)
67+
68+
def reset_database(self, env: dict) -> None:
69+
"""Drop and recreate the database."""
70+
db = self.get_db_settings()
71+
self.stdout.write(f"Dropping database {db['name']}...")
72+
self.run_psql_command(f"DROP DATABASE IF EXISTS {db['name']}", env=env)
73+
74+
self.stdout.write(f"Creating database {db['name']}...")
75+
self.run_psql_command(f"CREATE DATABASE {db['name']}", env=env)
76+
77+
def restore_backup(self, backup_file: str, env: dict) -> None:
78+
"""Restore database from backup file."""
79+
db = self.get_db_settings()
80+
cmd = ["psql", "-h", db["host"], "-U", db["user"], "-d", db["name"], "-f", backup_file]
81+
self.stdout.write("Restoring from backup...")
82+
subprocess.run(cmd, env=env, check=True)
83+
84+
def decompress_file(self, input_file: str, output_file: str) -> None:
85+
"""Decompress gzipped file to output file."""
86+
with gzip.open(input_file, "rb") as f_in:
87+
with open(output_file, "wb") as f_out:
88+
shutil.copyfileobj(f_in, f_out)
3889

3990
def handle(self, *args, **options):
4091
server = detect_server()
4192
backup_path = options["backup_path"]
93+
is_compressed = backup_path.endswith(".gz")
4294

4395
if not os.path.exists(backup_path):
4496
raise CommandError(f"Backup file not found: {backup_path}")
4597

46-
db_settings = settings.DATABASES["default"]
47-
host = db_settings["HOST"]
48-
name = db_settings["NAME"]
49-
user = db_settings["USER"]
50-
password = db_settings["PASSWORD"]
51-
52-
# Drop and recreate database
53-
drop_cmd = ["psql", "-h", host, "-U", user, "-d", "postgres", "-c", f"DROP DATABASE IF EXISTS {name}"]
54-
create_cmd = ["psql", "-h", host, "-U", user, "-d", "postgres", "-c", f"CREATE DATABASE {name}"]
55-
56-
# Restore command
57-
restore_cmd = ["psql", "-h", host, "-U", user, "-d", name, "-f", backup_path]
98+
env = os.environ.copy()
99+
env["PGPASSWORD"] = self.get_db_settings()["password"]
58100

59101
try:
60-
env = os.environ.copy()
61-
env["PGPASSWORD"] = password
62-
63-
self.stdout.write(f"Dropping database {name}...")
64-
subprocess.run(drop_cmd, env=env, check=True)
65-
66-
self.stdout.write(f"Creating database {name}...")
67-
subprocess.run(create_cmd, env=env, check=True)
102+
# Reset the database first
103+
self.reset_database(env)
68104

69-
self.stdout.write("Restoring from backup...")
70-
subprocess.run(restore_cmd, env=env, check=True)
105+
# Handle backup restoration
106+
if is_compressed:
107+
with temp_file_handler(backup_path[:-3]) as temp_file:
108+
self.decompress_file(backup_path, temp_file)
109+
self.restore_backup(temp_file, env)
110+
else:
111+
self.restore_backup(backup_path, env)
71112

72113
self.stdout.write(self.style.SUCCESS(f"Successfully restored {server.value} database from {backup_path}"))
73114

74115
except subprocess.CalledProcessError as e:
75116
self.stdout.write(self.style.ERROR(f"Restore failed on {server.value}: {str(e)}"))
117+
except Exception as e:
118+
self.stdout.write(self.style.ERROR(f"Error during restore process: {str(e)}"))

0 commit comments

Comments
 (0)