Skip to content

Commit 957addc

Browse files
authored
feat(litestar): add methods to get active request session (#68)
Introduce methods to access the active database session and connection from the current request context, enhancing the integration with Litestar's request handling. This improves session management and provides better typing for synchronous and asynchronous database configurations.
1 parent abca00d commit 957addc

File tree

5 files changed

+907
-5
lines changed

5 files changed

+907
-5
lines changed

sqlspec/extensions/litestar/config.py

Lines changed: 186 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from dataclasses import dataclass, field
2-
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
2+
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union, cast
33

44
from sqlspec.exceptions import ImproperConfigurationError
5+
from sqlspec.extensions.litestar._utils import get_sqlspec_scope_state, set_sqlspec_scope_state
56
from sqlspec.extensions.litestar.handlers import (
67
autocommit_handler_maker,
78
connection_provider_maker,
@@ -13,13 +14,14 @@
1314

1415
if TYPE_CHECKING:
1516
from collections.abc import AsyncGenerator, Awaitable
16-
from contextlib import AbstractAsyncContextManager
17+
from contextlib import AbstractAsyncContextManager, AbstractContextManager
1718

1819
from litestar import Litestar
1920
from litestar.datastructures.state import State
2021
from litestar.types import BeforeMessageSendHookHandler, Scope
2122

2223
from sqlspec.config import AsyncConfigT, DriverT, SyncConfigT
24+
from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase
2325
from sqlspec.typing import ConnectionT, PoolT
2426

2527

@@ -34,8 +36,10 @@
3436
"DEFAULT_CONNECTION_KEY",
3537
"DEFAULT_POOL_KEY",
3638
"DEFAULT_SESSION_KEY",
39+
"AsyncDatabaseConfig",
3740
"CommitMode",
3841
"DatabaseConfig",
42+
"SyncDatabaseConfig",
3943
)
4044

4145

@@ -90,3 +94,183 @@ def __post_init__(self) -> None:
9094
self.session_provider = session_provider_maker(
9195
config=self.config, connection_dependency_key=self.connection_key
9296
)
97+
98+
def get_request_session(
99+
self, state: "State", scope: "Scope"
100+
) -> "Union[SyncDriverAdapterBase, AsyncDriverAdapterBase]":
101+
"""Get a session instance from the current request.
102+
103+
This method provides access to the database session that has been added to the request
104+
scope, similar to Advanced Alchemy's provide_session method. It first looks for an
105+
existing session in the request scope state, and if not found, creates a new one using
106+
the connection from the scope.
107+
108+
Args:
109+
state: The Litestar application State object.
110+
scope: The ASGI scope containing the request context.
111+
112+
Returns:
113+
A driver session instance.
114+
115+
Raises:
116+
ImproperConfigurationError: If no connection is available in the scope.
117+
"""
118+
# Create a unique scope key for sessions to avoid conflicts
119+
session_scope_key = f"{self.session_key}_instance"
120+
121+
# Try to get existing session from scope
122+
session = get_sqlspec_scope_state(scope, session_scope_key)
123+
if session is not None:
124+
return cast("Union[SyncDriverAdapterBase, AsyncDriverAdapterBase]", session)
125+
126+
# Get connection from scope state
127+
connection = get_sqlspec_scope_state(scope, self.connection_key)
128+
if connection is None:
129+
msg = f"No database connection found in scope for key '{self.connection_key}'. "
130+
msg += "Ensure the connection dependency is properly configured and available."
131+
raise ImproperConfigurationError(detail=msg)
132+
133+
# Create new session using the connection
134+
# Access driver_type which is available on all config types
135+
session = self.config.driver_type(connection=connection) # type: ignore[union-attr]
136+
137+
# Store session in scope for future use
138+
set_sqlspec_scope_state(scope, session_scope_key, session)
139+
140+
return cast("Union[SyncDriverAdapterBase, AsyncDriverAdapterBase]", session)
141+
142+
def get_request_connection(self, state: "State", scope: "Scope") -> "Any":
143+
"""Get a connection instance from the current request.
144+
145+
This method provides access to the database connection that has been added to the request
146+
scope. This is useful in guards, middleware, or other contexts where you need direct
147+
access to the connection that's been established for the current request.
148+
149+
Args:
150+
state: The Litestar application State object.
151+
scope: The ASGI scope containing the request context.
152+
153+
Returns:
154+
A database connection instance.
155+
156+
Raises:
157+
ImproperConfigurationError: If no connection is available in the scope.
158+
"""
159+
connection = get_sqlspec_scope_state(scope, self.connection_key)
160+
if connection is None:
161+
msg = f"No database connection found in scope for key '{self.connection_key}'. "
162+
msg += "Ensure the connection dependency is properly configured and available."
163+
raise ImproperConfigurationError(detail=msg)
164+
165+
return cast("Any", connection)
166+
167+
168+
# Add passthrough methods to both specialized classes for convenience
169+
class SyncDatabaseConfig(DatabaseConfig):
170+
"""Sync-specific DatabaseConfig with better typing for get_request_session."""
171+
172+
def get_request_session(self, state: "State", scope: "Scope") -> "SyncDriverAdapterBase":
173+
"""Get a sync session instance from the current request.
174+
175+
This method provides access to the database session that has been added to the request
176+
scope, similar to Advanced Alchemy's provide_session method. It first looks for an
177+
existing session in the request scope state, and if not found, creates a new one using
178+
the connection from the scope.
179+
180+
Args:
181+
state: The Litestar application State object.
182+
scope: The ASGI scope containing the request context.
183+
184+
Returns:
185+
A sync driver session instance.
186+
"""
187+
session = super().get_request_session(state, scope)
188+
return cast("SyncDriverAdapterBase", session)
189+
190+
def provide_session(self) -> "AbstractContextManager[SyncDriverAdapterBase]":
191+
"""Provide a database session context manager.
192+
193+
This is a passthrough to the underlying config's provide_session method
194+
for convenient access to database sessions.
195+
196+
Returns:
197+
Context manager that yields a sync driver session.
198+
"""
199+
return self.config.provide_session() # type: ignore[union-attr,no-any-return]
200+
201+
def provide_connection(self) -> "AbstractContextManager[Any]":
202+
"""Provide a database connection context manager.
203+
204+
This is a passthrough to the underlying config's provide_connection method
205+
for convenient access to database connections.
206+
207+
Returns:
208+
Context manager that yields a sync database connection.
209+
"""
210+
return self.config.provide_connection() # type: ignore[union-attr,no-any-return]
211+
212+
def create_connection(self) -> "Any":
213+
"""Create and return a new database connection.
214+
215+
This is a passthrough to the underlying config's create_connection method
216+
for direct connection creation without context management.
217+
218+
Returns:
219+
A new sync database connection.
220+
"""
221+
return self.config.create_connection() # type: ignore[union-attr]
222+
223+
224+
class AsyncDatabaseConfig(DatabaseConfig):
225+
"""Async-specific DatabaseConfig with better typing for get_request_session."""
226+
227+
def get_request_session(self, state: "State", scope: "Scope") -> "AsyncDriverAdapterBase":
228+
"""Get an async session instance from the current request.
229+
230+
This method provides access to the database session that has been added to the request
231+
scope, similar to Advanced Alchemy's provide_session method. It first looks for an
232+
existing session in the request scope state, and if not found, creates a new one using
233+
the connection from the scope.
234+
235+
Args:
236+
state: The Litestar application State object.
237+
scope: The ASGI scope containing the request context.
238+
239+
Returns:
240+
An async driver session instance.
241+
"""
242+
session = super().get_request_session(state, scope)
243+
return cast("AsyncDriverAdapterBase", session)
244+
245+
def provide_session(self) -> "AbstractAsyncContextManager[AsyncDriverAdapterBase]":
246+
"""Provide a database session context manager.
247+
248+
This is a passthrough to the underlying config's provide_session method
249+
for convenient access to database sessions.
250+
251+
Returns:
252+
Context manager that yields an async driver session.
253+
"""
254+
return self.config.provide_session() # type: ignore[union-attr,no-any-return]
255+
256+
def provide_connection(self) -> "AbstractAsyncContextManager[Any]":
257+
"""Provide a database connection context manager.
258+
259+
This is a passthrough to the underlying config's provide_connection method
260+
for convenient access to database connections.
261+
262+
Returns:
263+
Context manager that yields an async database connection.
264+
"""
265+
return self.config.provide_connection() # type: ignore[union-attr,no-any-return]
266+
267+
async def create_connection(self) -> "Any":
268+
"""Create and return a new database connection.
269+
270+
This is a passthrough to the underlying config's create_connection method
271+
for direct connection creation without context management.
272+
273+
Returns:
274+
A new async database connection.
275+
"""
276+
return await self.config.create_connection() # type: ignore[union-attr]

0 commit comments

Comments
 (0)