Skip to content

Commit 4049217

Browse files
feat(db): back up database before running migrations
Just in case.
1 parent 59b4a23 commit 4049217

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_impl.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import sqlite3
2+
from contextlib import closing
3+
from datetime import datetime
24
from pathlib import Path
35
from typing import Optional
46

@@ -32,6 +34,7 @@ def __init__(self, db: SqliteDatabase) -> None:
3234
self._db = db
3335
self._logger = db.logger
3436
self._migration_set = MigrationSet()
37+
self._backup_path: Optional[Path] = None
3538

3639
def register_migration(self, migration: Migration) -> None:
3740
"""Registers a migration."""
@@ -55,6 +58,18 @@ def run_migrations(self) -> bool:
5558
return False
5659

5760
self._logger.info("Database update needed")
61+
62+
# Make a backup of the db if it needs to be updated and is a file db
63+
if self._db.db_path is not None:
64+
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
65+
self._backup_path = self._db.db_path.parent / f"{self._db.db_path.stem}_backup_{timestamp}.db"
66+
self._logger.info(f"Backing up database to {str(self._backup_path)}")
67+
# Use SQLite to do the backup
68+
with closing(sqlite3.connect(self._backup_path)) as backup_conn:
69+
self._db.conn.backup(backup_conn)
70+
else:
71+
self._logger.info("Using in-memory database, no backup needed")
72+
5873
next_migration = self._migration_set.get(from_version=self._get_current_version(cursor))
5974
while next_migration is not None:
6075
self._run_migration(next_migration)

tests/test_sqlite_migrator.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,32 @@ def test_migrator_runs_all_migrations_file(logger: Logger) -> None:
250250
db.conn.close()
251251

252252

253+
def test_migrator_backs_up_db(logger: Logger) -> None:
254+
with TemporaryDirectory() as tempdir:
255+
original_db_path = Path(tempdir) / "invokeai.db"
256+
db = SqliteDatabase(db_path=original_db_path, logger=logger, verbose=False)
257+
# Write some data to the db to test for successful backup
258+
temp_cursor = db.conn.cursor()
259+
temp_cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
260+
db.conn.commit()
261+
# Set up the migrator
262+
migrator = SqliteMigrator(db=db)
263+
migrations = [Migration(from_version=i, to_version=i + 1, callback=create_migrate(i)) for i in range(0, 3)]
264+
for migration in migrations:
265+
migrator.register_migration(migration)
266+
migrator.run_migrations()
267+
# Must manually close else we get an error on Windows
268+
db.conn.close()
269+
assert original_db_path.exists()
270+
# We should have a backup file when we migrated a file db
271+
assert migrator._backup_path
272+
# Check that the test table exists as a proxy for successful backup
273+
with closing(sqlite3.connect(migrator._backup_path)) as backup_db_conn:
274+
backup_db_cursor = backup_db_conn.cursor()
275+
backup_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
276+
assert backup_db_cursor.fetchone() is not None
277+
278+
253279
def test_migrator_makes_no_changes_on_failed_migration(
254280
migrator: SqliteMigrator, migration_no_op: Migration, failing_migrate_callback: MigrateCallback
255281
) -> None:

0 commit comments

Comments
 (0)