diff --git a/advanced_alchemy/extensions/litestar/plugins/__init__.py b/advanced_alchemy/extensions/litestar/plugins/__init__.py index 93ade430..d8e7bcd7 100644 --- a/advanced_alchemy/extensions/litestar/plugins/__init__.py +++ b/advanced_alchemy/extensions/litestar/plugins/__init__.py @@ -1,8 +1,12 @@ -from collections.abc import Sequence -from typing import Union +from collections.abc import AsyncGenerator, Generator, Sequence +from contextlib import asynccontextmanager, contextmanager +from typing import Any, Callable, Optional, Union, cast from litestar.config.app import AppConfig from litestar.plugins import InitPluginProtocol +from sqlalchemy import Engine +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession +from sqlalchemy.orm import Session from advanced_alchemy.extensions.litestar.plugins import _slots_base from advanced_alchemy.extensions.litestar.plugins.init import ( @@ -48,6 +52,92 @@ def on_app_init(self, app_config: AppConfig) -> AppConfig: app_config.plugins.extend([SQLAlchemyInitPlugin(config=self._config), SQLAlchemySerializationPlugin()]) return app_config + def _get_config(self, key: Optional[str] = None) -> Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]: + """Get a configuration by key. + + Args: + key: Optional key to identify the configuration. If not provided, uses the first config. + + Raises: + ValueError: If no configuration is found. + + Returns: + The SQLAlchemy configuration. + """ + if key is None: + return self._config[0] + for config in self._config: + if getattr(config, "key", None) == key: + return config + msg = f"No configuration found with key {key}" + raise ValueError(msg) + + def get_session( + self, + key: Optional[str] = None, + ) -> Union[AsyncGenerator[AsyncSession, None], Generator[Session, None, None]]: + """Get a SQLAlchemy session. + + Args: + key: Optional key to identify the configuration. If not provided, uses the first config. + + Returns: + A SQLAlchemy session. + """ + config = self._get_config(key) + + if isinstance(config, SQLAlchemyAsyncConfig): + + @asynccontextmanager + async def async_gen() -> AsyncGenerator[AsyncSession, None]: + async with config.get_session() as session: + yield session + + return cast("AsyncGenerator[AsyncSession, None]", async_gen()) + + @contextmanager + def sync_gen() -> Generator[Session, None, None]: + with config.get_session() as session: + yield session + + return cast("Generator[Session, None, None]", sync_gen()) + + def provide_session( + self, + key: Optional[str] = None, + ) -> Callable[..., Union[AsyncGenerator[AsyncSession, None], Generator[Session, None, None]]]: + """Get a session provider for dependency injection. + + Args: + key: Optional key to identify the configuration. If not provided, uses the first config. + + Returns: + A callable that returns a session provider. + """ + + def provider( + *args: Any, # noqa: ARG001 + **kwargs: Any, # noqa: ARG001 + ) -> Union[AsyncGenerator[AsyncSession, None], Generator[Session, None, None]]: + return self.get_session(key) + + return provider + + def get_engine( + self, + key: Optional[str] = None, + ) -> Union[AsyncEngine, Engine]: + """Get the SQLAlchemy engine. + + Args: + key: Optional key to identify the configuration. If not provided, uses the first config. + + Returns: + The SQLAlchemy engine. + """ + config = self._get_config(key) + return config.get_engine() + __all__ = ( "EngineConfig", diff --git a/advanced_alchemy/extensions/litestar/plugins/init/config/sync.py b/advanced_alchemy/extensions/litestar/plugins/init/config/sync.py index 0ee0f0f5..d493fbbc 100644 --- a/advanced_alchemy/extensions/litestar/plugins/init/config/sync.py +++ b/advanced_alchemy/extensions/litestar/plugins/init/config/sync.py @@ -54,9 +54,6 @@ def handler(message: "Message", scope: "Scope") -> None: Args: message: ASGI-``Message`` scope: An ASGI-``Scope`` - - Returns: - None """ session = cast("Optional[Session]", get_aa_scope_state(scope, session_scope_key)) if session and message["type"] in SESSION_TERMINUS_ASGI_EVENTS: diff --git a/examples/litestar/litestar_service.py b/examples/litestar/litestar_service.py index 7e249be0..e69d808e 100644 --- a/examples/litestar/litestar_service.py +++ b/examples/litestar/litestar_service.py @@ -101,10 +101,13 @@ async def create_author(self, authors_service: AuthorService, data: AuthorCreate async def get_author( self, authors_service: AuthorService, - author_id: UUID = Parameter( - title="Author ID", - description="The author to retrieve.", - ), + author_id: Annotated[ + UUID, + Parameter( + title="Author ID", + description="The author to retrieve.", + ), + ], ) -> Author: """Get an existing author.""" obj = await authors_service.get(author_id) @@ -115,10 +118,13 @@ async def update_author( self, authors_service: AuthorService, data: AuthorUpdate, - author_id: UUID = Parameter( - title="Author ID", - description="The author to update.", - ), + author_id: Annotated[ + UUID, + Parameter( + title="Author ID", + description="The author to update.", + ), + ], ) -> Author: """Update an author.""" obj = await authors_service.update(data, item_id=author_id, auto_commit=True) @@ -128,10 +134,13 @@ async def update_author( async def delete_author( self, authors_service: AuthorService, - author_id: UUID = Parameter( - title="Author ID", - description="The author to delete.", - ), + author_id: Annotated[ + UUID, + Parameter( + title="Author ID", + description="The author to delete.", + ), + ], ) -> None: """Delete a author from the system.""" _ = await authors_service.delete(author_id)