Skip to content

Commit 1188734

Browse files
committed
refactor: update test fixtures for improved database session management and isolation
1 parent 1fe0faa commit 1188734

File tree

1 file changed

+35
-7
lines changed

1 file changed

+35
-7
lines changed

tests/conftest.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from httpx import ASGITransport, AsyncClient
77
from sqlalchemy import text
88
from sqlalchemy.exc import ProgrammingError
9+
from sqlalchemy.ext.asyncio import AsyncSession
910

1011
from app.database import engine, test_engine, get_test_db, get_db
1112
from app.main import app
@@ -64,16 +65,43 @@ async def start_db():
6465
await test_engine.dispose()
6566

6667

67-
@pytest.fixture(scope="session")
68-
async def client(start_db) -> AsyncGenerator[AsyncClient, Any]: # noqa: ARG001
69-
transport = ASGITransport(
70-
app=app,
71-
)
68+
@pytest.fixture(scope="function")
69+
async def db_session(start_db) -> AsyncGenerator[AsyncSession, Any]:
70+
"""
71+
Provide a transactional database session for each test function.
72+
Rolls back changes after the test.
73+
"""
74+
connection = await test_engine.connect()
75+
transaction = await connection.begin()
76+
session = AsyncSession(bind=connection)
77+
78+
yield session
79+
80+
await session.close()
81+
await transaction.rollback()
82+
await connection.close()
83+
84+
85+
@pytest.fixture(scope="function")
86+
async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, Any]:
87+
"""
88+
Provide a test client for making API requests.
89+
Uses the function-scoped db_session for test isolation.
90+
"""
91+
92+
def get_test_db_override():
93+
yield db_session
94+
95+
app.dependency_overrides[get_db] = get_test_db_override
96+
app.redis = await get_redis()
97+
98+
transport = ASGITransport(app=app)
7299
async with AsyncClient(
73100
base_url="http://testserver/v1",
74101
headers={"Content-Type": "application/json"},
75102
transport=transport,
76103
) as test_client:
77-
app.dependency_overrides[get_db] = get_test_db
78-
app.redis = await get_redis()
79104
yield test_client
105+
106+
# Clean up dependency overrides
107+
del app.dependency_overrides[get_db]

0 commit comments

Comments
 (0)