Skip to content

Commit 16139c8

Browse files
committed
Add healthcheck during startup
1 parent 1b2dd81 commit 16139c8

File tree

5 files changed

+114
-0
lines changed

5 files changed

+114
-0
lines changed

src/stac_auth_proxy/app.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from .config import Settings
1414
from .handlers import HealthzHandler, ReverseProxyHandler
15+
from .lifespan import LifespanManager, ServerHealthCheck
1516
from .middleware import (
1617
AddProcessTimeHeaderMiddleware,
1718
ApplyCql2FilterMiddleware,
@@ -27,8 +28,22 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
2728
"""FastAPI Application Factory."""
2829
settings = settings or Settings()
2930

31+
url_checks = [
32+
settings.upstream_url,
33+
settings.oidc_discovery_internal_url or settings.oidc_discovery_url,
34+
]
35+
lifespan = LifespanManager(
36+
on_startup=(
37+
[ServerHealthCheck(url=str(url)) for url in url_checks]
38+
if settings.wait_for_upstream
39+
else []
40+
)
41+
)
42+
3043
app = FastAPI(
3144
openapi_url=None, # Disable OpenAPI schema endpoint, we want to serve upstream's schema
45+
lifespan=lifespan,
46+
)
3247

3348
# Add catchall proxy handler
3449
proxy_handler = ReverseProxyHandler(upstream=str(settings.upstream_url))

src/stac_auth_proxy/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class Settings(BaseSettings):
3636
oidc_discovery_url: HttpUrl
3737
oidc_discovery_internal_url: Optional[HttpUrl] = None
3838

39+
wait_for_upstream: bool = True
40+
3941
# Endpoints
4042
healthz_prefix: str = Field(pattern=_PREFIX_PATTERN, default="/healthz")
4143
openapi_spec_endpoint: Optional[str] = Field(pattern=_PREFIX_PATTERN, default=None)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""Lifespan manager for FastAPI applications."""
2+
3+
import logging
4+
from contextlib import asynccontextmanager
5+
from dataclasses import dataclass, field
6+
from typing import AsyncGenerator, Awaitable, Callable, List
7+
8+
from fastapi import FastAPI
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
@dataclass
14+
class LifespanManager:
15+
"""Manager for FastAPI lifespan events."""
16+
17+
on_startup: List[Callable[[], Awaitable[None]]] = field(default_factory=list)
18+
on_teardown: List[Callable[[], Awaitable[None]]] = field(default_factory=list)
19+
20+
@asynccontextmanager
21+
async def __call__(self, app: FastAPI) -> AsyncGenerator[None, None]:
22+
"""FastAPI lifespan event handler."""
23+
for i, task in enumerate(self.on_startup):
24+
logger.debug(f"Executing startup task {i+1}/{len(self.on_startup)}")
25+
await task()
26+
27+
logger.debug("All startup tasks completed successfully")
28+
29+
yield
30+
31+
# Execute teardown tasks
32+
for i, task in enumerate(self.on_teardown):
33+
try:
34+
logger.debug(f"Executing teardown task {i+1}/{len(self.on_teardown)}")
35+
await task()
36+
except Exception as e:
37+
logger.error(f"Teardown task failed: {e}")
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""Health check implementations for lifespan events."""
2+
3+
import asyncio
4+
import logging
5+
from dataclasses import dataclass
6+
7+
import httpx
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
@dataclass
13+
class ServerHealthCheck:
14+
"""Health check for upstream API."""
15+
16+
url: str
17+
max_retries: int = 5
18+
retry_delay: float = 0.25
19+
retry_delay_max: float = 10.0
20+
timeout: float = 5.0
21+
22+
async def _check_health(self) -> bool:
23+
"""Check if upstream API is responding."""
24+
try:
25+
async with httpx.AsyncClient() as client:
26+
response = await client.get(
27+
self.url, timeout=self.timeout, follow_redirects=True
28+
)
29+
response.raise_for_status()
30+
return True
31+
except Exception as e:
32+
logger.warning(f"Upstream health check failed: {e}")
33+
return False
34+
35+
async def __call__(self) -> None:
36+
"""Wait for upstream API to become available."""
37+
for attempt in range(self.max_retries):
38+
if await self._check_health():
39+
logger.info(f"Upstream API at {self.url} is healthy")
40+
return
41+
42+
retry_in = min(self.retry_delay * (2**attempt), self.retry_delay_max)
43+
logger.warning(
44+
f"Upstream API not healthy, retrying in {retry_in:.1f}s "
45+
f"(attempt {attempt + 1}/{self.max_retries})"
46+
)
47+
await asyncio.sleep(retry_in)
48+
49+
raise RuntimeError(
50+
f"Upstream API at {self.url} failed to respond after {self.max_retries} attempts"
51+
)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""Lifespan event handlers for the STAC Auth Proxy."""
2+
3+
from .LifespanManager import LifespanManager
4+
from .ServerHealthCheck import ServerHealthCheck
5+
6+
__all__ = [
7+
"ServerHealthCheck",
8+
"LifespanManager",
9+
]

0 commit comments

Comments
 (0)