11"""Configuration classes for SQLSpec FastAPI integration."""
22
3- import contextlib
4- from dataclasses import dataclass , field
5- from typing import TYPE_CHECKING , Callable , Literal , Optional , Union
3+ from typing import TYPE_CHECKING
64
7- from sqlspec .exceptions import ImproperConfigurationError
5+ from sqlspec .extensions .starlette .config import (
6+ DEFAULT_COMMIT_MODE ,
7+ DEFAULT_CONNECTION_KEY ,
8+ DEFAULT_POOL_KEY ,
9+ DEFAULT_SESSION_KEY ,
10+ )
11+ from sqlspec .extensions .starlette .config import AsyncDatabaseConfig as StarletteAsyncConfig
12+ from sqlspec .extensions .starlette .config import DatabaseConfig as StarletteConfig
13+ from sqlspec .extensions .starlette .config import SyncDatabaseConfig as StarletteSyncConfig
814
915if TYPE_CHECKING :
10- from collections .abc import AsyncGenerator , Awaitable
11-
1216 from fastapi import FastAPI
1317
14- from sqlspec .config import AsyncConfigT , DriverT , SyncConfigT
15- from sqlspec .typing import ConnectionT , PoolT
16-
17-
18- CommitMode = Literal ["manual" , "autocommit" , "autocommit_include_redirect" ]
19- DEFAULT_COMMIT_MODE : CommitMode = "manual"
20- DEFAULT_CONNECTION_KEY = "db_connection"
21- DEFAULT_POOL_KEY = "db_pool"
22- DEFAULT_SESSION_KEY = "db_session"
23-
2418__all__ = (
2519 "DEFAULT_COMMIT_MODE" ,
2620 "DEFAULT_CONNECTION_KEY" ,
2721 "DEFAULT_POOL_KEY" ,
2822 "DEFAULT_SESSION_KEY" ,
23+ "AsyncDatabaseConfig" ,
2924 "CommitMode" ,
3025 "DatabaseConfig" ,
26+ "SyncDatabaseConfig" ,
3127)
3228
29+ # Re-export Starlette types with FastAPI-compatible typing
30+ from sqlspec .extensions .starlette .config import CommitMode
3331
34- @dataclass
35- class DatabaseConfig :
36- """Configuration for SQLSpec database integration with FastAPI applications."""
37-
38- config : "Union[SyncConfigT, AsyncConfigT]" = field () # type: ignore[valid-type] # pyright: ignore[reportGeneralTypeIssues]
39- connection_key : str = field (default = DEFAULT_CONNECTION_KEY )
40- pool_key : str = field (default = DEFAULT_POOL_KEY )
41- session_key : str = field (default = DEFAULT_SESSION_KEY )
42- commit_mode : "CommitMode" = field (default = DEFAULT_COMMIT_MODE )
43- extra_commit_statuses : "Optional[set[int]]" = field (default = None )
44- extra_rollback_statuses : "Optional[set[int]]" = field (default = None )
45- enable_middleware : bool = field (default = True )
46-
47- # Generated providers and dependencies
48- connection_provider : "Optional[Callable[[], AsyncGenerator[ConnectionT, None]]]" = field (
49- init = False , repr = False , hash = False , default = None
50- )
51- pool_provider : "Optional[Callable[[], Awaitable[PoolT]]]" = field (init = False , repr = False , hash = False , default = None )
52- session_provider : "Optional[Callable[[ConnectionT], AsyncGenerator[DriverT, None]]]" = field (
53- init = False , repr = False , hash = False , default = None
54- )
55-
56- def __post_init__ (self ) -> None :
57- """Initialize providers after object creation."""
58- if not self .config .supports_connection_pooling and self .pool_key == DEFAULT_POOL_KEY : # type: ignore[union-attr,unused-ignore]
59- self .pool_key = f"_{ self .pool_key } _{ id (self .config )} "
60-
61- # Validate commit mode
62- if self .commit_mode not in {"manual" , "autocommit" , "autocommit_include_redirect" }:
63- msg = f"Invalid commit mode: { self .commit_mode } "
64- raise ImproperConfigurationError (detail = msg )
65-
66- # Validate status code sets
67- if (
68- self .extra_commit_statuses
69- and self .extra_rollback_statuses
70- and self .extra_commit_statuses & self .extra_rollback_statuses
71- ):
72- msg = "Extra rollback statuses and commit statuses must not share any status codes"
73- raise ImproperConfigurationError (msg )
74-
75- def init_app (self , app : "FastAPI" ) -> None :
32+
33+ class DatabaseConfig (StarletteConfig ):
34+ """Configuration for SQLSpec database integration with FastAPI applications.
35+
36+ FastAPI is built on Starlette, so this configuration inherits all functionality
37+ from the Starlette configuration. The only differences are type hints for FastAPI
38+ Request objects and middleware imports.
39+ """
40+
41+ def init_app (self , app : "FastAPI" ) -> None : # pyright: ignore
7642 """Initialize SQLSpec configuration for FastAPI application.
7743
7844 Args:
7945 app: The FastAPI application instance.
8046 """
8147 from sqlspec .extensions .fastapi ._middleware import SessionMiddleware
82- from sqlspec .extensions .fastapi ._providers import (
48+ from sqlspec .extensions .starlette ._providers import (
8349 create_connection_provider ,
8450 create_pool_provider ,
8551 create_session_provider ,
8652 )
8753
88- # Create providers
54+ # Create providers using Starlette providers (FastAPI is compatible)
8955 self .pool_provider = create_pool_provider (self .config , self .pool_key )
9056 self .connection_provider = create_connection_provider (self .config , self .pool_key , self .connection_key )
9157 self .session_provider = create_session_provider (self .config , self .connection_key )
@@ -100,43 +66,30 @@ def init_app(self, app: "FastAPI") -> None:
10066 extra_rollback_statuses = self .extra_rollback_statuses ,
10167 )
10268
103- # Add event handlers
104- app . add_event_handler ( "startup" , self . _startup_handler (app ))
105- app . add_event_handler ( "shutdown" , self . _shutdown_handler ( app ))
69+ # Add event handlers - delegate to parent logic but cast FastAPI to Starlette
70+ super (). init_app (app ) # type: ignore[arg-type]
71+
10672
107- def _startup_handler (self , app : "FastAPI" ) -> "Callable[[], Awaitable[None]]" :
108- """Create startup handler for database pool initialization.
73+ # Add typed subclasses for better developer experience
74+ class SyncDatabaseConfig (StarletteSyncConfig ):
75+ """Sync-specific DatabaseConfig with FastAPI-compatible type hints."""
76+
77+ def init_app (self , app : "FastAPI" ) -> None : # pyright: ignore
78+ """Initialize SQLSpec configuration for FastAPI application.
10979
11080 Args:
11181 app: The FastAPI application instance.
112-
113- Returns:
114- Startup handler function.
11582 """
83+ DatabaseConfig .init_app (self , app ) # pyright: ignore
11684
117- async def startup () -> None :
118- from sqlspec .utils .sync_tools import ensure_async_
11985
120- db_pool = await ensure_async_ ( self . config . create_pool )()
121- app . state . __dict__ [ self . pool_key ] = db_pool
86+ class AsyncDatabaseConfig ( StarletteAsyncConfig ):
87+ """Async-specific DatabaseConfig with FastAPI-compatible type hints."""
12288
123- return startup
124-
125- def _shutdown_handler (self , app : "FastAPI" ) -> "Callable[[], Awaitable[None]]" :
126- """Create shutdown handler for database pool cleanup.
89+ def init_app (self , app : "FastAPI" ) -> None : # pyright: ignore
90+ """Initialize SQLSpec configuration for FastAPI application.
12791
12892 Args:
12993 app: The FastAPI application instance.
130-
131- Returns:
132- Shutdown handler function.
13394 """
134-
135- async def shutdown () -> None :
136- from sqlspec .utils .sync_tools import ensure_async_
137-
138- app .state .__dict__ .pop (self .pool_key , None )
139- with contextlib .suppress (Exception ):
140- await ensure_async_ (self .config .close_pool )()
141-
142- return shutdown
95+ DatabaseConfig .init_app (self , app ) # pyright: ignore
0 commit comments