Skip to content

Commit 341bdc6

Browse files
committed
Refactor API spec augmentation to middleware
1 parent 1e3d823 commit 341bdc6

File tree

4 files changed

+107
-97
lines changed

4 files changed

+107
-97
lines changed

src/stac_auth_proxy/app.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,16 @@
88
import logging
99
from typing import Optional
1010

11-
from fastapi import FastAPI, Security
11+
from fastapi import FastAPI
1212

1313
from .auth import OpenIdConnectAuth
1414
from .config import Settings
15-
from .handlers import ReverseProxyHandler, build_openapi_spec_handler
16-
from .middleware import AddProcessTimeHeaderMiddleware
15+
from .handlers import ReverseProxyHandler
16+
from .middleware import (
17+
AddProcessTimeHeaderMiddleware,
18+
EnforceAuthMiddleware,
19+
OpenApiMiddleware,
20+
)
1721

1822
logger = logging.getLogger(__name__)
1923

@@ -25,7 +29,17 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
2529
app = FastAPI(
2630
openapi_url=None, # Disable OpenAPI schema endpoint, we want to serve upstream's schema
2731
)
32+
2833
app.add_middleware(AddProcessTimeHeaderMiddleware)
34+
if settings.openapi_spec_endpoint:
35+
app.add_middleware(
36+
OpenApiMiddleware,
37+
openapi_spec_path=settings.openapi_spec_endpoint,
38+
oidc_config_url=str(settings.oidc_discovery_url),
39+
private_endpoints=settings.private_endpoints,
40+
default_public=settings.default_public,
41+
)
42+
app.add_middleware(EnforceAuthMiddleware)
2943

3044
if settings.debug:
3145
app.add_api_route(
@@ -44,37 +58,33 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
4458
collections_filter=settings.collections_filter,
4559
items_filter=settings.items_filter,
4660
)
47-
openapi_handler = build_openapi_spec_handler(
48-
proxy=proxy_handler,
49-
oidc_config_url=str(settings.oidc_discovery_url),
50-
)
5161

52-
# Configure security dependency for explicitely specified endpoints
53-
for path_methods, dependencies in [
54-
(settings.private_endpoints, [Security(auth_scheme.validated_user)]),
55-
(settings.public_endpoints, []),
56-
]:
57-
for path, methods in path_methods.items():
58-
endpoint = (
59-
openapi_handler
60-
if path == settings.openapi_spec_endpoint
61-
else proxy_handler.stream
62-
)
63-
app.add_api_route(
64-
path,
65-
endpoint=endpoint,
66-
methods=methods,
67-
dependencies=dependencies,
68-
)
62+
# # Configure security dependency for explicitely specified endpoints
63+
# for path_methods, dependencies in [
64+
# (settings.private_endpoints, [Security(auth_scheme.validated_user)]),
65+
# (settings.public_endpoints, []),
66+
# ]:
67+
# for path, methods in path_methods.items():
68+
# endpoint = (
69+
# openapi_handler
70+
# if path == settings.openapi_spec_endpoint
71+
# else proxy_handler.stream
72+
# )
73+
# app.add_api_route(
74+
# path,
75+
# endpoint=endpoint,
76+
# methods=methods,
77+
# dependencies=dependencies,
78+
# )
6979

7080
# Catchall for remainder of the endpoints
7181
app.add_api_route(
7282
"/{path:path}",
7383
proxy_handler.stream,
7484
methods=["GET", "POST", "PUT", "PATCH", "DELETE"],
75-
dependencies=(
76-
[] if settings.default_public else [Security(auth_scheme.validated_user)]
77-
),
85+
# dependencies=(
86+
# [] if settings.default_public else [Security(auth_scheme.validated_user)]
87+
# ),
7888
)
7989

8090
return app
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Handlers to process requests."""
22

3-
from .open_api_spec import build_openapi_spec_handler
43
from .reverse_proxy import ReverseProxyHandler
54

6-
__all__ = ["build_openapi_spec_handler", "ReverseProxyHandler"]
5+
__all__ = ["ReverseProxyHandler"]

src/stac_auth_proxy/handlers/open_api_spec.py

Lines changed: 0 additions & 68 deletions
This file was deleted.

src/stac_auth_proxy/middleware.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
"""Custom middleware."""
22

3+
from dataclasses import dataclass
4+
from typing import Any
5+
import json
36
import time
47

58
from fastapi import Request, Response
69
from starlette.middleware.base import BaseHTTPMiddleware
10+
from starlette.types import ASGIApp, Message, Scope, Receive, Send
11+
12+
from .config import EndpointMethods
713

814

915
class AddProcessTimeHeaderMiddleware(BaseHTTPMiddleware):
@@ -16,3 +22,66 @@ async def dispatch(self, request: Request, call_next) -> Response:
1622
process_time = time.perf_counter() - start_time
1723
response.headers["X-Process-Time"] = f"{process_time:.3f}"
1824
return response
25+
26+
27+
@dataclass(frozen=True)
28+
class OpenApiMiddleware:
29+
"""Middleware to add the OpenAPI spec to the response."""
30+
31+
app: ASGIApp
32+
openapi_spec_path: str
33+
oidc_config_url: str
34+
private_endpoints: EndpointMethods
35+
default_public: bool
36+
oidc_auth_scheme_name: str = "oidcAuth"
37+
38+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
39+
"""Add the OpenAPI spec to the response."""
40+
if scope["type"] != "http":
41+
return await self.app(scope, receive, send)
42+
43+
async def augment_oidc_spec(message: Message):
44+
if message["type"] != "http.response.body":
45+
return await send(message)
46+
47+
# TODO: Make more robust to handle non-JSON responses
48+
body = json.loads(message["body"])
49+
50+
await send(
51+
{
52+
"type": "http.response.body",
53+
"body": dict_to_bytes(self.augment_spec(body)),
54+
}
55+
)
56+
57+
return await self.app(scope, receive, augment_oidc_spec)
58+
59+
def augment_spec(self, openapi_spec) -> dict[str, Any]:
60+
components = openapi_spec.setdefault("components", {})
61+
securitySchemes = components.setdefault("securitySchemes", {})
62+
securitySchemes[self.oidc_auth_scheme_name] = {
63+
"type": "openIdConnect",
64+
"openIdConnectUrl": self.oidc_config_url,
65+
}
66+
for path, method_config in openapi_spec["paths"].items():
67+
for method, config in method_config.items():
68+
for private_method in self.private_endpoints.get(path, []):
69+
if method.lower() == private_method.lower():
70+
config.setdefault("security", []).append(
71+
{self.oidc_auth_scheme_name: []}
72+
)
73+
return openapi_spec
74+
75+
76+
# TODO
77+
class EnforceAuthMiddleware(BaseHTTPMiddleware):
78+
"""Middleware to enforce authentication."""
79+
80+
async def dispatch(self, request: Request, call_next) -> Response:
81+
"""Enforce authentication."""
82+
return await call_next(request)
83+
84+
85+
def dict_to_bytes(d: dict) -> bytes:
86+
"""Convert a dictionary to a body."""
87+
return json.dumps(d, separators=(",", ":")).encode("utf-8")

0 commit comments

Comments
 (0)