Skip to content

Commit fe9268a

Browse files
committed
feat: fastapi, starlette, flask, and sanic extensions
1 parent 1d51c0e commit fe9268a

File tree

35 files changed

+4954
-6
lines changed

35 files changed

+4954
-6
lines changed

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ psycopg = ["psycopg[binary,pool]"]
5151
pydantic = ["pydantic", "pydantic-extra-types"]
5252
pymssql = ["pymssql"]
5353
pymysql = ["pymysql"]
54+
sanic = ["sanic", "sanic[ext]>=24.6.0"]
5455
spanner = ["google-cloud-spanner"]
5556
uuid = ["uuid-utils"]
5657

@@ -90,6 +91,10 @@ extras = [
9091
"adbc_driver_postgresql",
9192
"adbc_driver_flightsql",
9293
"adbc_driver_bigquery",
94+
"sanic-testing",
95+
"dishka ; python_version >= \"3.10\"",
96+
"pydantic-extra-types",
97+
"fsspec[s3]",
9398
]
9499
lint = [
95100
"mypy>=1.13.0",
@@ -465,6 +470,7 @@ split-on-trailing-comma = false
465470
"docs/**/*.*" = ["S", "B", "DTZ", "A", "TC", "ERA", "D", "RET", "PLW0127"]
466471
"docs/examples/**" = ["T201"]
467472
"sqlspec/builder/mixins/**/*.*" = ["SLF001"]
473+
"sqlspec/extensions/fastapi/providers.py" = ["B008"]
468474
"tests/**/*.*" = [
469475
"A",
470476
"ARG",
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from sqlspec.extensions.fastapi.config import DatabaseConfig
2+
from sqlspec.extensions.fastapi.extension import SQLSpec
3+
4+
__all__ = ("DatabaseConfig", "SQLSpec")
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""Middleware for SQLSpec FastAPI integration."""
2+
3+
import contextlib
4+
from typing import TYPE_CHECKING, Any, Optional
5+
6+
from starlette.middleware.base import BaseHTTPMiddleware
7+
8+
from sqlspec.utils.sync_tools import ensure_async_
9+
10+
if TYPE_CHECKING:
11+
from collections.abc import Awaitable, Callable
12+
13+
from starlette.requests import Request
14+
from starlette.responses import Response
15+
16+
from sqlspec.extensions.fastapi.config import CommitMode, DatabaseConfig
17+
18+
19+
__all__ = ("SessionMiddleware",)
20+
21+
22+
class SessionMiddleware(BaseHTTPMiddleware):
23+
"""Middleware for managing database sessions and transactions in FastAPI."""
24+
25+
def __init__(
26+
self,
27+
app: Any,
28+
config: "DatabaseConfig",
29+
commit_mode: "CommitMode" = "manual",
30+
extra_commit_statuses: "Optional[set[int]]" = None,
31+
extra_rollback_statuses: "Optional[set[int]]" = None,
32+
) -> None:
33+
"""Initialize session middleware.
34+
35+
Args:
36+
app: The ASGI application.
37+
config: Database configuration instance.
38+
commit_mode: Transaction commit behavior.
39+
extra_commit_statuses: Additional status codes that trigger commits.
40+
extra_rollback_statuses: Additional status codes that trigger rollbacks.
41+
"""
42+
super().__init__(app)
43+
self.config = config
44+
self.commit_mode = commit_mode
45+
self.extra_commit_statuses = extra_commit_statuses or set()
46+
self.extra_rollback_statuses = extra_rollback_statuses or set()
47+
48+
async def dispatch(self, request: "Request", call_next: "Callable[[Request], Awaitable[Response]]") -> "Response":
49+
"""Handle request with session management.
50+
51+
Args:
52+
request: The incoming request.
53+
call_next: The next middleware or endpoint.
54+
55+
Returns:
56+
The response from the application.
57+
"""
58+
if not self.config.connection_provider:
59+
return await call_next(request)
60+
61+
# Get connection from provider
62+
connection_gen = self.config.connection_provider()
63+
connection = await connection_gen.__anext__()
64+
65+
# Store connection in request state
66+
request.state.__dict__[self.config.connection_key] = connection
67+
68+
try:
69+
response = await call_next(request)
70+
71+
# Handle transaction based on commit mode and response status
72+
if self.commit_mode != "manual":
73+
await self._handle_transaction(connection, response.status_code)
74+
75+
except Exception:
76+
# Rollback on exception
77+
if hasattr(connection, "rollback") and callable(connection.rollback):
78+
await ensure_async_(connection.rollback)()
79+
raise
80+
else:
81+
return response
82+
finally:
83+
# Clean up connection
84+
with contextlib.suppress(StopAsyncIteration):
85+
await connection_gen.__anext__()
86+
if hasattr(connection, "close") and callable(connection.close):
87+
await ensure_async_(connection.close)()
88+
89+
async def _handle_transaction(self, connection: Any, status_code: int) -> None:
90+
"""Handle transaction commit/rollback based on status code.
91+
92+
Args:
93+
connection: The database connection.
94+
status_code: HTTP response status code.
95+
"""
96+
http_ok = 200
97+
http_multiple_choices = 300
98+
http_bad_request = 400
99+
100+
should_commit = False
101+
102+
if self.commit_mode == "autocommit":
103+
should_commit = http_ok <= status_code < http_multiple_choices
104+
elif self.commit_mode == "autocommit_include_redirect":
105+
should_commit = http_ok <= status_code < http_bad_request
106+
107+
# Apply extra status overrides
108+
if status_code in self.extra_commit_statuses:
109+
should_commit = True
110+
elif status_code in self.extra_rollback_statuses:
111+
should_commit = False
112+
113+
# Execute transaction action
114+
if should_commit and hasattr(connection, "commit") and callable(connection.commit):
115+
await ensure_async_(connection.commit)()
116+
elif not should_commit and hasattr(connection, "rollback") and callable(connection.rollback):
117+
await ensure_async_(connection.rollback)()
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
"""Provider functions for SQLSpec FastAPI integration."""
2+
3+
import contextlib
4+
from typing import TYPE_CHECKING, Any, cast
5+
6+
from sqlspec.utils.sync_tools import ensure_async_
7+
8+
if TYPE_CHECKING:
9+
from collections.abc import AsyncGenerator, Awaitable, Callable
10+
11+
from sqlspec.config import DatabaseConfigProtocol, DriverT
12+
from sqlspec.typing import ConnectionT, PoolT
13+
14+
15+
__all__ = ("create_connection_provider", "create_pool_provider", "create_session_provider")
16+
17+
18+
def create_pool_provider(
19+
config: "DatabaseConfigProtocol[Any, Any, Any]", pool_key: str
20+
) -> "Callable[[], Awaitable[PoolT]]":
21+
"""Create provider for database pool access.
22+
23+
Args:
24+
config: The database configuration object.
25+
pool_key: The key used to store the connection pool.
26+
27+
Returns:
28+
The pool provider function.
29+
"""
30+
31+
async def provide_pool() -> "PoolT":
32+
"""Provide the database pool.
33+
34+
Returns:
35+
The database connection pool.
36+
"""
37+
db_pool = await ensure_async_(config.create_pool)()
38+
return cast("PoolT", db_pool)
39+
40+
return provide_pool
41+
42+
43+
def create_connection_provider(
44+
config: "DatabaseConfigProtocol[Any, Any, Any]", pool_key: str, connection_key: str
45+
) -> "Callable[[], AsyncGenerator[ConnectionT, None]]":
46+
"""Create provider for database connections.
47+
48+
Args:
49+
config: The database configuration object.
50+
pool_key: The key used to store the connection pool.
51+
connection_key: The key used to store the connection.
52+
53+
Returns:
54+
The connection provider function.
55+
"""
56+
57+
async def provide_connection() -> "AsyncGenerator[ConnectionT, None]":
58+
"""Provide a database connection.
59+
60+
Yields:
61+
Database connection instance.
62+
"""
63+
db_pool = await ensure_async_(config.create_pool)()
64+
65+
try:
66+
connection_cm = config.provide_connection(db_pool)
67+
68+
# Handle both context managers and direct connections
69+
if hasattr(connection_cm, "__aenter__"):
70+
async with connection_cm as conn:
71+
yield cast("ConnectionT", conn)
72+
else:
73+
conn = await connection_cm if hasattr(connection_cm, "__await__") else connection_cm
74+
yield cast("ConnectionT", conn)
75+
finally:
76+
with contextlib.suppress(Exception):
77+
await ensure_async_(config.close_pool)()
78+
79+
return provide_connection
80+
81+
82+
def create_session_provider(
83+
config: "DatabaseConfigProtocol[Any, Any, Any]", connection_key: str
84+
) -> "Callable[[ConnectionT], AsyncGenerator[DriverT, None]]":
85+
"""Create provider for database sessions/drivers.
86+
87+
Args:
88+
config: The database configuration object.
89+
connection_key: The key used to access the connection.
90+
91+
Returns:
92+
The session provider function.
93+
"""
94+
95+
async def provide_session(connection: "ConnectionT") -> "AsyncGenerator[DriverT, None]":
96+
"""Provide a database session/driver.
97+
98+
Args:
99+
connection: The database connection.
100+
101+
Yields:
102+
Database driver/session instance.
103+
"""
104+
yield cast("DriverT", config.driver_type(connection=connection))
105+
106+
return provide_session

sqlspec/extensions/fastapi/cli.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""FastAPI CLI integration for SQLSpec migrations."""
2+
3+
from contextlib import suppress
4+
from typing import TYPE_CHECKING, cast
5+
6+
from sqlspec.cli import add_migration_commands
7+
8+
try:
9+
import rich_click as click
10+
except ImportError:
11+
import click # type: ignore[no-redef]
12+
13+
if TYPE_CHECKING:
14+
from fastapi import FastAPI
15+
16+
from sqlspec.extensions.fastapi.extension import SQLSpec
17+
18+
__all__ = ("get_database_migration_plugin", "register_database_commands")
19+
20+
21+
def get_database_migration_plugin(app: "FastAPI") -> "SQLSpec":
22+
"""Retrieve the SQLSpec plugin from the FastAPI application.
23+
24+
Args:
25+
app: The FastAPI application
26+
27+
Returns:
28+
The SQLSpec plugin
29+
30+
Raises:
31+
ImproperConfigurationError: If the SQLSpec plugin is not found
32+
"""
33+
from sqlspec.exceptions import ImproperConfigurationError
34+
35+
# FastAPI doesn't have a built-in plugin system like Litestar
36+
# Check if SQLSpec was stored in app.state
37+
with suppress(AttributeError):
38+
if hasattr(app.state, "sqlspec"):
39+
return cast("SQLSpec", app.state.sqlspec)
40+
41+
msg = "Failed to initialize database migrations. The required SQLSpec plugin is missing."
42+
raise ImproperConfigurationError(msg)
43+
44+
45+
def register_database_commands(app: "FastAPI") -> click.Group:
46+
"""Register database commands with a FastAPI application.
47+
48+
Args:
49+
app: The FastAPI application instance
50+
51+
Returns:
52+
Click group with database commands
53+
"""
54+
55+
@click.group(name="db")
56+
def database_group() -> None:
57+
"""Manage SQLSpec database components."""
58+
59+
# Add migration commands to the group
60+
add_migration_commands(database_group)
61+
62+
return database_group

0 commit comments

Comments
 (0)