Skip to content

Commit bbb0d4a

Browse files
committed
refactor and add tests for database restores
1 parent 48c66b8 commit bbb0d4a

File tree

4 files changed

+512
-6
lines changed

4 files changed

+512
-6
lines changed

sde_collections/management/commands/database_backup.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
Usage:
55
docker-compose -f local.yml run --rm django python manage.py database_backup
66
docker-compose -f local.yml run --rm django python manage.py database_backup --no-compress
7+
docker-compose -f local.yml run --rm django python manage.py database_backup --output /path/to/output.sql
78
docker-compose -f production.yml run --rm django python manage.py database_backup
89
"""
910

@@ -54,19 +55,41 @@ def add_arguments(self, parser):
5455
action="store_true",
5556
help="Disable backup file compression (enabled by default)",
5657
)
58+
parser.add_argument(
59+
"--output",
60+
type=str,
61+
help="Output file path (default: auto-generated based on server name and date)",
62+
)
5763

58-
def get_backup_filename(self, server: Server, compress: bool) -> tuple[str, str]:
64+
def get_backup_filename(self, server: Server, compress: bool, custom_output: str = None) -> tuple[str, str]:
5965
"""Generate backup filename and actual dump path.
6066
67+
Args:
68+
server: Server enum indicating the environment
69+
compress: Whether the output should be compressed
70+
custom_output: Optional custom output path
71+
6172
Returns:
6273
tuple[str, str]: A tuple containing (final_filename, temp_filename)
6374
- final_filename: The name of the final backup file (with .gz if compressed)
6475
- temp_filename: The name of the temporary dump file (always without .gz)
6576
"""
66-
date_str = datetime.now().strftime("%Y%m%d")
67-
temp_filename = f"{server.value.lower()}_backup_{date_str}.sql"
68-
final_filename = f"{temp_filename}.gz" if compress else temp_filename
69-
return final_filename, temp_filename
77+
if custom_output:
78+
# Ensure the output directory exists
79+
output_dir = os.path.dirname(custom_output)
80+
if output_dir:
81+
os.makedirs(output_dir, exist_ok=True)
82+
83+
if compress:
84+
return custom_output + (".gz" if not custom_output.endswith(".gz") else ""), custom_output.removesuffix(
85+
".gz"
86+
)
87+
return custom_output, custom_output
88+
else:
89+
date_str = datetime.now().strftime("%Y%m%d")
90+
temp_filename = f"{server.value.lower()}_backup_{date_str}.sql"
91+
final_filename = f"{temp_filename}.gz" if compress else temp_filename
92+
return final_filename, temp_filename
7093

7194
def run_pg_dump(self, output_file: str, env: dict) -> None:
7295
"""Execute pg_dump with given parameters."""
@@ -95,7 +118,7 @@ def compress_file(self, input_file: str, output_file: str) -> None:
95118
def handle(self, *args, **options):
96119
server = detect_server()
97120
compress = not options["no_compress"]
98-
backup_file, dump_file = self.get_backup_filename(server, compress)
121+
backup_file, dump_file = self.get_backup_filename(server, compress, options.get("output"))
99122

100123
env = os.environ.copy()
101124
env["PGPASSWORD"] = settings.DATABASES["default"]["PASSWORD"]

sde_collections/management/commands/database_restore.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from django.conf import settings
1818
from django.core.management.base import BaseCommand, CommandError
19+
from django.db import connections
1920

2021

2122
class Server(enum.Enum):
@@ -65,9 +66,32 @@ def run_psql_command(self, command: str, db_name: str = "postgres", env: dict =
6566
cmd = ["psql", "-h", db["host"], "-U", db["user"], "-d", db_name, "-c", command]
6667
subprocess.run(cmd, env=env, check=True)
6768

69+
def terminate_database_connections(self, env: dict) -> None:
70+
"""Terminate all connections to the database."""
71+
db = self.get_db_settings()
72+
# Close Django's connection first
73+
connections.close_all()
74+
75+
# Terminate any remaining PostgreSQL connections
76+
terminate_conn_sql = f"""
77+
SELECT pg_terminate_backend(pid)
78+
FROM pg_stat_activity
79+
WHERE datname = '{db["name"]}'
80+
AND pid <> pg_backend_pid();
81+
"""
82+
try:
83+
self.run_psql_command(terminate_conn_sql, env=env)
84+
except subprocess.CalledProcessError:
85+
# If this fails, it's usually because there are no connections to terminate
86+
pass
87+
6888
def reset_database(self, env: dict) -> None:
6989
"""Drop and recreate the database."""
7090
db = self.get_db_settings()
91+
92+
self.stdout.write(f"Terminating connections to {db['name']}...")
93+
self.terminate_database_connections(env)
94+
7195
self.stdout.write(f"Dropping database {db['name']}...")
7296
self.run_psql_command(f"DROP DATABASE IF EXISTS {db['name']}", env=env)
7397

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
# docker-compose -f local.yml run --rm django pytest sde_collections/tests/test_database_backup.py
2+
import gzip
3+
import os
4+
import subprocess
5+
from datetime import datetime
6+
from unittest.mock import Mock, patch
7+
8+
import pytest
9+
from django.core.management import call_command
10+
11+
from sde_collections.management.commands import database_backup
12+
from sde_collections.management.commands.database_backup import (
13+
Server,
14+
temp_file_handler,
15+
)
16+
17+
18+
@pytest.fixture
19+
def mock_subprocess():
20+
with patch("subprocess.run") as mock_run:
21+
mock_run.return_value.returncode = 0
22+
yield mock_run
23+
24+
25+
@pytest.fixture
26+
def mock_date():
27+
with patch("sde_collections.management.commands.database_backup.datetime") as mock_dt:
28+
mock_dt.now.return_value = datetime(2024, 1, 15)
29+
yield mock_dt
30+
31+
32+
@pytest.fixture
33+
def mock_settings(settings):
34+
"""Configure test database settings."""
35+
settings.DATABASES = {
36+
"default": {
37+
"HOST": "test-db-host",
38+
"NAME": "test_db",
39+
"USER": "test_user",
40+
"PASSWORD": "test_password",
41+
}
42+
}
43+
return settings
44+
45+
46+
@pytest.fixture
47+
def command():
48+
return database_backup.Command()
49+
50+
51+
class TestBackupCommand:
52+
def test_get_backup_filename_compressed(self, command, mock_date):
53+
"""Test backup filename generation with compression."""
54+
backup_file, dump_file = command.get_backup_filename(Server.STAGING, compress=True)
55+
assert backup_file == "staging_backup_20240115.sql.gz"
56+
assert dump_file == "staging_backup_20240115.sql"
57+
58+
def test_get_backup_filename_uncompressed(self, command, mock_date):
59+
"""Test backup filename generation without compression."""
60+
backup_file, dump_file = command.get_backup_filename(Server.PRODUCTION, compress=False)
61+
assert backup_file == "production_backup_20240115.sql"
62+
assert dump_file == backup_file
63+
64+
def test_run_pg_dump(self, command, mock_subprocess, mock_settings):
65+
"""Test pg_dump command execution."""
66+
env = {"PGPASSWORD": "test_password"}
67+
command.run_pg_dump("test_output.sql", env)
68+
69+
mock_subprocess.assert_called_once()
70+
cmd_args = mock_subprocess.call_args[0][0]
71+
assert cmd_args == [
72+
"pg_dump",
73+
"-h",
74+
"test-db-host",
75+
"-U",
76+
"test_user",
77+
"-d",
78+
"test_db",
79+
"--no-owner",
80+
"--no-privileges",
81+
"-f",
82+
"test_output.sql",
83+
]
84+
85+
def test_compress_file(self, command, tmp_path):
86+
"""Test file compression."""
87+
input_file = tmp_path / "test.sql"
88+
output_file = tmp_path / "test.sql.gz"
89+
test_content = b"Test database content"
90+
91+
# Create test input file
92+
input_file.write_bytes(test_content)
93+
94+
# Compress the file
95+
command.compress_file(str(input_file), str(output_file))
96+
97+
# Verify compression
98+
assert output_file.exists()
99+
with gzip.open(output_file, "rb") as f:
100+
assert f.read() == test_content
101+
102+
def test_temp_file_handler_cleanup(self, tmp_path):
103+
"""Test temporary file cleanup."""
104+
test_file = tmp_path / "temp.sql"
105+
test_file.touch()
106+
107+
with temp_file_handler(str(test_file)):
108+
assert test_file.exists()
109+
assert not test_file.exists()
110+
111+
def test_temp_file_handler_cleanup_on_error(self, tmp_path):
112+
"""Test temporary file cleanup when an error occurs."""
113+
test_file = tmp_path / "temp.sql"
114+
test_file.touch()
115+
116+
with pytest.raises(ValueError):
117+
with temp_file_handler(str(test_file)):
118+
assert test_file.exists()
119+
raise ValueError("Test error")
120+
assert not test_file.exists()
121+
122+
@patch("socket.gethostname")
123+
def test_server_detection(self, mock_hostname):
124+
"""Test server environment detection."""
125+
test_cases = [
126+
("PRODUCTION-SERVER", Server.PRODUCTION),
127+
("STAGING-DB", Server.STAGING),
128+
("DEV-HOST", Server.UNKNOWN),
129+
]
130+
131+
for hostname, expected_server in test_cases:
132+
mock_hostname.return_value = hostname
133+
with patch("sde_collections.management.commands.database_backup.detect_server") as mock_detect:
134+
mock_detect.return_value = expected_server
135+
server = database_backup.detect_server()
136+
assert server == expected_server
137+
138+
@pytest.mark.parametrize(
139+
"compress,hostname",
140+
[
141+
(True, "PRODUCTION-SERVER"),
142+
(False, "STAGING-SERVER"),
143+
(True, "UNKNOWN-SERVER"),
144+
],
145+
)
146+
def test_handle_integration(self, compress, hostname, mock_subprocess, mock_date, mock_settings):
147+
"""Test full backup process integration."""
148+
with patch("socket.gethostname", return_value=hostname):
149+
call_command("database_backup", no_compress=not compress)
150+
151+
# Verify correct command execution
152+
mock_subprocess.assert_called_once()
153+
154+
# Verify correct filename used
155+
cmd_args = mock_subprocess.call_args[0][0]
156+
date_str = "20240115"
157+
server_type = hostname.split("-")[0].lower()
158+
expected_base = f"{server_type}_backup_{date_str}.sql"
159+
160+
if compress:
161+
assert cmd_args[-1] == expected_base # Temporary file
162+
# Verify cleanup attempted
163+
assert not os.path.exists(expected_base)
164+
else:
165+
assert cmd_args[-1] == expected_base
166+
167+
def test_handle_pg_dump_error(self, mock_subprocess, mock_date):
168+
"""Test error handling when pg_dump fails."""
169+
mock_subprocess.side_effect = subprocess.CalledProcessError(1, "pg_dump")
170+
171+
with patch("socket.gethostname", return_value="STAGING-SERVER"):
172+
call_command("database_backup")
173+
174+
# Verify error handling and cleanup
175+
date_str = "20240115"
176+
temp_file = f"staging_backup_{date_str}.sql"
177+
assert not os.path.exists(temp_file)
178+
179+
def test_handle_compression_error(self, mock_subprocess, mock_date, command):
180+
"""Test error handling during compression."""
181+
# Mock compression to fail
182+
command.compress_file = Mock(side_effect=Exception("Compression failed"))
183+
184+
with patch("socket.gethostname", return_value="STAGING-SERVER"):
185+
call_command("database_backup")
186+
187+
# Verify cleanup
188+
date_str = "20240115"
189+
temp_file = f"staging_backup_{date_str}.sql"
190+
assert not os.path.exists(temp_file)

0 commit comments

Comments
 (0)