Skip to content

refactor: allow configuring external FastAPI app #85

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/stac_auth_proxy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@
with some internal STAC API.
"""

from .app import create_app
from .app import configure_app, create_app
from .config import Settings
from .lifespan import check_conformance, check_server_health, lifespan

__all__ = ["create_app", "Settings"]
__all__ = [
"create_app",
"configure_app",
"lifespan",
"check_conformance",
"check_server_health",
"Settings",
]
81 changes: 27 additions & 54 deletions src/stac_auth_proxy/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
"""

import logging
from contextlib import asynccontextmanager
from typing import Optional

from fastapi import FastAPI
Expand All @@ -26,56 +25,15 @@
ProcessLinksMiddleware,
RemoveRootPathMiddleware,
)
from .utils.lifespan import check_conformance, check_server_health
from .lifespan import lifespan

logger = logging.getLogger(__name__)


def create_app(settings: Optional[Settings] = None) -> FastAPI:
"""FastAPI Application Factory."""
def configure_app(app: FastAPI, settings: Optional[Settings] = None) -> FastAPI:
"""Apply routes and middleware to an existing FastAPI app."""
settings = settings or Settings()

#
# Application
#

@asynccontextmanager
async def lifespan(app: FastAPI):
assert settings

# Wait for upstream servers to become available
if settings.wait_for_upstream:
logger.info("Running upstream server health checks...")
urls = [settings.upstream_url, settings.oidc_discovery_internal_url]
for url in urls:
await check_server_health(url=url)
logger.info(
"Upstream servers are healthy:\n%s",
"\n".join([f" - {url}" for url in urls]),
)

# Log all middleware connected to the app
logger.info(
"Connected middleware:\n%s",
"\n".join([f" - {m.cls.__name__}" for m in app.user_middleware]),
)

if settings.check_conformance:
await check_conformance(
app.user_middleware,
str(settings.upstream_url),
)

yield

app = FastAPI(
openapi_url=None, # Disable OpenAPI schema endpoint, we want to serve upstream's schema
lifespan=lifespan,
root_path=settings.root_path,
)
if app.root_path:
logger.debug("Mounted app at %s", app.root_path)

#
# Handlers (place catch-all proxy handler last)
#
Expand Down Expand Up @@ -105,15 +63,6 @@ async def lifespan(app: FastAPI):
prefix=settings.healthz_prefix,
)

app.add_api_route(
"/{path:path}",
ReverseProxyHandler(
upstream=str(settings.upstream_url),
override_host=settings.override_host,
).proxy_request,
methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
)

#
# Middleware (order is important, last added = first to run)
#
Expand Down Expand Up @@ -185,4 +134,28 @@ async def lifespan(app: FastAPI):
CompressionMiddleware,
)

return app
def create_app(settings: Optional[Settings] = None) -> FastAPI:
"""FastAPI Application Factory."""
settings = settings or Settings()

app = FastAPI(
openapi_url=None, # Disable OpenAPI schema endpoint, we want to serve upstream's schema
lifespan=lifespan(settings=settings),
root_path=settings.root_path,
)
if app.root_path:
logger.debug("Mounted app at %s", app.root_path)

configure_app(app, settings)

app.add_api_route(
"/{path:path}",
ReverseProxyHandler(
upstream=str(settings.upstream_url),
override_host=settings.override_host,
).proxy_request,
methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
)

return app
64 changes: 64 additions & 0 deletions src/stac_auth_proxy/lifespan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Reusable lifespan handler for FastAPI applications."""

from contextlib import asynccontextmanager
import logging
from typing import Any

from fastapi import FastAPI

from .config import Settings
from .utils.lifespan import check_conformance, check_server_health

logger = logging.getLogger(__name__)


def lifespan(settings: Settings | None = None, **settings_kwargs: Any):
"""Create a lifespan handler that runs startup checks.

Parameters
----------
settings : Settings | None, optional
Pre-built settings instance. If omitted, a new one is constructed from
``settings_kwargs``.
**settings_kwargs : Any
Keyword arguments used to configure the health and conformance checks if
``settings`` is not provided.

Returns
-------
Callable[[FastAPI], AsyncContextManager[Any]]
A callable suitable for the ``lifespan`` parameter of ``FastAPI``.
"""

if settings is None:
settings = Settings(**settings_kwargs)

@asynccontextmanager
async def _lifespan(app: FastAPI):
# Wait for upstream servers to become available
if settings.wait_for_upstream:
logger.info("Running upstream server health checks...")
urls = [settings.upstream_url, settings.oidc_discovery_internal_url]
for url in urls:
await check_server_health(url=url)
logger.info(
"Upstream servers are healthy:\n%s",
"\n".join([f" - {url}" for url in urls]),
)

# Log all middleware connected to the app
logger.info(
"Connected middleware:\n%s",
"\n".join([f" - {m.cls.__name__}" for m in app.user_middleware]),
)

if settings.check_conformance:
await check_conformance(app.user_middleware, str(settings.upstream_url))

yield

return _lifespan


__all__ = ["lifespan", "check_conformance", "check_server_health"]

24 changes: 24 additions & 0 deletions tests/test_configure_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Tests for configuring an external FastAPI application."""

from fastapi import FastAPI
from fastapi.routing import APIRoute

from stac_auth_proxy import Settings, configure_app


def test_configure_app_excludes_proxy_route():
"""Ensure `configure_app` adds health route and omits proxy route."""
app = FastAPI()
settings = Settings(
upstream_url="https://example.com",
oidc_discovery_url="https://example.com/.well-known/openid-configuration",
wait_for_upstream=False,
check_conformance=False,
default_public=True,
)

configure_app(app, settings)

routes = [r.path for r in app.router.routes if isinstance(r, APIRoute)]
assert settings.healthz_prefix in routes
assert "/{path:path}" not in routes
36 changes: 34 additions & 2 deletions tests/test_lifespan.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
"""Tests for lifespan module."""

from dataclasses import dataclass
from unittest.mock import patch
from unittest.mock import AsyncMock, patch

import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from starlette.middleware import Middleware
from starlette.types import ASGIApp

from stac_auth_proxy.utils.lifespan import check_conformance, check_server_health
from stac_auth_proxy import (
check_conformance,
check_server_health,
lifespan as lifespan_handler,
)
from stac_auth_proxy.utils.middleware import required_conformance


Expand Down Expand Up @@ -80,3 +86,29 @@ def __init__(self, app):

middleware = [Middleware(NoConformanceMiddleware)]
await check_conformance(middleware, source_api_server)


def test_lifespan_reusable():
"""Ensure the public lifespan handler runs health and conformance checks."""
upstream_url = "https://example.com"
oidc_discovery_url = "https://example.com/.well-known/openid-configuration"
with patch(
"stac_auth_proxy.lifespan.check_server_health",
new=AsyncMock(),
) as mock_health, patch(
"stac_auth_proxy.lifespan.check_conformance",
new=AsyncMock(),
) as mock_conf:
app = FastAPI(
lifespan=lifespan_handler(
upstream_url=upstream_url,
oidc_discovery_url=oidc_discovery_url,
)
)
with TestClient(app):
pass
assert mock_health.await_count == 2
expected_upstream = upstream_url.rstrip("/") + "/"
mock_conf.assert_awaited_once_with(
app.user_middleware, expected_upstream
)
Loading