22from typing import Any
33
44import pytest
5+ from fastapi .exceptions import ResponseValidationError
56from httpx import ASGITransport , AsyncClient
67from sqlalchemy import text
7- from sqlalchemy .exc import ProgrammingError
8+ from sqlalchemy .exc import ProgrammingError , SQLAlchemyError
89
9- from app .database import engine , get_db , get_test_db , test_engine
10+ from app .database import engine , get_db , test_engine , TestAsyncSessionFactory
1011from app .main import app
1112from app .models .base import Base
1213from app .redis import get_redis
@@ -43,7 +44,7 @@ def _create_db_schema(conn) -> None:
4344 pass
4445
4546
46- @pytest .fixture (scope = "session" )
47+ @pytest .fixture (scope = "session" , autouse = True )
4748async def start_db ():
4849 # The `engine` is configured for the default 'postgres' database.
4950 # We connect to it and create the test database.
@@ -63,16 +64,37 @@ async def start_db():
6364 await test_engine .dispose ()
6465
6566
66- @pytest .fixture (scope = "session" )
67- async def client (start_db ) -> AsyncGenerator [AsyncClient , Any ]: # noqa: ARG001
67+ @pytest .fixture ()
68+ async def db_session ():
69+ connection = await test_engine .connect ()
70+ transaction = await connection .begin ()
71+ session = TestAsyncSessionFactory (bind = connection )
72+
73+ try :
74+ yield session
75+ finally :
76+ # Rollback the overall transaction, restoring the state before the test ran.
77+ await session .close ()
78+ if transaction .is_active :
79+ await transaction .rollback ()
80+ await connection .close ()
81+
82+
83+ @pytest .fixture (scope = "function" )
84+ async def client (db_session ) -> AsyncGenerator [AsyncClient , Any ]: # noqa: ARG001
6885 transport = ASGITransport (
6986 app = app ,
7087 )
88+
89+ async def override_get_db ():
90+ yield db_session
91+ await db_session .commit ()
92+
7193 async with AsyncClient (
7294 base_url = "http://testserver/v1" ,
7395 headers = {"Content-Type" : "application/json" },
7496 transport = transport ,
7597 ) as test_client :
76- app .dependency_overrides [get_db ] = get_test_db
98+ app .dependency_overrides [get_db ] = override_get_db
7799 app .redis = await get_redis ()
78100 yield test_client
0 commit comments