Skip to content

Commit 12f2973

Browse files
authored
feat: check upstream API health at startup (#32)
* add lifespan tooling * check health status of upstream APIs before starting proxy (if `config.wait_for_upstream` is True) * reorg app slightly, place handlers before middleware * better support for internal OIDC routes (via `config.oidc_discovery_internal_url`)
1 parent 4c9f4f9 commit 12f2973

File tree

5 files changed

+146
-13
lines changed

5 files changed

+146
-13
lines changed

src/stac_auth_proxy/app.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from .config import Settings
1414
from .handlers import HealthzHandler, ReverseProxyHandler
15+
from .lifespan import LifespanManager, ServerHealthCheck
1516
from .middleware import (
1617
AddProcessTimeHeaderMiddleware,
1718
ApplyCql2FilterMiddleware,
@@ -27,12 +28,44 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
2728
"""FastAPI Application Factory."""
2829
settings = settings or Settings()
2930

31+
#
32+
# Application
33+
#
34+
upstream_urls = [
35+
settings.upstream_url,
36+
settings.oidc_discovery_internal_url or settings.oidc_discovery_url,
37+
]
38+
lifespan = LifespanManager(
39+
on_startup=(
40+
[ServerHealthCheck(url=url) for url in upstream_urls]
41+
if settings.wait_for_upstream
42+
else []
43+
)
44+
)
45+
3046
app = FastAPI(
3147
openapi_url=None, # Disable OpenAPI schema endpoint, we want to serve upstream's schema
48+
lifespan=lifespan,
3249
)
3350

34-
app.add_middleware(AddProcessTimeHeaderMiddleware)
51+
#
52+
# Handlers (place catch-all proxy handler last)
53+
#
54+
if settings.healthz_prefix:
55+
app.include_router(
56+
HealthzHandler(upstream_url=str(settings.upstream_url)).router,
57+
prefix=settings.healthz_prefix,
58+
)
59+
60+
app.add_api_route(
61+
"/{path:path}",
62+
ReverseProxyHandler(upstream=str(settings.upstream_url)).stream,
63+
methods=["GET", "POST", "PUT", "PATCH", "DELETE"],
64+
)
3565

66+
#
67+
# Middleware (order is important, last added = first to run)
68+
#
3669
if settings.openapi_spec_endpoint:
3770
app.add_middleware(
3871
OpenApiMiddleware,
@@ -44,10 +77,11 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
4477
)
4578

4679
if settings.items_filter:
47-
app.add_middleware(ApplyCql2FilterMiddleware)
80+
app.add_middleware(
81+
ApplyCql2FilterMiddleware,
82+
)
4883
app.add_middleware(
4984
BuildCql2FilterMiddleware,
50-
# collections_filter=settings.collections_filter,
5185
items_filter=settings.items_filter(),
5286
)
5387

@@ -57,18 +91,11 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
5791
private_endpoints=settings.private_endpoints,
5892
default_public=settings.default_public,
5993
oidc_config_url=settings.oidc_discovery_url,
94+
oidc_config_internal_url=settings.oidc_discovery_internal_url,
6095
)
6196

62-
if settings.healthz_prefix:
63-
healthz_handler = HealthzHandler(upstream_url=str(settings.upstream_url))
64-
app.include_router(healthz_handler.router, prefix="/healthz")
65-
66-
# Catchall for any endpoint
67-
proxy_handler = ReverseProxyHandler(upstream=str(settings.upstream_url))
68-
app.add_api_route(
69-
"/{path:path}",
70-
proxy_handler.stream,
71-
methods=["GET", "POST", "PUT", "PATCH", "DELETE"],
97+
app.add_middleware(
98+
AddProcessTimeHeaderMiddleware,
7299
)
73100

74101
return app

src/stac_auth_proxy/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ class Settings(BaseSettings):
3434
# External URLs
3535
upstream_url: HttpUrl
3636
oidc_discovery_url: HttpUrl
37+
oidc_discovery_internal_url: Optional[HttpUrl] = None
38+
39+
wait_for_upstream: bool = True
3740

3841
# Endpoints
3942
healthz_prefix: str = Field(pattern=_PREFIX_PATTERN, default="/healthz")
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""Lifespan manager for FastAPI applications."""
2+
3+
import logging
4+
from contextlib import asynccontextmanager
5+
from dataclasses import dataclass, field
6+
from typing import AsyncGenerator, Awaitable, Callable, List
7+
8+
from fastapi import FastAPI
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
@dataclass
14+
class LifespanManager:
15+
"""Manager for FastAPI lifespan events."""
16+
17+
on_startup: List[Callable[[], Awaitable[None]]] = field(default_factory=list)
18+
on_teardown: List[Callable[[], Awaitable[None]]] = field(default_factory=list)
19+
20+
@asynccontextmanager
21+
async def __call__(self, app: FastAPI) -> AsyncGenerator[None, None]:
22+
"""FastAPI lifespan event handler."""
23+
for i, task in enumerate(self.on_startup):
24+
logger.debug(f"Executing startup task {i+1}/{len(self.on_startup)}")
25+
await task()
26+
27+
logger.debug("All startup tasks completed successfully")
28+
29+
yield
30+
31+
# Execute teardown tasks
32+
for i, task in enumerate(self.on_teardown):
33+
try:
34+
logger.debug(f"Executing teardown task {i+1}/{len(self.on_teardown)}")
35+
await task()
36+
except Exception as e:
37+
logger.error(f"Teardown task failed: {e}")
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""Health check implementations for lifespan events."""
2+
3+
import asyncio
4+
import logging
5+
from dataclasses import dataclass
6+
7+
import httpx
8+
from pydantic import HttpUrl
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
@dataclass
14+
class ServerHealthCheck:
15+
"""Health check for upstream API."""
16+
17+
url: str | HttpUrl
18+
max_retries: int = 5
19+
retry_delay: float = 0.25
20+
retry_delay_max: float = 10.0
21+
timeout: float = 5.0
22+
23+
def __post_init__(self):
24+
"""Convert url to string if it's a HttpUrl."""
25+
if isinstance(self.url, HttpUrl):
26+
self.url = str(self.url)
27+
28+
async def _check_health(self) -> bool:
29+
"""Check if upstream API is responding."""
30+
try:
31+
async with httpx.AsyncClient() as client:
32+
response = await client.get(
33+
self.url, timeout=self.timeout, follow_redirects=True
34+
)
35+
response.raise_for_status()
36+
return True
37+
except Exception as e:
38+
logger.warning(f"Upstream health check for {self.url!r} failed: {e}")
39+
return False
40+
41+
async def __call__(self) -> None:
42+
"""Wait for upstream API to become available."""
43+
for attempt in range(self.max_retries):
44+
if await self._check_health():
45+
logger.info(f"Upstream API {self.url!r} is healthy")
46+
return
47+
48+
retry_in = min(self.retry_delay * (2**attempt), self.retry_delay_max)
49+
logger.warning(
50+
f"Upstream API {self.url!r} not healthy, retrying in {retry_in:.1f}s "
51+
f"(attempt {attempt + 1}/{self.max_retries})"
52+
)
53+
await asyncio.sleep(retry_in)
54+
55+
raise RuntimeError(
56+
f"Upstream API {self.url!r} failed to respond after {self.max_retries} attempts"
57+
)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""Lifespan event handlers for the STAC Auth Proxy."""
2+
3+
from .LifespanManager import LifespanManager
4+
from .ServerHealthCheck import ServerHealthCheck
5+
6+
__all__ = [
7+
"ServerHealthCheck",
8+
"LifespanManager",
9+
]

0 commit comments

Comments
 (0)