Skip to content

Commit f3494d8

Browse files
committed
Add healthcheck integration in lifespan
1 parent 1b2dd81 commit f3494d8

File tree

5 files changed

+120
-1
lines changed

5 files changed

+120
-1
lines changed

src/stac_auth_proxy/app.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from .config import Settings
1414
from .handlers import HealthzHandler, ReverseProxyHandler
15+
from .lifespan import LifespanManager, ServerHealthCheck
1516
from .middleware import (
1617
AddProcessTimeHeaderMiddleware,
1718
ApplyCql2FilterMiddleware,
@@ -27,8 +28,22 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
2728
"""FastAPI Application Factory."""
2829
settings = settings or Settings()
2930

31+
upstream_urls = [
32+
settings.upstream_url,
33+
settings.oidc_discovery_internal_url or settings.oidc_discovery_url,
34+
]
35+
lifespan = LifespanManager(
36+
on_startup=(
37+
[ServerHealthCheck(url=url) for url in upstream_urls]
38+
if settings.wait_for_upstream
39+
else []
40+
)
41+
)
42+
3043
app = FastAPI(
3144
openapi_url=None, # Disable OpenAPI schema endpoint, we want to serve upstream's schema
45+
lifespan=lifespan,
46+
)
3247

3348
# Add catchall proxy handler
3449
proxy_handler = ReverseProxyHandler(upstream=str(settings.upstream_url))
@@ -59,7 +74,6 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
5974
app.add_middleware(ApplyCql2FilterMiddleware)
6075
app.add_middleware(
6176
BuildCql2FilterMiddleware,
62-
# collections_filter=settings.collections_filter,
6377
items_filter=settings.items_filter(),
6478
)
6579

src/stac_auth_proxy/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class Settings(BaseSettings):
3636
oidc_discovery_url: HttpUrl
3737
oidc_discovery_internal_url: Optional[HttpUrl] = None
3838

39+
wait_for_upstream: bool = True
40+
3941
# Endpoints
4042
healthz_prefix: str = Field(pattern=_PREFIX_PATTERN, default="/healthz")
4143
openapi_spec_endpoint: Optional[str] = Field(pattern=_PREFIX_PATTERN, default=None)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""Lifespan manager for FastAPI applications."""
2+
3+
import logging
4+
from contextlib import asynccontextmanager
5+
from dataclasses import dataclass, field
6+
from typing import AsyncGenerator, Awaitable, Callable, List
7+
8+
from fastapi import FastAPI
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
@dataclass
14+
class LifespanManager:
15+
"""Manager for FastAPI lifespan events."""
16+
17+
on_startup: List[Callable[[], Awaitable[None]]] = field(default_factory=list)
18+
on_teardown: List[Callable[[], Awaitable[None]]] = field(default_factory=list)
19+
20+
@asynccontextmanager
21+
async def __call__(self, app: FastAPI) -> AsyncGenerator[None, None]:
22+
"""FastAPI lifespan event handler."""
23+
for i, task in enumerate(self.on_startup):
24+
logger.debug(f"Executing startup task {i+1}/{len(self.on_startup)}")
25+
await task()
26+
27+
logger.debug("All startup tasks completed successfully")
28+
29+
yield
30+
31+
# Execute teardown tasks
32+
for i, task in enumerate(self.on_teardown):
33+
try:
34+
logger.debug(f"Executing teardown task {i+1}/{len(self.on_teardown)}")
35+
await task()
36+
except Exception as e:
37+
logger.error(f"Teardown task failed: {e}")
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""Health check implementations for lifespan events."""
2+
3+
import asyncio
4+
import logging
5+
from dataclasses import dataclass
6+
7+
import httpx
8+
from pydantic import HttpUrl
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
@dataclass
14+
class ServerHealthCheck:
15+
"""Health check for upstream API."""
16+
17+
url: str | HttpUrl
18+
max_retries: int = 5
19+
retry_delay: float = 0.25
20+
retry_delay_max: float = 10.0
21+
timeout: float = 5.0
22+
23+
def __post_init__(self):
24+
"""Convert url to string if it's a HttpUrl."""
25+
if isinstance(self.url, HttpUrl):
26+
self.url = str(self.url)
27+
28+
async def _check_health(self) -> bool:
29+
"""Check if upstream API is responding."""
30+
try:
31+
async with httpx.AsyncClient() as client:
32+
response = await client.get(
33+
self.url, timeout=self.timeout, follow_redirects=True
34+
)
35+
response.raise_for_status()
36+
return True
37+
except Exception as e:
38+
logger.warning(f"Upstream health check for {self.url!r} failed: {e}")
39+
return False
40+
41+
async def __call__(self) -> None:
42+
"""Wait for upstream API to become available."""
43+
for attempt in range(self.max_retries):
44+
if await self._check_health():
45+
logger.info(f"Upstream API {self.url!r} is healthy")
46+
return
47+
48+
retry_in = min(self.retry_delay * (2**attempt), self.retry_delay_max)
49+
logger.warning(
50+
f"Upstream API {self.url!r} not healthy, retrying in {retry_in:.1f}s "
51+
f"(attempt {attempt + 1}/{self.max_retries})"
52+
)
53+
await asyncio.sleep(retry_in)
54+
55+
raise RuntimeError(
56+
f"Upstream API {self.url!r} failed to respond after {self.max_retries} attempts"
57+
)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""Lifespan event handlers for the STAC Auth Proxy."""
2+
3+
from .LifespanManager import LifespanManager
4+
from .ServerHealthCheck import ServerHealthCheck
5+
6+
__all__ = [
7+
"ServerHealthCheck",
8+
"LifespanManager",
9+
]

0 commit comments

Comments
 (0)