Skip to content

Commit 84d2aaf

Browse files
phernandezclaude
andauthored
fix: eliminate redundant database migration initialization (#146)
Co-authored-by: Claude <[email protected]>
1 parent 7789864 commit 84d2aaf

File tree

4 files changed

+254
-28
lines changed

4 files changed

+254
-28
lines changed

src/basic_memory/db.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
# Module level state
2424
_engine: Optional[AsyncEngine] = None
2525
_session_maker: Optional[async_sessionmaker[AsyncSession]] = None
26+
_migrations_completed: bool = False
2627

2728

2829
class DatabaseType(Enum):
@@ -72,18 +73,35 @@ async def scoped_session(
7273
await factory.remove()
7374

7475

76+
def _create_engine_and_session(
77+
db_path: Path, db_type: DatabaseType = DatabaseType.FILESYSTEM
78+
) -> tuple[AsyncEngine, async_sessionmaker[AsyncSession]]:
79+
"""Internal helper to create engine and session maker."""
80+
db_url = DatabaseType.get_db_url(db_path, db_type)
81+
logger.debug(f"Creating engine for db_url: {db_url}")
82+
engine = create_async_engine(db_url, connect_args={"check_same_thread": False})
83+
session_maker = async_sessionmaker(engine, expire_on_commit=False)
84+
return engine, session_maker
85+
86+
7587
async def get_or_create_db(
7688
db_path: Path,
7789
db_type: DatabaseType = DatabaseType.FILESYSTEM,
90+
ensure_migrations: bool = True,
91+
app_config: Optional["BasicMemoryConfig"] = None,
7892
) -> tuple[AsyncEngine, async_sessionmaker[AsyncSession]]: # pragma: no cover
7993
"""Get or create database engine and session maker."""
8094
global _engine, _session_maker
8195

8296
if _engine is None:
83-
db_url = DatabaseType.get_db_url(db_path, db_type)
84-
logger.debug(f"Creating engine for db_url: {db_url}")
85-
_engine = create_async_engine(db_url, connect_args={"check_same_thread": False})
86-
_session_maker = async_sessionmaker(_engine, expire_on_commit=False)
97+
_engine, _session_maker = _create_engine_and_session(db_path, db_type)
98+
99+
# Run migrations automatically unless explicitly disabled
100+
if ensure_migrations:
101+
if app_config is None:
102+
from basic_memory.config import app_config as global_app_config
103+
app_config = global_app_config
104+
await run_migrations(app_config, db_type)
87105

88106
# These checks should never fail since we just created the engine and session maker
89107
# if they were None, but we'll check anyway for the type checker
@@ -100,12 +118,13 @@ async def get_or_create_db(
100118

101119
async def shutdown_db() -> None: # pragma: no cover
102120
"""Clean up database connections."""
103-
global _engine, _session_maker
121+
global _engine, _session_maker, _migrations_completed
104122

105123
if _engine:
106124
await _engine.dispose()
107125
_engine = None
108126
_session_maker = None
127+
_migrations_completed = False
109128

110129

111130
@asynccontextmanager
@@ -119,7 +138,7 @@ async def engine_session_factory(
119138
for each test. For production use, use get_or_create_db() instead.
120139
"""
121140

122-
global _engine, _session_maker
141+
global _engine, _session_maker, _migrations_completed
123142

124143
db_url = DatabaseType.get_db_url(db_path, db_type)
125144
logger.debug(f"Creating engine for db_url: {db_url}")
@@ -143,12 +162,20 @@ async def engine_session_factory(
143162
await _engine.dispose()
144163
_engine = None
145164
_session_maker = None
165+
_migrations_completed = False
146166

147167

148168
async def run_migrations(
149-
app_config: BasicMemoryConfig, database_type=DatabaseType.FILESYSTEM
169+
app_config: BasicMemoryConfig, database_type=DatabaseType.FILESYSTEM, force: bool = False
150170
): # pragma: no cover
151171
"""Run any pending alembic migrations."""
172+
global _migrations_completed
173+
174+
# Skip if migrations already completed unless forced
175+
if _migrations_completed and not force:
176+
logger.debug("Migrations already completed in this session, skipping")
177+
return
178+
152179
logger.info("Running database migrations...")
153180
try:
154181
# Get the absolute path to the alembic directory relative to this file
@@ -170,11 +197,18 @@ async def run_migrations(
170197
command.upgrade(config, "head")
171198
logger.info("Migrations completed successfully")
172199

173-
_, session_maker = await get_or_create_db(app_config.database_path, database_type)
200+
# Get session maker - ensure we don't trigger recursive migration calls
201+
if _session_maker is None:
202+
_, session_maker = _create_engine_and_session(app_config.database_path, database_type)
203+
else:
204+
session_maker = _session_maker
174205

175206
# initialize the search Index schema
176207
# the project_id is not used for init_search_index, so we pass a dummy value
177208
await SearchRepository(session_maker, 1).init_search_index()
209+
210+
# Mark migrations as completed
211+
_migrations_completed = True
178212
except Exception as e: # pragma: no cover
179213
logger.error(f"Error running migrations: {e}")
180214
raise

src/basic_memory/services/initialization.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,21 @@
1717

1818

1919
async def initialize_database(app_config: BasicMemoryConfig) -> None:
20-
"""Run database migrations to ensure schema is up to date.
20+
"""Initialize database with migrations handled automatically by get_or_create_db.
2121
2222
Args:
2323
app_config: The Basic Memory project configuration
24+
25+
Note:
26+
Database migrations are now handled automatically when the database
27+
connection is first established via get_or_create_db().
2428
"""
29+
# Trigger database initialization and migrations by getting the database connection
2530
try:
26-
logger.info("Running database migrations...")
27-
await db.run_migrations(app_config)
28-
logger.info("Migrations completed successfully")
31+
await db.get_or_create_db(app_config.database_path)
32+
logger.info("Database initialization completed")
2933
except Exception as e:
30-
logger.error(f"Error running migrations: {e}")
34+
logger.error(f"Error initializing database: {e}")
3135
# Allow application to continue - it might still work
3236
# depending on what the error was, and will fail with a
3337
# more specific error if the database is actually unusable
@@ -44,9 +48,9 @@ async def reconcile_projects_with_config(app_config: BasicMemoryConfig):
4448
"""
4549
logger.info("Reconciling projects from config with database...")
4650

47-
# Get database session
51+
# Get database session - migrations handled centrally
4852
_, session_maker = await db.get_or_create_db(
49-
db_path=app_config.database_path, db_type=db.DatabaseType.FILESYSTEM
53+
db_path=app_config.database_path, db_type=db.DatabaseType.FILESYSTEM, ensure_migrations=False
5054
)
5155
project_repository = ProjectRepository(session_maker)
5256

@@ -65,9 +69,9 @@ async def reconcile_projects_with_config(app_config: BasicMemoryConfig):
6569

6670

6771
async def migrate_legacy_projects(app_config: BasicMemoryConfig):
68-
# Get database session
72+
# Get database session - migrations handled centrally
6973
_, session_maker = await db.get_or_create_db(
70-
db_path=app_config.database_path, db_type=db.DatabaseType.FILESYSTEM
74+
db_path=app_config.database_path, db_type=db.DatabaseType.FILESYSTEM, ensure_migrations=False
7175
)
7276
logger.info("Migrating legacy projects...")
7377
project_repository = ProjectRepository(session_maker)
@@ -134,9 +138,9 @@ async def initialize_file_sync(
134138
# delay import
135139
from basic_memory.sync import WatchService
136140

137-
# Load app configuration
141+
# Load app configuration - migrations handled centrally
138142
_, session_maker = await db.get_or_create_db(
139-
db_path=app_config.database_path, db_type=db.DatabaseType.FILESYSTEM
143+
db_path=app_config.database_path, db_type=db.DatabaseType.FILESYSTEM, ensure_migrations=False
140144
)
141145
project_repository = ProjectRepository(session_maker)
142146

tests/services/test_initialization.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,21 @@
1717

1818

1919
@pytest.mark.asyncio
20-
@patch("basic_memory.services.initialization.db.run_migrations")
21-
async def test_initialize_database(mock_run_migrations, project_config):
20+
@patch("basic_memory.services.initialization.db.get_or_create_db")
21+
async def test_initialize_database(mock_get_or_create_db, app_config):
2222
"""Test initializing the database."""
23-
await initialize_database(project_config)
24-
mock_run_migrations.assert_called_once_with(project_config)
23+
mock_get_or_create_db.return_value = (MagicMock(), MagicMock())
24+
await initialize_database(app_config)
25+
mock_get_or_create_db.assert_called_once_with(app_config.database_path)
2526

2627

2728
@pytest.mark.asyncio
28-
@patch("basic_memory.services.initialization.db.run_migrations")
29-
async def test_initialize_database_error(mock_run_migrations, project_config):
29+
@patch("basic_memory.services.initialization.db.get_or_create_db")
30+
async def test_initialize_database_error(mock_get_or_create_db, app_config):
3031
"""Test handling errors during database initialization."""
31-
mock_run_migrations.side_effect = Exception("Test error")
32-
await initialize_database(project_config)
33-
mock_run_migrations.assert_called_once_with(project_config)
32+
mock_get_or_create_db.side_effect = Exception("Test error")
33+
await initialize_database(app_config)
34+
mock_get_or_create_db.assert_called_once_with(app_config.database_path)
3435

3536

3637
@pytest.mark.asyncio
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
"""Tests for database migration deduplication functionality."""
2+
3+
import pytest
4+
from unittest.mock import patch, AsyncMock, MagicMock
5+
6+
from basic_memory import db
7+
8+
9+
@pytest.fixture
10+
def mock_alembic_config():
11+
"""Mock Alembic config to avoid actual migration runs."""
12+
with patch("basic_memory.db.Config") as mock_config_class:
13+
mock_config = MagicMock()
14+
mock_config_class.return_value = mock_config
15+
yield mock_config
16+
17+
18+
@pytest.fixture
19+
def mock_alembic_command():
20+
"""Mock Alembic command to avoid actual migration runs."""
21+
with patch("basic_memory.db.command") as mock_command:
22+
yield mock_command
23+
24+
25+
@pytest.fixture
26+
def mock_search_repository():
27+
"""Mock SearchRepository to avoid database dependencies."""
28+
with patch("basic_memory.db.SearchRepository") as mock_repo_class:
29+
mock_repo = AsyncMock()
30+
mock_repo_class.return_value = mock_repo
31+
yield mock_repo
32+
33+
34+
# Use the app_config fixture from conftest.py
35+
36+
37+
@pytest.mark.asyncio
38+
async def test_migration_deduplication_single_call(
39+
app_config, mock_alembic_config, mock_alembic_command, mock_search_repository
40+
):
41+
"""Test that migrations are only run once when called multiple times."""
42+
# Reset module state
43+
db._migrations_completed = False
44+
db._engine = None
45+
db._session_maker = None
46+
47+
# First call should run migrations
48+
await db.run_migrations(app_config)
49+
50+
# Verify migrations were called
51+
mock_alembic_command.upgrade.assert_called_once_with(mock_alembic_config, "head")
52+
mock_search_repository.init_search_index.assert_called_once()
53+
54+
# Reset mocks for second call
55+
mock_alembic_command.reset_mock()
56+
mock_search_repository.reset_mock()
57+
58+
# Second call should skip migrations
59+
await db.run_migrations(app_config)
60+
61+
# Verify migrations were NOT called again
62+
mock_alembic_command.upgrade.assert_not_called()
63+
mock_search_repository.init_search_index.assert_not_called()
64+
65+
66+
@pytest.mark.asyncio
67+
async def test_migration_force_parameter(
68+
app_config, mock_alembic_config, mock_alembic_command, mock_search_repository
69+
):
70+
"""Test that migrations can be forced to run even if already completed."""
71+
# Reset module state
72+
db._migrations_completed = False
73+
db._engine = None
74+
db._session_maker = None
75+
76+
# First call should run migrations
77+
await db.run_migrations(app_config)
78+
79+
# Verify migrations were called
80+
mock_alembic_command.upgrade.assert_called_once_with(mock_alembic_config, "head")
81+
mock_search_repository.init_search_index.assert_called_once()
82+
83+
# Reset mocks for forced call
84+
mock_alembic_command.reset_mock()
85+
mock_search_repository.reset_mock()
86+
87+
# Forced call should run migrations again
88+
await db.run_migrations(app_config, force=True)
89+
90+
# Verify migrations were called again
91+
mock_alembic_command.upgrade.assert_called_once_with(mock_alembic_config, "head")
92+
mock_search_repository.init_search_index.assert_called_once()
93+
94+
95+
@pytest.mark.asyncio
96+
async def test_migration_state_reset_on_shutdown():
97+
"""Test that migration state is reset when database is shut down."""
98+
# Set up completed state
99+
db._migrations_completed = True
100+
db._engine = AsyncMock()
101+
db._session_maker = AsyncMock()
102+
103+
# Shutdown should reset state
104+
await db.shutdown_db()
105+
106+
# Verify state was reset
107+
assert db._migrations_completed is False
108+
assert db._engine is None
109+
assert db._session_maker is None
110+
111+
112+
@pytest.mark.asyncio
113+
async def test_get_or_create_db_runs_migrations_automatically(
114+
app_config, mock_alembic_config, mock_alembic_command, mock_search_repository
115+
):
116+
"""Test that get_or_create_db runs migrations automatically."""
117+
# Reset module state
118+
db._migrations_completed = False
119+
db._engine = None
120+
db._session_maker = None
121+
122+
# First call should create engine and run migrations
123+
engine, session_maker = await db.get_or_create_db(
124+
app_config.database_path, app_config=app_config
125+
)
126+
127+
# Verify we got valid objects
128+
assert engine is not None
129+
assert session_maker is not None
130+
131+
# Verify migrations were called
132+
mock_alembic_command.upgrade.assert_called_once_with(mock_alembic_config, "head")
133+
mock_search_repository.init_search_index.assert_called_once()
134+
135+
136+
@pytest.mark.asyncio
137+
async def test_get_or_create_db_skips_migrations_when_disabled(
138+
app_config, mock_alembic_config, mock_alembic_command, mock_search_repository
139+
):
140+
"""Test that get_or_create_db can skip migrations when ensure_migrations=False."""
141+
# Reset module state
142+
db._migrations_completed = False
143+
db._engine = None
144+
db._session_maker = None
145+
146+
# Call with ensure_migrations=False should skip migrations
147+
engine, session_maker = await db.get_or_create_db(
148+
app_config.database_path, ensure_migrations=False
149+
)
150+
151+
# Verify we got valid objects
152+
assert engine is not None
153+
assert session_maker is not None
154+
155+
# Verify migrations were NOT called
156+
mock_alembic_command.upgrade.assert_not_called()
157+
mock_search_repository.init_search_index.assert_not_called()
158+
159+
160+
@pytest.mark.asyncio
161+
async def test_multiple_get_or_create_db_calls_deduplicated(
162+
app_config, mock_alembic_config, mock_alembic_command, mock_search_repository
163+
):
164+
"""Test that multiple get_or_create_db calls only run migrations once."""
165+
# Reset module state
166+
db._migrations_completed = False
167+
db._engine = None
168+
db._session_maker = None
169+
170+
# First call should create engine and run migrations
171+
await db.get_or_create_db(app_config.database_path, app_config=app_config)
172+
173+
# Verify migrations were called
174+
mock_alembic_command.upgrade.assert_called_once_with(mock_alembic_config, "head")
175+
mock_search_repository.init_search_index.assert_called_once()
176+
177+
# Reset mocks for subsequent calls
178+
mock_alembic_command.reset_mock()
179+
mock_search_repository.reset_mock()
180+
181+
# Subsequent calls should not run migrations again
182+
await db.get_or_create_db(app_config.database_path, app_config=app_config)
183+
await db.get_or_create_db(app_config.database_path, app_config=app_config)
184+
185+
# Verify migrations were NOT called again
186+
mock_alembic_command.upgrade.assert_not_called()
187+
mock_search_repository.init_search_index.assert_not_called()

0 commit comments

Comments
 (0)