|
1 | 1 | from dataclasses import dataclass, field |
2 | | -from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union |
| 2 | +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union, cast |
3 | 3 |
|
4 | 4 | from sqlspec.exceptions import ImproperConfigurationError |
| 5 | +from sqlspec.extensions.litestar._utils import get_sqlspec_scope_state, set_sqlspec_scope_state |
5 | 6 | from sqlspec.extensions.litestar.handlers import ( |
6 | 7 | autocommit_handler_maker, |
7 | 8 | connection_provider_maker, |
|
13 | 14 |
|
14 | 15 | if TYPE_CHECKING: |
15 | 16 | from collections.abc import AsyncGenerator, Awaitable |
16 | | - from contextlib import AbstractAsyncContextManager |
| 17 | + from contextlib import AbstractAsyncContextManager, AbstractContextManager |
17 | 18 |
|
18 | 19 | from litestar import Litestar |
19 | 20 | from litestar.datastructures.state import State |
20 | 21 | from litestar.types import BeforeMessageSendHookHandler, Scope |
21 | 22 |
|
22 | 23 | from sqlspec.config import AsyncConfigT, DriverT, SyncConfigT |
| 24 | + from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase |
23 | 25 | from sqlspec.typing import ConnectionT, PoolT |
24 | 26 |
|
25 | 27 |
|
|
34 | 36 | "DEFAULT_CONNECTION_KEY", |
35 | 37 | "DEFAULT_POOL_KEY", |
36 | 38 | "DEFAULT_SESSION_KEY", |
| 39 | + "AsyncDatabaseConfig", |
37 | 40 | "CommitMode", |
38 | 41 | "DatabaseConfig", |
| 42 | + "SyncDatabaseConfig", |
39 | 43 | ) |
40 | 44 |
|
41 | 45 |
|
@@ -90,3 +94,183 @@ def __post_init__(self) -> None: |
90 | 94 | self.session_provider = session_provider_maker( |
91 | 95 | config=self.config, connection_dependency_key=self.connection_key |
92 | 96 | ) |
| 97 | + |
| 98 | + def get_request_session( |
| 99 | + self, state: "State", scope: "Scope" |
| 100 | + ) -> "Union[SyncDriverAdapterBase, AsyncDriverAdapterBase]": |
| 101 | + """Get a session instance from the current request. |
| 102 | +
|
| 103 | + This method provides access to the database session that has been added to the request |
| 104 | + scope, similar to Advanced Alchemy's provide_session method. It first looks for an |
| 105 | + existing session in the request scope state, and if not found, creates a new one using |
| 106 | + the connection from the scope. |
| 107 | +
|
| 108 | + Args: |
| 109 | + state: The Litestar application State object. |
| 110 | + scope: The ASGI scope containing the request context. |
| 111 | +
|
| 112 | + Returns: |
| 113 | + A driver session instance. |
| 114 | +
|
| 115 | + Raises: |
| 116 | + ImproperConfigurationError: If no connection is available in the scope. |
| 117 | + """ |
| 118 | + # Create a unique scope key for sessions to avoid conflicts |
| 119 | + session_scope_key = f"{self.session_key}_instance" |
| 120 | + |
| 121 | + # Try to get existing session from scope |
| 122 | + session = get_sqlspec_scope_state(scope, session_scope_key) |
| 123 | + if session is not None: |
| 124 | + return cast("Union[SyncDriverAdapterBase, AsyncDriverAdapterBase]", session) |
| 125 | + |
| 126 | + # Get connection from scope state |
| 127 | + connection = get_sqlspec_scope_state(scope, self.connection_key) |
| 128 | + if connection is None: |
| 129 | + msg = f"No database connection found in scope for key '{self.connection_key}'. " |
| 130 | + msg += "Ensure the connection dependency is properly configured and available." |
| 131 | + raise ImproperConfigurationError(detail=msg) |
| 132 | + |
| 133 | + # Create new session using the connection |
| 134 | + # Access driver_type which is available on all config types |
| 135 | + session = self.config.driver_type(connection=connection) # type: ignore[union-attr] |
| 136 | + |
| 137 | + # Store session in scope for future use |
| 138 | + set_sqlspec_scope_state(scope, session_scope_key, session) |
| 139 | + |
| 140 | + return cast("Union[SyncDriverAdapterBase, AsyncDriverAdapterBase]", session) |
| 141 | + |
| 142 | + def get_request_connection(self, state: "State", scope: "Scope") -> "Any": |
| 143 | + """Get a connection instance from the current request. |
| 144 | +
|
| 145 | + This method provides access to the database connection that has been added to the request |
| 146 | + scope. This is useful in guards, middleware, or other contexts where you need direct |
| 147 | + access to the connection that's been established for the current request. |
| 148 | +
|
| 149 | + Args: |
| 150 | + state: The Litestar application State object. |
| 151 | + scope: The ASGI scope containing the request context. |
| 152 | +
|
| 153 | + Returns: |
| 154 | + A database connection instance. |
| 155 | +
|
| 156 | + Raises: |
| 157 | + ImproperConfigurationError: If no connection is available in the scope. |
| 158 | + """ |
| 159 | + connection = get_sqlspec_scope_state(scope, self.connection_key) |
| 160 | + if connection is None: |
| 161 | + msg = f"No database connection found in scope for key '{self.connection_key}'. " |
| 162 | + msg += "Ensure the connection dependency is properly configured and available." |
| 163 | + raise ImproperConfigurationError(detail=msg) |
| 164 | + |
| 165 | + return cast("Any", connection) |
| 166 | + |
| 167 | + |
| 168 | +# Add passthrough methods to both specialized classes for convenience |
| 169 | +class SyncDatabaseConfig(DatabaseConfig): |
| 170 | + """Sync-specific DatabaseConfig with better typing for get_request_session.""" |
| 171 | + |
| 172 | + def get_request_session(self, state: "State", scope: "Scope") -> "SyncDriverAdapterBase": |
| 173 | + """Get a sync session instance from the current request. |
| 174 | +
|
| 175 | + This method provides access to the database session that has been added to the request |
| 176 | + scope, similar to Advanced Alchemy's provide_session method. It first looks for an |
| 177 | + existing session in the request scope state, and if not found, creates a new one using |
| 178 | + the connection from the scope. |
| 179 | +
|
| 180 | + Args: |
| 181 | + state: The Litestar application State object. |
| 182 | + scope: The ASGI scope containing the request context. |
| 183 | +
|
| 184 | + Returns: |
| 185 | + A sync driver session instance. |
| 186 | + """ |
| 187 | + session = super().get_request_session(state, scope) |
| 188 | + return cast("SyncDriverAdapterBase", session) |
| 189 | + |
| 190 | + def provide_session(self) -> "AbstractContextManager[SyncDriverAdapterBase]": |
| 191 | + """Provide a database session context manager. |
| 192 | +
|
| 193 | + This is a passthrough to the underlying config's provide_session method |
| 194 | + for convenient access to database sessions. |
| 195 | +
|
| 196 | + Returns: |
| 197 | + Context manager that yields a sync driver session. |
| 198 | + """ |
| 199 | + return self.config.provide_session() # type: ignore[union-attr,no-any-return] |
| 200 | + |
| 201 | + def provide_connection(self) -> "AbstractContextManager[Any]": |
| 202 | + """Provide a database connection context manager. |
| 203 | +
|
| 204 | + This is a passthrough to the underlying config's provide_connection method |
| 205 | + for convenient access to database connections. |
| 206 | +
|
| 207 | + Returns: |
| 208 | + Context manager that yields a sync database connection. |
| 209 | + """ |
| 210 | + return self.config.provide_connection() # type: ignore[union-attr,no-any-return] |
| 211 | + |
| 212 | + def create_connection(self) -> "Any": |
| 213 | + """Create and return a new database connection. |
| 214 | +
|
| 215 | + This is a passthrough to the underlying config's create_connection method |
| 216 | + for direct connection creation without context management. |
| 217 | +
|
| 218 | + Returns: |
| 219 | + A new sync database connection. |
| 220 | + """ |
| 221 | + return self.config.create_connection() # type: ignore[union-attr] |
| 222 | + |
| 223 | + |
| 224 | +class AsyncDatabaseConfig(DatabaseConfig): |
| 225 | + """Async-specific DatabaseConfig with better typing for get_request_session.""" |
| 226 | + |
| 227 | + def get_request_session(self, state: "State", scope: "Scope") -> "AsyncDriverAdapterBase": |
| 228 | + """Get an async session instance from the current request. |
| 229 | +
|
| 230 | + This method provides access to the database session that has been added to the request |
| 231 | + scope, similar to Advanced Alchemy's provide_session method. It first looks for an |
| 232 | + existing session in the request scope state, and if not found, creates a new one using |
| 233 | + the connection from the scope. |
| 234 | +
|
| 235 | + Args: |
| 236 | + state: The Litestar application State object. |
| 237 | + scope: The ASGI scope containing the request context. |
| 238 | +
|
| 239 | + Returns: |
| 240 | + An async driver session instance. |
| 241 | + """ |
| 242 | + session = super().get_request_session(state, scope) |
| 243 | + return cast("AsyncDriverAdapterBase", session) |
| 244 | + |
| 245 | + def provide_session(self) -> "AbstractAsyncContextManager[AsyncDriverAdapterBase]": |
| 246 | + """Provide a database session context manager. |
| 247 | +
|
| 248 | + This is a passthrough to the underlying config's provide_session method |
| 249 | + for convenient access to database sessions. |
| 250 | +
|
| 251 | + Returns: |
| 252 | + Context manager that yields an async driver session. |
| 253 | + """ |
| 254 | + return self.config.provide_session() # type: ignore[union-attr,no-any-return] |
| 255 | + |
| 256 | + def provide_connection(self) -> "AbstractAsyncContextManager[Any]": |
| 257 | + """Provide a database connection context manager. |
| 258 | +
|
| 259 | + This is a passthrough to the underlying config's provide_connection method |
| 260 | + for convenient access to database connections. |
| 261 | +
|
| 262 | + Returns: |
| 263 | + Context manager that yields an async database connection. |
| 264 | + """ |
| 265 | + return self.config.provide_connection() # type: ignore[union-attr,no-any-return] |
| 266 | + |
| 267 | + async def create_connection(self) -> "Any": |
| 268 | + """Create and return a new database connection. |
| 269 | +
|
| 270 | + This is a passthrough to the underlying config's create_connection method |
| 271 | + for direct connection creation without context management. |
| 272 | +
|
| 273 | + Returns: |
| 274 | + A new async database connection. |
| 275 | + """ |
| 276 | + return await self.config.create_connection() # type: ignore[union-attr] |
0 commit comments