|
6 | 6 | from httpx import ASGITransport, AsyncClient |
7 | 7 | from sqlalchemy import text |
8 | 8 | from sqlalchemy.exc import ProgrammingError |
| 9 | +from sqlalchemy.ext.asyncio import AsyncSession |
9 | 10 |
|
10 | 11 | from app.database import engine, test_engine, get_test_db, get_db |
11 | 12 | from app.main import app |
@@ -64,16 +65,43 @@ async def start_db(): |
64 | 65 | await test_engine.dispose() |
65 | 66 |
|
66 | 67 |
|
67 | | -@pytest.fixture(scope="session") |
68 | | -async def client(start_db) -> AsyncGenerator[AsyncClient, Any]: # noqa: ARG001 |
69 | | - transport = ASGITransport( |
70 | | - app=app, |
71 | | - ) |
| 68 | +@pytest.fixture(scope="function") |
| 69 | +async def db_session(start_db) -> AsyncGenerator[AsyncSession, Any]: |
| 70 | + """ |
| 71 | + Provide a transactional database session for each test function. |
| 72 | + Rolls back changes after the test. |
| 73 | + """ |
| 74 | + connection = await test_engine.connect() |
| 75 | + transaction = await connection.begin() |
| 76 | + session = AsyncSession(bind=connection) |
| 77 | + |
| 78 | + yield session |
| 79 | + |
| 80 | + await session.close() |
| 81 | + await transaction.rollback() |
| 82 | + await connection.close() |
| 83 | + |
| 84 | + |
| 85 | +@pytest.fixture(scope="function") |
| 86 | +async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, Any]: |
| 87 | + """ |
| 88 | + Provide a test client for making API requests. |
| 89 | + Uses the function-scoped db_session for test isolation. |
| 90 | + """ |
| 91 | + |
| 92 | + def get_test_db_override(): |
| 93 | + yield db_session |
| 94 | + |
| 95 | + app.dependency_overrides[get_db] = get_test_db_override |
| 96 | + app.redis = await get_redis() |
| 97 | + |
| 98 | + transport = ASGITransport(app=app) |
72 | 99 | async with AsyncClient( |
73 | 100 | base_url="http://testserver/v1", |
74 | 101 | headers={"Content-Type": "application/json"}, |
75 | 102 | transport=transport, |
76 | 103 | ) as test_client: |
77 | | - app.dependency_overrides[get_db] = get_test_db |
78 | | - app.redis = await get_redis() |
79 | 104 | yield test_client |
| 105 | + |
| 106 | + # Clean up dependency overrides |
| 107 | + del app.dependency_overrides[get_db] |
0 commit comments