Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions app/__main__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import granian
from granian.constants import Interfaces, Loops

from app import ioc
from app.settings import settings


if __name__ == "__main__":
settings = ioc.IOCContainer.settings.sync_resolve()
granian.Granian(
target="app.application:application",
address="0.0.0.0", # noqa: S104
Expand Down
25 changes: 10 additions & 15 deletions app/api/decks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,20 @@

import fastapi
from advanced_alchemy.exceptions import NotFoundError
from modern_di_fastapi import FromDI
from sqlalchemy import orm
from starlette import status
from that_depends.providers import container_context

from app import ioc, models, schemas
from app.repositories import CardsService, DecksService


async def init_di_context() -> typing.AsyncIterator[None]:
async with container_context():
yield


ROUTER: typing.Final = fastapi.APIRouter(dependencies=[fastapi.Depends(init_di_context)])
ROUTER: typing.Final = fastapi.APIRouter()


@ROUTER.get("/decks/")
async def list_decks(
decks_service: DecksService = fastapi.Depends(ioc.IOCContainer.decks_service),
decks_service: DecksService = FromDI(ioc.IOCContainer.decks_service),
) -> schemas.Decks:
objects = await decks_service.list()
return typing.cast(schemas.Decks, {"items": objects})
Expand All @@ -29,7 +24,7 @@ async def list_decks(
@ROUTER.get("/decks/{deck_id}/")
async def get_deck(
deck_id: int,
decks_service: DecksService = fastapi.Depends(ioc.IOCContainer.decks_service),
decks_service: DecksService = FromDI(ioc.IOCContainer.decks_service),
) -> schemas.Deck:
instance = await decks_service.get_one_or_none(
models.Deck.id == deck_id,
Expand All @@ -45,7 +40,7 @@ async def get_deck(
async def update_deck(
deck_id: int,
data: schemas.DeckCreate,
decks_service: DecksService = fastapi.Depends(ioc.IOCContainer.decks_service),
decks_service: DecksService = FromDI(ioc.IOCContainer.decks_service),
) -> schemas.Deck:
try:
instance = await decks_service.update(data=data.model_dump(), item_id=deck_id)
Expand All @@ -58,7 +53,7 @@ async def update_deck(
@ROUTER.post("/decks/")
async def create_deck(
data: schemas.DeckCreate,
decks_service: DecksService = fastapi.Depends(ioc.IOCContainer.decks_service),
decks_service: DecksService = FromDI(ioc.IOCContainer.decks_service),
) -> schemas.Deck:
instance = await decks_service.create(data)
return typing.cast(schemas.Deck, instance)
Expand All @@ -67,7 +62,7 @@ async def create_deck(
@ROUTER.get("/decks/{deck_id}/cards/")
async def list_cards(
deck_id: int,
cards_service: CardsService = fastapi.Depends(ioc.IOCContainer.cards_service),
cards_service: CardsService = FromDI(ioc.IOCContainer.cards_service),
) -> schemas.Cards:
objects = await cards_service.list(models.Card.deck_id == deck_id)
return typing.cast(schemas.Cards, {"items": objects})
Expand All @@ -76,7 +71,7 @@ async def list_cards(
@ROUTER.get("/cards/{card_id}/")
async def get_card(
card_id: int,
cards_service: CardsService = fastapi.Depends(ioc.IOCContainer.cards_service),
cards_service: CardsService = FromDI(ioc.IOCContainer.cards_service),
) -> schemas.Card:
instance = await cards_service.get_one_or_none(models.Card.id == card_id)
if not instance:
Expand All @@ -88,7 +83,7 @@ async def get_card(
async def create_cards(
deck_id: int,
data: list[schemas.CardCreate],
cards_service: CardsService = fastapi.Depends(ioc.IOCContainer.cards_service),
cards_service: CardsService = FromDI(ioc.IOCContainer.cards_service),
) -> schemas.Cards:
objects = await cards_service.create_many(
data=[models.Card(**card.model_dump(), deck_id=deck_id) for card in data],
Expand All @@ -100,7 +95,7 @@ async def create_cards(
async def update_cards(
deck_id: int,
data: list[schemas.Card],
cards_service: CardsService = fastapi.Depends(ioc.IOCContainer.cards_service),
cards_service: CardsService = FromDI(ioc.IOCContainer.cards_service),
) -> schemas.Cards:
objects = await cards_service.upsert_many(
data=[models.Card(**card.model_dump(exclude={"deck_id"}), deck_id=deck_id) for card in data],
Expand Down
23 changes: 15 additions & 8 deletions app/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
import typing

import fastapi
from advanced_alchemy.exceptions import ForeignKeyError
import modern_di
import modern_di_fastapi
from advanced_alchemy.exceptions import DuplicateKeyError, ForeignKeyError

from app import exceptions, ioc
from app.api.decks import ROUTER
from app.settings import settings


def include_routers(app: fastapi.FastAPI) -> None:
Expand All @@ -14,25 +17,29 @@ def include_routers(app: fastapi.FastAPI) -> None:

class AppBuilder:
def __init__(self) -> None:
self.settings = ioc.IOCContainer.settings.sync_resolve()
self.app: fastapi.FastAPI = fastapi.FastAPI(
title=self.settings.service_name,
debug=self.settings.debug,
title=settings.service_name,
debug=settings.debug,
lifespan=self.lifespan_manager,
dependencies=[fastapi.Depends(modern_di_fastapi.enter_di_request_scope)],
)
self.di_container = modern_di.Container(scope=modern_di.Scope.APP)
modern_di_fastapi.save_di_container(self.app, self.di_container)
include_routers(self.app)
self.app.add_exception_handler(
ForeignKeyError,
exceptions.foreign_key_error_handler, # type: ignore[arg-type]
)
self.app.add_exception_handler(
DuplicateKeyError,
exceptions.foreign_key_error_handler, # type: ignore[arg-type]
)

@contextlib.asynccontextmanager
async def lifespan_manager(self, _: fastapi.FastAPI) -> typing.AsyncIterator[dict[str, typing.Any]]:
try:
await ioc.IOCContainer.init_resources()
async with self.di_container:
await ioc.IOCContainer.async_resolve_creators(self.di_container)
yield {}
finally:
await ioc.IOCContainer.tear_down()


application = AppBuilder().app
4 changes: 2 additions & 2 deletions app/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from advanced_alchemy.exceptions import ForeignKeyError
from advanced_alchemy.exceptions import DuplicateKeyError, ForeignKeyError
from fastapi.responses import JSONResponse
from starlette import status
from starlette.requests import Request


async def foreign_key_error_handler(_: Request, exc: ForeignKeyError) -> JSONResponse:
async def foreign_key_error_handler(_: Request, exc: ForeignKeyError | DuplicateKeyError) -> JSONResponse:
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content={"detail": exc.detail},
Expand Down
15 changes: 6 additions & 9 deletions app/ioc.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
from that_depends import BaseContainer, providers
from modern_di import BaseGraph, Scope, providers

from app import repositories
from app.resources.db import create_sa_engine, create_session
from app.settings import Settings


class IOCContainer(BaseContainer):
settings = providers.Singleton(Settings)
class IOCContainer(BaseGraph):
database_engine = providers.Resource(Scope.APP, create_sa_engine)
session = providers.Resource(Scope.REQUEST, create_session, engine=database_engine.cast)

database_engine = providers.Resource(create_sa_engine, settings=settings.cast)
session = providers.ContextResource(create_session, engine=database_engine.cast)

decks_service = providers.Factory(repositories.DecksService, session=session)
cards_service = providers.Factory(repositories.CardsService, session=session)
decks_service = providers.Factory(Scope.REQUEST, repositories.DecksService, session=session.cast)
cards_service = providers.Factory(Scope.REQUEST, repositories.CardsService, session=session.cast)
14 changes: 11 additions & 3 deletions app/resources/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

from sqlalchemy.ext import asyncio as sa

from app.settings import Settings
from app.settings import settings


logger = logging.getLogger(__name__)


async def create_sa_engine(settings: Settings) -> typing.AsyncIterator[sa.AsyncEngine]:
async def create_sa_engine() -> typing.AsyncIterator[sa.AsyncEngine]:
logger.debug("Initializing SQLAlchemy engine")
engine = sa.create_async_engine(
url=settings.db_dsn,
Expand All @@ -27,6 +27,14 @@ async def create_sa_engine(settings: Settings) -> typing.AsyncIterator[sa.AsyncE
logger.debug("SQLAlchemy engine has been cleaned up")


class CustomAsyncSession(sa.AsyncSession):
async def close(self) -> None:
if isinstance(self.bind, sa.AsyncConnection):
return self.expunge_all()

return await super().close()


async def create_session(engine: sa.AsyncEngine) -> typing.AsyncIterator[sa.AsyncSession]:
async with sa.AsyncSession(engine, expire_on_commit=False, autoflush=False) as session:
async with CustomAsyncSession(engine, expire_on_commit=False, autoflush=False) as session:
yield session
3 changes: 3 additions & 0 deletions app/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,6 @@ def db_dsn(self) -> URL:
self.db_port,
self.db_database,
)


settings = Settings()
3 changes: 1 addition & 2 deletions migrations/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
from alembic import context
from sqlalchemy import URL, create_engine

from app import ioc
from app.models import METADATA
from app.settings import settings


def get_dsn() -> URL:
settings = ioc.IOCContainer.settings.sync_resolve()
db_dsn = settings.db_dsn
return db_dsn.set(drivername="postgresql")

Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@ dependencies = [
"advanced-alchemy",
"pydantic-settings",
"granian",
"that-depends",
"modern-di-fastapi",
# database
"alembic",
"psycopg2",
"sqlalchemy",
"asyncpg",
]

[tool.uv]
dev-dependencies = [
[dependency-groups]
dev = [
"polyfactory",
"httpx",
"pytest",
Expand All @@ -38,7 +38,7 @@ dev-dependencies = [
fix = true
unsafe-fixes = true
line-length = 120
target-version = "py311"
target-version = "py312"
extend-exclude = ["bin"]

[tool.ruff.lint]
Expand Down
39 changes: 21 additions & 18 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import typing

import modern_di
import modern_di_fastapi
import pytest
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
Expand All @@ -13,26 +15,27 @@ async def client() -> typing.AsyncIterator[AsyncClient]:
async with AsyncClient(
transport=ASGITransport(app=application),
base_url="http://test",
timeout=0,
) as client:
yield client


@pytest.fixture(autouse=True)
async def _prepare_ioc_container() -> typing.AsyncIterator[None]:
engine = await ioc.IOCContainer.database_engine()
connection = await engine.connect()
transaction = await connection.begin()
await connection.begin_nested()
session = AsyncSession(connection, expire_on_commit=False, autoflush=False)
ioc.IOCContainer.session.override(session)

try:
yield
finally:
if connection.in_transaction():
await transaction.rollback()
await connection.close()

ioc.IOCContainer.reset_override()
await ioc.IOCContainer.tear_down()
async def di_container() -> modern_di.Container:
return modern_di_fastapi.fetch_di_container(application)


@pytest.fixture(autouse=True)
async def db_session(di_container: modern_di.Container) -> typing.AsyncIterator[AsyncSession]:
async with di_container:
engine = await ioc.IOCContainer.database_engine.async_resolve(di_container)
connection = await engine.connect()
transaction = await connection.begin()
await connection.begin_nested()
ioc.IOCContainer.database_engine.override(connection, di_container)

try:
yield AsyncSession(connection, expire_on_commit=False, autoflush=False)
finally:
if connection.in_transaction():
await transaction.rollback()
await connection.close()
Loading
Loading