Skip to content

Commit 410cf9a

Browse files
committed
feat: add middleware conformance checks
1 parent fe46940 commit 410cf9a

File tree

5 files changed

+90
-1
lines changed

5 files changed

+90
-1
lines changed

src/stac_auth_proxy/app.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@
2121
EnforceAuthMiddleware,
2222
OpenApiMiddleware,
2323
)
24-
from .utils.lifespan import check_server_health
24+
from .utils.lifespan import (
25+
check_conformance,
26+
check_server_health,
27+
log_middleware_classes,
28+
)
2529

2630
logger = logging.getLogger(__name__)
2731

@@ -40,9 +44,18 @@ async def lifespan(app: FastAPI):
4044

4145
# Wait for upstream servers to become available
4246
if settings.wait_for_upstream:
47+
logger.info("Running upstream server health checks...")
4348
for url in [settings.upstream_url, settings.oidc_discovery_internal_url]:
4449
await check_server_health(url=url)
4550

51+
# Log all middleware connected to the app
52+
await log_middleware_classes(app.user_middleware)
53+
if settings.check_conformance:
54+
await check_conformance(
55+
app.user_middleware,
56+
str(settings.upstream_url),
57+
)
58+
4659
yield
4760

4861
app = FastAPI(

src/stac_auth_proxy/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class Settings(BaseSettings):
3939
oidc_discovery_internal_url: HttpUrl
4040

4141
wait_for_upstream: bool = True
42+
check_conformance: bool = True
4243

4344
# Endpoints
4445
healthz_prefix: str = Field(pattern=_PREFIX_PATTERN, default="/healthz")

src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,18 @@
1313
from starlette.types import ASGIApp, Message, Receive, Scope, Send
1414

1515
from ..utils import filters
16+
from ..utils.middleware import required_conformance
1617

1718
logger = getLogger(__name__)
1819

1920

21+
@required_conformance(
22+
r"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2",
23+
r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-text",
24+
r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-json",
25+
r"http://www.opengis.net/spec/ogcapi-features-3/1.0/conf/features-filter",
26+
r"https://api.stacspec.org/v1\.\d+\.\d+(?:-[\w\.]+)?/item-search#filter",
27+
)
2028
@dataclass(frozen=True)
2129
class ApplyCql2FilterMiddleware:
2230
"""Middleware to apply the Cql2Filter to the request."""

src/stac_auth_proxy/utils/lifespan.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
import asyncio
44
import logging
5+
import re
56

67
import httpx
78
from pydantic import HttpUrl
9+
from starlette.middleware import Middleware
810

911
logger = logging.getLogger(__name__)
1012

@@ -40,3 +42,58 @@ async def check_server_health(
4042
raise RuntimeError(
4143
f"Upstream API {url!r} failed to respond after {max_retries} attempts"
4244
)
45+
46+
47+
async def log_middleware_classes(middleware_classes: list[Middleware]):
48+
"""Log the middleware classes connected to the application."""
49+
logger.debug(
50+
"Connected middleware:\n%s",
51+
"\n".join(
52+
[f"- {middleware.cls.__name__}" for middleware in middleware_classes]
53+
),
54+
)
55+
56+
57+
async def check_conformance(
58+
middleware_classes: list[Middleware],
59+
api_url: str,
60+
attr_name: str = "__required_conformances__",
61+
):
62+
"""Check if the upstream API supports a given conformance class."""
63+
required_conformances: dict[str, list[str]] = {}
64+
for middleware in middleware_classes:
65+
66+
for conformance in getattr(middleware.cls, attr_name, []):
67+
required_conformances.setdefault(conformance, []).append(
68+
middleware.cls.__name__
69+
)
70+
71+
async with httpx.AsyncClient() as client:
72+
response = await client.get(api_url)
73+
response.raise_for_status()
74+
api_conforms_to = response.json().get("conformsTo", [])
75+
missing = [
76+
req_conformance
77+
for req_conformance in required_conformances.keys()
78+
if not any(
79+
re.match(req_conformance, conformance) for conformance in api_conforms_to
80+
)
81+
]
82+
83+
def print_conformance(conformance):
84+
return f" - {conformance} [{','.join(required_conformances[conformance])}]"
85+
86+
if missing:
87+
missing_str = [print_conformance(c) for c in missing]
88+
raise RuntimeError(
89+
"\n".join(
90+
[
91+
"Upstream catalog is missing the following conformance classes:",
92+
*missing_str,
93+
]
94+
)
95+
)
96+
logger.debug(
97+
"Upstream catalog conforms to the following required conformance classes: \n%s",
98+
"\n".join([print_conformance(c) for c in required_conformances]),
99+
)

src/stac_auth_proxy/utils/middleware.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,13 @@ async def transform_response(message: Message) -> None:
9999
)
100100

101101
return await self.app(scope, receive, transform_response)
102+
103+
104+
def required_conformance(*conformances: str):
105+
"""Register required conformance classes with a middleware class."""
106+
107+
def decorator(func):
108+
func.__required_conformances__ = list(conformances)
109+
return func
110+
111+
return decorator

0 commit comments

Comments
 (0)