Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ psycopg = ["psycopg[binary,pool]"]
pydantic = ["pydantic", "pydantic-extra-types"]
pymssql = ["pymssql"]
pymysql = ["pymysql"]
sanic = ["sanic", "sanic[ext]>=24.6.0"]
spanner = ["google-cloud-spanner"]
uuid = ["uuid-utils"]

Expand Down Expand Up @@ -90,6 +91,10 @@ extras = [
"adbc_driver_postgresql",
"adbc_driver_flightsql",
"adbc_driver_bigquery",
"sanic-testing",
"dishka ; python_version >= \"3.10\"",
"pydantic-extra-types",
"fsspec[s3]",
]
lint = [
"mypy>=1.13.0",
Expand Down Expand Up @@ -466,6 +471,7 @@ split-on-trailing-comma = false
"docs/**/*.*" = ["S", "B", "DTZ", "A", "TC", "ERA", "D", "RET", "PLW0127"]
"docs/examples/**" = ["T201"]
"sqlspec/builder/mixins/**/*.*" = ["SLF001"]
"sqlspec/extensions/fastapi/providers.py" = ["B008"]
"tests/**/*.*" = [
"A",
"ARG",
Expand Down
13 changes: 13 additions & 0 deletions sqlspec/extensions/fastapi/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from sqlspec.extensions.fastapi._middleware import SessionMiddleware
from sqlspec.extensions.fastapi.config import DatabaseConfig
from sqlspec.extensions.fastapi.extension import SQLSpec
from sqlspec.extensions.fastapi.providers import FilterConfig, create_filter_dependencies, provide_filters

__all__ = (
"DatabaseConfig",
"FilterConfig",
"SQLSpec",
"SessionMiddleware",
"create_filter_dependencies",
"provide_filters",
)
117 changes: 117 additions & 0 deletions sqlspec/extensions/fastapi/_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Middleware for SQLSpec FastAPI integration."""

import contextlib
from typing import TYPE_CHECKING, Any, Optional

from starlette.middleware.base import BaseHTTPMiddleware

from sqlspec.utils.sync_tools import ensure_async_

if TYPE_CHECKING:
from collections.abc import Awaitable, Callable

from starlette.requests import Request
from starlette.responses import Response

from sqlspec.extensions.fastapi.config import CommitMode, DatabaseConfig


__all__ = ("SessionMiddleware",)


class SessionMiddleware(BaseHTTPMiddleware):
"""Middleware for managing database sessions and transactions in FastAPI."""

def __init__(
self,
app: Any,
config: "DatabaseConfig",
commit_mode: "CommitMode" = "manual",
extra_commit_statuses: "Optional[set[int]]" = None,
extra_rollback_statuses: "Optional[set[int]]" = None,
) -> None:
"""Initialize session middleware.

Args:
app: The ASGI application.
config: Database configuration instance.
commit_mode: Transaction commit behavior.
extra_commit_statuses: Additional status codes that trigger commits.
extra_rollback_statuses: Additional status codes that trigger rollbacks.
"""
super().__init__(app)
self.config = config
self.commit_mode = commit_mode
self.extra_commit_statuses = extra_commit_statuses or set()
self.extra_rollback_statuses = extra_rollback_statuses or set()

async def dispatch(self, request: "Request", call_next: "Callable[[Request], Awaitable[Response]]") -> "Response":
"""Handle request with session management.

Args:
request: The incoming request.
call_next: The next middleware or endpoint.

Returns:
The response from the application.
"""
if not self.config.connection_provider:
return await call_next(request)

# Get connection from provider
connection_gen = self.config.connection_provider()
connection = await connection_gen.__anext__()

# Store connection in request state
request.state.__dict__[self.config.connection_key] = connection

try:
response = await call_next(request)

# Handle transaction based on commit mode and response status
if self.commit_mode != "manual":
await self._handle_transaction(connection, response.status_code)

except Exception:
# Rollback on exception
if hasattr(connection, "rollback") and callable(connection.rollback):
await ensure_async_(connection.rollback)()
raise
else:
return response
finally:
# Clean up connection
with contextlib.suppress(StopAsyncIteration):
await connection_gen.__anext__()
if hasattr(connection, "close") and callable(connection.close):
await ensure_async_(connection.close)()

async def _handle_transaction(self, connection: Any, status_code: int) -> None:
"""Handle transaction commit/rollback based on status code.

Args:
connection: The database connection.
status_code: HTTP response status code.
"""
http_ok = 200
http_multiple_choices = 300
http_bad_request = 400

should_commit = False

if self.commit_mode == "autocommit":
should_commit = http_ok <= status_code < http_multiple_choices
elif self.commit_mode == "autocommit_include_redirect":
should_commit = http_ok <= status_code < http_bad_request

# Apply extra status overrides
if status_code in self.extra_commit_statuses:
should_commit = True
elif status_code in self.extra_rollback_statuses:
should_commit = False

# Execute transaction action
if should_commit and hasattr(connection, "commit") and callable(connection.commit):
await ensure_async_(connection.commit)()
elif not should_commit and hasattr(connection, "rollback") and callable(connection.rollback):
await ensure_async_(connection.rollback)()
106 changes: 106 additions & 0 deletions sqlspec/extensions/fastapi/_providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""Provider functions for SQLSpec FastAPI integration."""

import contextlib
from typing import TYPE_CHECKING, Any, cast

from sqlspec.utils.sync_tools import ensure_async_

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Awaitable, Callable

from sqlspec.config import DatabaseConfigProtocol, DriverT
from sqlspec.typing import ConnectionT, PoolT


__all__ = ("create_connection_provider", "create_pool_provider", "create_session_provider")


def create_pool_provider(
config: "DatabaseConfigProtocol[Any, Any, Any]", pool_key: str
) -> "Callable[[], Awaitable[PoolT]]":
"""Create provider for database pool access.

Args:
config: The database configuration object.
pool_key: The key used to store the connection pool.

Returns:
The pool provider function.
"""

async def provide_pool() -> "PoolT":
"""Provide the database pool.

Returns:
The database connection pool.
"""
db_pool = await ensure_async_(config.create_pool)()
return cast("PoolT", db_pool)

return provide_pool


def create_connection_provider(
config: "DatabaseConfigProtocol[Any, Any, Any]", pool_key: str, connection_key: str
) -> "Callable[[], AsyncGenerator[ConnectionT, None]]":
"""Create provider for database connections.

Args:
config: The database configuration object.
pool_key: The key used to store the connection pool.
connection_key: The key used to store the connection.

Returns:
The connection provider function.
"""

async def provide_connection() -> "AsyncGenerator[ConnectionT, None]":
"""Provide a database connection.

Yields:
Database connection instance.
"""
db_pool = await ensure_async_(config.create_pool)()

try:
connection_cm = config.provide_connection(db_pool)

# Handle both context managers and direct connections
if hasattr(connection_cm, "__aenter__"):
async with connection_cm as conn:
yield cast("ConnectionT", conn)
else:
conn = await connection_cm if hasattr(connection_cm, "__await__") else connection_cm
yield cast("ConnectionT", conn)
finally:
with contextlib.suppress(Exception):
await ensure_async_(config.close_pool)()

return provide_connection


def create_session_provider(
config: "DatabaseConfigProtocol[Any, Any, Any]", connection_key: str
) -> "Callable[[ConnectionT], AsyncGenerator[DriverT, None]]":
"""Create provider for database sessions/drivers.

Args:
config: The database configuration object.
connection_key: The key used to access the connection.

Returns:
The session provider function.
"""

async def provide_session(connection: "ConnectionT") -> "AsyncGenerator[DriverT, None]":
"""Provide a database session/driver.

Args:
connection: The database connection.

Yields:
Database driver/session instance.
"""
yield cast("DriverT", config.driver_type(connection=connection))

return provide_session
62 changes: 62 additions & 0 deletions sqlspec/extensions/fastapi/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""FastAPI CLI integration for SQLSpec migrations."""

from contextlib import suppress
from typing import TYPE_CHECKING, cast

from sqlspec.cli import add_migration_commands

try:
import rich_click as click
except ImportError:
import click # type: ignore[no-redef]

if TYPE_CHECKING:
from fastapi import FastAPI

from sqlspec.extensions.fastapi.extension import SQLSpec

__all__ = ("get_database_migration_plugin", "register_database_commands")


def get_database_migration_plugin(app: "FastAPI") -> "SQLSpec":
"""Retrieve the SQLSpec plugin from the FastAPI application.

Args:
app: The FastAPI application

Returns:
The SQLSpec plugin

Raises:
ImproperConfigurationError: If the SQLSpec plugin is not found
"""
from sqlspec.exceptions import ImproperConfigurationError

# FastAPI doesn't have a built-in plugin system like Litestar
# Check if SQLSpec was stored in app.state
with suppress(AttributeError):
if hasattr(app.state, "sqlspec"):
return cast("SQLSpec", app.state.sqlspec)

msg = "Failed to initialize database migrations. The required SQLSpec plugin is missing."
raise ImproperConfigurationError(msg)


def register_database_commands(app: "FastAPI") -> click.Group:
"""Register database commands with a FastAPI application.

Args:
app: The FastAPI application instance

Returns:
Click group with database commands
"""

@click.group(name="db")
def database_group() -> None:
"""Manage SQLSpec database components."""

# Add migration commands to the group
add_migration_commands(database_group)

return database_group
Loading
Loading