Skip to content

Commit 9a66800

Browse files
committed
Rework lifespan
1 parent dd80e74 commit 9a66800

File tree

4 files changed

+108
-115
lines changed

4 files changed

+108
-115
lines changed

src/stac_auth_proxy/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,9 @@
88

99
from .app import configure_app, create_app
1010
from .config import Settings
11-
from .lifespan import check_conformance, check_server_health, lifespan
1211

1312
__all__ = [
1413
"create_app",
1514
"configure_app",
16-
"lifespan",
17-
"check_conformance",
18-
"check_server_health",
1915
"Settings",
2016
]

src/stac_auth_proxy/app.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from .config import Settings
1515
from .handlers import HealthzHandler, ReverseProxyHandler, SwaggerUI
16-
from .lifespan import lifespan
16+
from .lifespan import build_lifespan
1717
from .middleware import (
1818
AddProcessTimeHeaderMiddleware,
1919
AuthenticationExtensionMiddleware,
@@ -31,11 +31,11 @@
3131

3232

3333
def configure_app(app: FastAPI, settings: Optional[Settings] = None) -> FastAPI:
34-
"""Apply routes and middleware to an existing FastAPI app."""
34+
"""Apply routes and middleware to a FastAPI app."""
3535
settings = settings or Settings()
3636

3737
#
38-
# Handlers (place catch-all proxy handler last)
38+
# Route Handlers
3939
#
4040

4141
# If we have customized Swagger UI Init settings (e.g. a provided client_id)
@@ -143,7 +143,7 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
143143

144144
app = FastAPI(
145145
openapi_url=None, # Disable OpenAPI schema endpoint, we want to serve upstream's schema
146-
lifespan=lifespan(settings=settings),
146+
lifespan=build_lifespan(settings=settings),
147147
root_path=settings.root_path,
148148
)
149149
if app.root_path:

src/stac_auth_proxy/lifespan.py

Lines changed: 104 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,116 @@
11
"""Reusable lifespan handler for FastAPI applications."""
22

3+
import asyncio
34
import logging
5+
import re
46
from contextlib import asynccontextmanager
57
from typing import Any
68

9+
import httpx
710
from fastapi import FastAPI
11+
from pydantic import HttpUrl
12+
from starlette.middleware import Middleware
813

914
from .config import Settings
10-
from .utils.lifespan import check_conformance, check_server_health
1115

1216
logger = logging.getLogger(__name__)
17+
__all__ = ["build_lifespan", "check_conformance", "check_server_health"]
18+
19+
20+
async def check_server_healths(*urls: str | HttpUrl) -> None:
21+
"""Wait for upstream APIs to become available."""
22+
logger.info("Running upstream server health checks...")
23+
for url in urls:
24+
await check_server_health(url)
25+
logger.info(
26+
"Upstream servers are healthy:\n%s",
27+
"\n".join([f" - {url}" for url in urls]),
28+
)
29+
30+
31+
async def check_server_health(
32+
url: str | HttpUrl,
33+
max_retries: int = 10,
34+
retry_delay: float = 1.0,
35+
retry_delay_max: float = 5.0,
36+
timeout: float = 5.0,
37+
) -> None:
38+
"""Wait for upstream API to become available."""
39+
# Convert url to string if it's a HttpUrl
40+
if isinstance(url, HttpUrl):
41+
url = str(url)
42+
43+
async with httpx.AsyncClient(
44+
base_url=url, timeout=timeout, follow_redirects=True
45+
) as client:
46+
for attempt in range(max_retries):
47+
try:
48+
response = await client.get("/")
49+
response.raise_for_status()
50+
logger.info(f"Upstream API {url!r} is healthy")
51+
return
52+
except httpx.ConnectError as e:
53+
logger.warning(f"Upstream health check for {url!r} failed: {e}")
54+
retry_in = min(retry_delay * (2**attempt), retry_delay_max)
55+
logger.warning(
56+
f"Upstream API {url!r} not healthy, retrying in {retry_in:.1f}s "
57+
f"(attempt {attempt + 1}/{max_retries})"
58+
)
59+
await asyncio.sleep(retry_in)
60+
61+
raise RuntimeError(
62+
f"Upstream API {url!r} failed to respond after {max_retries} attempts"
63+
)
64+
65+
66+
async def check_conformance(
67+
middleware_classes: list[Middleware],
68+
api_url: str,
69+
attr_name: str = "__required_conformances__",
70+
endpoint: str = "/conformance",
71+
):
72+
"""Check if the upstream API supports a given conformance class."""
73+
required_conformances: dict[str, list[str]] = {}
74+
for middleware in middleware_classes:
75+
76+
for conformance in getattr(middleware.cls, attr_name, []):
77+
required_conformances.setdefault(conformance, []).append(
78+
middleware.cls.__name__
79+
)
80+
81+
async with httpx.AsyncClient(base_url=api_url) as client:
82+
response = await client.get(endpoint)
83+
response.raise_for_status()
84+
api_conforms_to = response.json().get("conformsTo", [])
85+
86+
missing = [
87+
req_conformance
88+
for req_conformance in required_conformances.keys()
89+
if not any(
90+
re.match(req_conformance, conformance) for conformance in api_conforms_to
91+
)
92+
]
93+
94+
def conformance_str(conformance: str) -> str:
95+
return f" - {conformance} [{','.join(required_conformances[conformance])}]"
96+
97+
if missing:
98+
missing_str = [conformance_str(c) for c in missing]
99+
raise RuntimeError(
100+
"\n".join(
101+
[
102+
"Upstream catalog is missing the following conformance classes:",
103+
*missing_str,
104+
]
105+
)
106+
)
107+
logger.info(
108+
"Upstream catalog conforms to the following required conformance classes: \n%s",
109+
"\n".join([conformance_str(c) for c in required_conformances]),
110+
)
13111

14112

15-
def lifespan(settings: Settings | None = None, **settings_kwargs: Any):
113+
def build_lifespan(settings: Settings | None = None, **settings_kwargs: Any):
16114
"""
17115
Create a lifespan handler that runs startup checks.
18116
@@ -34,18 +132,13 @@ def lifespan(settings: Settings | None = None, **settings_kwargs: Any):
34132
settings = Settings(**settings_kwargs)
35133

36134
@asynccontextmanager
37-
async def _lifespan(app: FastAPI):
135+
async def lifespan(app: "FastAPI"):
38136
assert settings is not None # Required for type checking
39137

40138
# Wait for upstream servers to become available
41139
if settings.wait_for_upstream:
42-
logger.info("Running upstream server health checks...")
43-
urls = [settings.upstream_url, settings.oidc_discovery_internal_url]
44-
for url in urls:
45-
await check_server_health(url=url)
46-
logger.info(
47-
"Upstream servers are healthy:\n%s",
48-
"\n".join([f" - {url}" for url in urls]),
140+
await check_server_healths(
141+
settings.upstream_url, settings.oidc_discovery_internal_url
49142
)
50143

51144
# Log all middleware connected to the app
@@ -59,7 +152,4 @@ async def _lifespan(app: FastAPI):
59152

60153
yield
61154

62-
return _lifespan
63-
64-
65-
__all__ = ["lifespan", "check_conformance", "check_server_health"]
155+
return lifespan

src/stac_auth_proxy/utils/lifespan.py

Lines changed: 0 additions & 93 deletions
This file was deleted.

0 commit comments

Comments
 (0)