Skip to content

Commit d4b30c7

Browse files
committed
refactor: expose reusable lifespan handler
1 parent 301b39c commit d4b30c7

File tree

4 files changed

+85
-49
lines changed

4 files changed

+85
-49
lines changed

src/stac_auth_proxy/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,11 @@
88

99
from .app import configure_app, create_app
1010
from .config import Settings
11-
from .lifespan import check_conformance, check_server_health
11+
from .lifespan import lifespan
1212

1313
__all__ = [
1414
"create_app",
1515
"configure_app",
16-
"check_conformance",
17-
"check_server_health",
16+
"lifespan",
1817
"Settings",
1918
]

src/stac_auth_proxy/app.py

Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"""
77

88
import logging
9-
from contextlib import asynccontextmanager
109
from typing import Optional
1110

1211
from fastapi import FastAPI
@@ -26,7 +25,7 @@
2625
ProcessLinksMiddleware,
2726
RemoveRootPathMiddleware,
2827
)
29-
from .utils.lifespan import check_conformance, check_server_health
28+
from .lifespan import lifespan
3029

3130
logger = logging.getLogger(__name__)
3231

@@ -136,48 +135,13 @@ def configure_app(app: FastAPI, settings: Optional[Settings] = None) -> FastAPI:
136135
)
137136

138137
return app
139-
140-
141138
def create_app(settings: Optional[Settings] = None) -> FastAPI:
142139
"""FastAPI Application Factory."""
143140
settings = settings or Settings()
144141

145-
#
146-
# Application
147-
#
148-
149-
@asynccontextmanager
150-
async def lifespan(app: FastAPI):
151-
assert settings
152-
153-
# Wait for upstream servers to become available
154-
if settings.wait_for_upstream:
155-
logger.info("Running upstream server health checks...")
156-
urls = [settings.upstream_url, settings.oidc_discovery_internal_url]
157-
for url in urls:
158-
await check_server_health(url=url)
159-
logger.info(
160-
"Upstream servers are healthy:\n%s",
161-
"\n".join([f" - {url}" for url in urls]),
162-
)
163-
164-
# Log all middleware connected to the app
165-
logger.info(
166-
"Connected middleware:\n%s",
167-
"\n".join([f" - {m.cls.__name__}" for m in app.user_middleware]),
168-
)
169-
170-
if settings.check_conformance:
171-
await check_conformance(
172-
app.user_middleware,
173-
str(settings.upstream_url),
174-
)
175-
176-
yield
177-
178142
app = FastAPI(
179143
openapi_url=None, # Disable OpenAPI schema endpoint, we want to serve upstream's schema
180-
lifespan=lifespan,
144+
lifespan=lifespan(settings),
181145
root_path=settings.root_path,
182146
)
183147
if app.root_path:

src/stac_auth_proxy/lifespan.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,58 @@
1-
"""Public access to lifespan health checks.
1+
"""Reusable lifespan handler for FastAPI applications."""
22

3-
This module re-exports the ``check_server_health`` and ``check_conformance``
4-
utilities so that library users can import them without reaching into the
5-
internal ``utils`` package.
6-
"""
3+
from contextlib import asynccontextmanager
4+
import logging
5+
from fastapi import FastAPI
76

7+
from .config import Settings
88
from .utils.lifespan import check_conformance, check_server_health
99

10-
__all__ = ["check_server_health", "check_conformance"]
10+
logger = logging.getLogger(__name__)
11+
12+
13+
def lifespan(settings: Settings | None = None):
14+
"""Create a lifespan handler that runs startup checks.
15+
16+
Parameters
17+
----------
18+
settings : Settings | None
19+
Configuration for health and conformance checks. If omitted, default
20+
settings are loaded.
21+
22+
Returns
23+
-------
24+
Callable[[FastAPI], AsyncContextManager[Any]]
25+
A callable suitable for the ``lifespan`` parameter of ``FastAPI``.
26+
"""
27+
28+
settings = settings or Settings()
29+
30+
@asynccontextmanager
31+
async def _lifespan(app: FastAPI):
32+
# Wait for upstream servers to become available
33+
if settings.wait_for_upstream:
34+
logger.info("Running upstream server health checks...")
35+
urls = [settings.upstream_url, settings.oidc_discovery_internal_url]
36+
for url in urls:
37+
await check_server_health(url=url)
38+
logger.info(
39+
"Upstream servers are healthy:\n%s",
40+
"\n".join([f" - {url}" for url in urls]),
41+
)
42+
43+
# Log all middleware connected to the app
44+
logger.info(
45+
"Connected middleware:\n%s",
46+
"\n".join([f" - {m.cls.__name__}" for m in app.user_middleware]),
47+
)
48+
49+
if settings.check_conformance:
50+
await check_conformance(app.user_middleware, str(settings.upstream_url))
51+
52+
yield
53+
54+
return _lifespan
55+
56+
57+
__all__ = ["lifespan"]
58+

tests/test_lifespan.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
"""Tests for lifespan module."""
22

33
from dataclasses import dataclass
4-
from unittest.mock import patch
4+
from unittest.mock import AsyncMock, patch
55

66
import pytest
7+
from fastapi import FastAPI
8+
from fastapi.testclient import TestClient
79
from starlette.middleware import Middleware
810
from starlette.types import ASGIApp
911

10-
from stac_auth_proxy import check_conformance, check_server_health
12+
from stac_auth_proxy import Settings, lifespan as lifespan_handler
13+
from stac_auth_proxy.utils.lifespan import check_conformance, check_server_health
1114
from stac_auth_proxy.utils.middleware import required_conformance
1215

1316

@@ -80,3 +83,25 @@ def __init__(self, app):
8083

8184
middleware = [Middleware(NoConformanceMiddleware)]
8285
await check_conformance(middleware, source_api_server)
86+
87+
88+
def test_lifespan_reusable():
89+
"""Ensure the public lifespan handler runs health and conformance checks."""
90+
settings = Settings(
91+
upstream_url="https://example.com",
92+
oidc_discovery_url="https://example.com/.well-known/openid-configuration",
93+
)
94+
with patch(
95+
"stac_auth_proxy.lifespan.check_server_health",
96+
new=AsyncMock(),
97+
) as mock_health, patch(
98+
"stac_auth_proxy.lifespan.check_conformance",
99+
new=AsyncMock(),
100+
) as mock_conf:
101+
app = FastAPI(lifespan=lifespan_handler(settings))
102+
with TestClient(app):
103+
pass
104+
assert mock_health.await_count == 2
105+
mock_conf.assert_awaited_once_with(
106+
app.user_middleware, str(settings.upstream_url)
107+
)

0 commit comments

Comments
 (0)