Skip to content

Commit 74d2a97

Browse files
authored
migrate to modern-di (#31)
1 parent 453a5fb commit 74d2a97

File tree

14 files changed

+312
-288
lines changed

14 files changed

+312
-288
lines changed

app/__main__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import granian
22
from granian.constants import Interfaces, Loops
33

4-
from app import ioc
4+
from app.settings import settings
55

66

77
if __name__ == "__main__":
8-
settings = ioc.IOCContainer.settings.sync_resolve()
98
granian.Granian(
109
target="app.application:application",
1110
address="0.0.0.0", # noqa: S104

app/api/decks.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,20 @@
22

33
import fastapi
44
from advanced_alchemy.exceptions import NotFoundError
5+
from modern_di_fastapi import FromDI
56
from sqlalchemy import orm
67
from starlette import status
7-
from that_depends.providers import container_context
88

99
from app import ioc, models, schemas
1010
from app.repositories import CardsService, DecksService
1111

1212

13-
async def init_di_context() -> typing.AsyncIterator[None]:
14-
async with container_context():
15-
yield
16-
17-
18-
ROUTER: typing.Final = fastapi.APIRouter(dependencies=[fastapi.Depends(init_di_context)])
13+
ROUTER: typing.Final = fastapi.APIRouter()
1914

2015

2116
@ROUTER.get("/decks/")
2217
async def list_decks(
23-
decks_service: DecksService = fastapi.Depends(ioc.IOCContainer.decks_service),
18+
decks_service: DecksService = FromDI(ioc.IOCContainer.decks_service),
2419
) -> schemas.Decks:
2520
objects = await decks_service.list()
2621
return typing.cast(schemas.Decks, {"items": objects})
@@ -29,7 +24,7 @@ async def list_decks(
2924
@ROUTER.get("/decks/{deck_id}/")
3025
async def get_deck(
3126
deck_id: int,
32-
decks_service: DecksService = fastapi.Depends(ioc.IOCContainer.decks_service),
27+
decks_service: DecksService = FromDI(ioc.IOCContainer.decks_service),
3328
) -> schemas.Deck:
3429
instance = await decks_service.get_one_or_none(
3530
models.Deck.id == deck_id,
@@ -45,7 +40,7 @@ async def get_deck(
4540
async def update_deck(
4641
deck_id: int,
4742
data: schemas.DeckCreate,
48-
decks_service: DecksService = fastapi.Depends(ioc.IOCContainer.decks_service),
43+
decks_service: DecksService = FromDI(ioc.IOCContainer.decks_service),
4944
) -> schemas.Deck:
5045
try:
5146
instance = await decks_service.update(data=data.model_dump(), item_id=deck_id)
@@ -58,7 +53,7 @@ async def update_deck(
5853
@ROUTER.post("/decks/")
5954
async def create_deck(
6055
data: schemas.DeckCreate,
61-
decks_service: DecksService = fastapi.Depends(ioc.IOCContainer.decks_service),
56+
decks_service: DecksService = FromDI(ioc.IOCContainer.decks_service),
6257
) -> schemas.Deck:
6358
instance = await decks_service.create(data)
6459
return typing.cast(schemas.Deck, instance)
@@ -67,7 +62,7 @@ async def create_deck(
6762
@ROUTER.get("/decks/{deck_id}/cards/")
6863
async def list_cards(
6964
deck_id: int,
70-
cards_service: CardsService = fastapi.Depends(ioc.IOCContainer.cards_service),
65+
cards_service: CardsService = FromDI(ioc.IOCContainer.cards_service),
7166
) -> schemas.Cards:
7267
objects = await cards_service.list(models.Card.deck_id == deck_id)
7368
return typing.cast(schemas.Cards, {"items": objects})
@@ -76,7 +71,7 @@ async def list_cards(
7671
@ROUTER.get("/cards/{card_id}/")
7772
async def get_card(
7873
card_id: int,
79-
cards_service: CardsService = fastapi.Depends(ioc.IOCContainer.cards_service),
74+
cards_service: CardsService = FromDI(ioc.IOCContainer.cards_service),
8075
) -> schemas.Card:
8176
instance = await cards_service.get_one_or_none(models.Card.id == card_id)
8277
if not instance:
@@ -88,7 +83,7 @@ async def get_card(
8883
async def create_cards(
8984
deck_id: int,
9085
data: list[schemas.CardCreate],
91-
cards_service: CardsService = fastapi.Depends(ioc.IOCContainer.cards_service),
86+
cards_service: CardsService = FromDI(ioc.IOCContainer.cards_service),
9287
) -> schemas.Cards:
9388
objects = await cards_service.create_many(
9489
data=[models.Card(**card.model_dump(), deck_id=deck_id) for card in data],
@@ -100,7 +95,7 @@ async def create_cards(
10095
async def update_cards(
10196
deck_id: int,
10297
data: list[schemas.Card],
103-
cards_service: CardsService = fastapi.Depends(ioc.IOCContainer.cards_service),
98+
cards_service: CardsService = FromDI(ioc.IOCContainer.cards_service),
10499
) -> schemas.Cards:
105100
objects = await cards_service.upsert_many(
106101
data=[models.Card(**card.model_dump(exclude={"deck_id"}), deck_id=deck_id) for card in data],

app/application.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22
import typing
33

44
import fastapi
5-
from advanced_alchemy.exceptions import ForeignKeyError
5+
import modern_di
6+
import modern_di_fastapi
7+
from advanced_alchemy.exceptions import DuplicateKeyError, ForeignKeyError
68

79
from app import exceptions, ioc
810
from app.api.decks import ROUTER
11+
from app.settings import settings
912

1013

1114
def include_routers(app: fastapi.FastAPI) -> None:
@@ -14,25 +17,29 @@ def include_routers(app: fastapi.FastAPI) -> None:
1417

1518
class AppBuilder:
1619
def __init__(self) -> None:
17-
self.settings = ioc.IOCContainer.settings.sync_resolve()
1820
self.app: fastapi.FastAPI = fastapi.FastAPI(
19-
title=self.settings.service_name,
20-
debug=self.settings.debug,
21+
title=settings.service_name,
22+
debug=settings.debug,
2123
lifespan=self.lifespan_manager,
24+
dependencies=[fastapi.Depends(modern_di_fastapi.enter_di_request_scope)],
2225
)
26+
self.di_container = modern_di.Container(scope=modern_di.Scope.APP)
27+
modern_di_fastapi.save_di_container(self.app, self.di_container)
2328
include_routers(self.app)
2429
self.app.add_exception_handler(
2530
ForeignKeyError,
2631
exceptions.foreign_key_error_handler, # type: ignore[arg-type]
2732
)
33+
self.app.add_exception_handler(
34+
DuplicateKeyError,
35+
exceptions.foreign_key_error_handler, # type: ignore[arg-type]
36+
)
2837

2938
@contextlib.asynccontextmanager
3039
async def lifespan_manager(self, _: fastapi.FastAPI) -> typing.AsyncIterator[dict[str, typing.Any]]:
31-
try:
32-
await ioc.IOCContainer.init_resources()
40+
async with self.di_container:
41+
await ioc.IOCContainer.async_resolve_creators(self.di_container)
3342
yield {}
34-
finally:
35-
await ioc.IOCContainer.tear_down()
3643

3744

3845
application = AppBuilder().app

app/exceptions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from advanced_alchemy.exceptions import ForeignKeyError
1+
from advanced_alchemy.exceptions import DuplicateKeyError, ForeignKeyError
22
from fastapi.responses import JSONResponse
33
from starlette import status
44
from starlette.requests import Request
55

66

7-
async def foreign_key_error_handler(_: Request, exc: ForeignKeyError) -> JSONResponse:
7+
async def foreign_key_error_handler(_: Request, exc: ForeignKeyError | DuplicateKeyError) -> JSONResponse:
88
return JSONResponse(
99
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
1010
content={"detail": exc.detail},

app/ioc.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
1-
from that_depends import BaseContainer, providers
1+
from modern_di import BaseGraph, Scope, providers
22

33
from app import repositories
44
from app.resources.db import create_sa_engine, create_session
5-
from app.settings import Settings
65

76

8-
class IOCContainer(BaseContainer):
9-
settings = providers.Singleton(Settings)
7+
class IOCContainer(BaseGraph):
8+
database_engine = providers.Resource(Scope.APP, create_sa_engine)
9+
session = providers.Resource(Scope.REQUEST, create_session, engine=database_engine.cast)
1010

11-
database_engine = providers.Resource(create_sa_engine, settings=settings.cast)
12-
session = providers.ContextResource(create_session, engine=database_engine.cast)
13-
14-
decks_service = providers.Factory(repositories.DecksService, session=session)
15-
cards_service = providers.Factory(repositories.CardsService, session=session)
11+
decks_service = providers.Factory(Scope.REQUEST, repositories.DecksService, session=session.cast)
12+
cards_service = providers.Factory(Scope.REQUEST, repositories.CardsService, session=session.cast)

app/resources/db.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33

44
from sqlalchemy.ext import asyncio as sa
55

6-
from app.settings import Settings
6+
from app.settings import settings
77

88

99
logger = logging.getLogger(__name__)
1010

1111

12-
async def create_sa_engine(settings: Settings) -> typing.AsyncIterator[sa.AsyncEngine]:
12+
async def create_sa_engine() -> typing.AsyncIterator[sa.AsyncEngine]:
1313
logger.debug("Initializing SQLAlchemy engine")
1414
engine = sa.create_async_engine(
1515
url=settings.db_dsn,
@@ -27,6 +27,14 @@ async def create_sa_engine(settings: Settings) -> typing.AsyncIterator[sa.AsyncE
2727
logger.debug("SQLAlchemy engine has been cleaned up")
2828

2929

30+
class CustomAsyncSession(sa.AsyncSession):
31+
async def close(self) -> None:
32+
if isinstance(self.bind, sa.AsyncConnection):
33+
return self.expunge_all()
34+
35+
return await super().close()
36+
37+
3038
async def create_session(engine: sa.AsyncEngine) -> typing.AsyncIterator[sa.AsyncSession]:
31-
async with sa.AsyncSession(engine, expire_on_commit=False, autoflush=False) as session:
39+
async with CustomAsyncSession(engine, expire_on_commit=False, autoflush=False) as session:
3240
yield session

app/settings.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,6 @@ def db_dsn(self) -> URL:
3232
self.db_port,
3333
self.db_database,
3434
)
35+
36+
37+
settings = Settings()

migrations/env.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
from alembic import context
44
from sqlalchemy import URL, create_engine
55

6-
from app import ioc
76
from app.models import METADATA
7+
from app.settings import settings
88

99

1010
def get_dsn() -> URL:
11-
settings = ioc.IOCContainer.settings.sync_resolve()
1211
db_dsn = settings.db_dsn
1312
return db_dsn.set(drivername="postgresql")
1413

pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,16 @@ dependencies = [
1414
"advanced-alchemy",
1515
"pydantic-settings",
1616
"granian",
17-
"that-depends",
17+
"modern-di-fastapi",
1818
# database
1919
"alembic",
2020
"psycopg2",
2121
"sqlalchemy",
2222
"asyncpg",
2323
]
2424

25-
[tool.uv]
26-
dev-dependencies = [
25+
[dependency-groups]
26+
dev = [
2727
"polyfactory",
2828
"httpx",
2929
"pytest",
@@ -38,7 +38,7 @@ dev-dependencies = [
3838
fix = true
3939
unsafe-fixes = true
4040
line-length = 120
41-
target-version = "py311"
41+
target-version = "py312"
4242
extend-exclude = ["bin"]
4343

4444
[tool.ruff.lint]

tests/conftest.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import typing
22

3+
import modern_di
4+
import modern_di_fastapi
35
import pytest
46
from httpx import ASGITransport, AsyncClient
57
from sqlalchemy.ext.asyncio import AsyncSession
@@ -13,26 +15,27 @@ async def client() -> typing.AsyncIterator[AsyncClient]:
1315
async with AsyncClient(
1416
transport=ASGITransport(app=application),
1517
base_url="http://test",
16-
timeout=0,
1718
) as client:
1819
yield client
1920

2021

2122
@pytest.fixture(autouse=True)
22-
async def _prepare_ioc_container() -> typing.AsyncIterator[None]:
23-
engine = await ioc.IOCContainer.database_engine()
24-
connection = await engine.connect()
25-
transaction = await connection.begin()
26-
await connection.begin_nested()
27-
session = AsyncSession(connection, expire_on_commit=False, autoflush=False)
28-
ioc.IOCContainer.session.override(session)
29-
30-
try:
31-
yield
32-
finally:
33-
if connection.in_transaction():
34-
await transaction.rollback()
35-
await connection.close()
36-
37-
ioc.IOCContainer.reset_override()
38-
await ioc.IOCContainer.tear_down()
23+
async def di_container() -> modern_di.Container:
24+
return modern_di_fastapi.fetch_di_container(application)
25+
26+
27+
@pytest.fixture(autouse=True)
28+
async def db_session(di_container: modern_di.Container) -> typing.AsyncIterator[AsyncSession]:
29+
async with di_container:
30+
engine = await ioc.IOCContainer.database_engine.async_resolve(di_container)
31+
connection = await engine.connect()
32+
transaction = await connection.begin()
33+
await connection.begin_nested()
34+
ioc.IOCContainer.database_engine.override(connection, di_container)
35+
36+
try:
37+
yield AsyncSession(connection, expire_on_commit=False, autoflush=False)
38+
finally:
39+
if connection.in_transaction():
40+
await transaction.rollback()
41+
await connection.close()

0 commit comments

Comments
 (0)