Skip to content

Commit ac981ce

Browse files
committed
feat: augment openapi spec
1 parent 3abf298 commit ac981ce

File tree

5 files changed

+103
-10
lines changed

5 files changed

+103
-10
lines changed

src/stac_auth_proxy/app.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .proxy import ReverseProxy
1313
from .config import Settings
1414
from .middleware import AddProcessTimeHeaderMiddleware
15+
from .handlers import OpenApiSpecHandler
1516

1617

1718
def create_app(settings: Optional[Settings] = None) -> FastAPI:
@@ -41,12 +42,21 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
4142
}.items():
4243
app.add_api_route(
4344
path,
44-
proxy.passthrough,
45+
proxy.stream,
4546
methods=methods,
46-
dependencies=[Depends(open_id_connect_scheme)],
47+
)
48+
49+
# Endpoint with special OpenAPI transformation functionality
50+
if settings.openapi_spec_endpoint:
51+
app.add_api_route(
52+
settings.openapi_spec_endpoint,
53+
OpenApiSpecHandler(
54+
proxy=proxy, oidc_config_url=str(settings.oidc_discovery_url)
55+
).dispatch,
56+
methods=["GET"],
4757
)
4858

4959
# Catchall proxy
50-
app.add_route("/{path:path}", proxy.passthrough)
60+
app.add_route("/{path:path}", proxy.stream)
5161

5262
return app

src/stac_auth_proxy/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from typing import Optional
12
from pydantic.networks import HttpUrl
23
from pydantic_settings import BaseSettings
34

45

56
class Settings(BaseSettings):
67
upstream_url: HttpUrl = "https://earth-search.aws.element84.com/v1"
78
oidc_discovery_url: HttpUrl
9+
openapi_spec_endpoint: Optional[str] = None
810

911
class Config:
1012
env_prefix = "STAC_AUTH_PROXY_"

src/stac_auth_proxy/handlers.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from dataclasses import dataclass
2+
import logging
3+
4+
from fastapi import Request, Response
5+
from fastapi.routing import APIRoute
6+
7+
from .proxy import ReverseProxy
8+
from .utils import safe_headers
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
@dataclass
14+
class OpenApiSpecHandler:
15+
proxy: ReverseProxy
16+
oidc_config_url: str
17+
auth_scheme_name: str = "oidcAuth"
18+
19+
async def dispatch(self, req: Request, res: Response):
20+
"""
21+
Proxy the OpenAPI spec from the upstream STAC API, updating it with OIDC security
22+
requirements.
23+
"""
24+
oidc_spec_response = await self.proxy.proxy_request(req)
25+
openapi_spec = oidc_spec_response.json()
26+
27+
# Pass along the response headers
28+
res.headers.update(safe_headers(oidc_spec_response.headers))
29+
30+
# Add the OIDC security scheme to the components
31+
openapi_spec.setdefault("components", {}).setdefault("securitySchemes", {})[
32+
self.auth_scheme_name
33+
] = {
34+
"type": "openIdConnect",
35+
"openIdConnectUrl": self.oidc_config_url,
36+
}
37+
38+
proxy_auth_routes = [
39+
r
40+
for r in req.app.routes
41+
# Ignore non-APIRoutes (we can't check their security dependencies)
42+
if isinstance(r, APIRoute)
43+
# Ignore routes that don't have security requirements
44+
and (
45+
r.dependant.security_requirements
46+
or any(d.security_requirements for d in r.dependant.dependencies)
47+
)
48+
]
49+
50+
# Update the paths with the specified security requirements
51+
for path, method_config in openapi_spec["paths"].items():
52+
for method, config in method_config.items():
53+
for route in proxy_auth_routes:
54+
match, _ = route.matches(
55+
{"type": "http", "method": method.upper(), "path": path}
56+
)
57+
if match.name != "FULL":
58+
continue
59+
# Add the OIDC security requirement
60+
config.setdefault("security", []).append(
61+
[{self.auth_scheme_name: []}]
62+
)
63+
break
64+
65+
return openapi_spec

src/stac_auth_proxy/proxy.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
import logging
22
import time
33
from dataclasses import dataclass
4-
from urllib.parse import urlparse
54

5+
import httpx
66
from fastapi import Request
7-
87
from starlette.datastructures import MutableHeaders
98
from starlette.responses import StreamingResponse
109
from starlette.background import BackgroundTask
1110

12-
import httpx
1311

1412
logger = logging.getLogger(__name__)
1513

@@ -25,9 +23,7 @@ def __post_init__(self):
2523
timeout=httpx.Timeout(timeout=15.0),
2624
)
2725

28-
async def passthrough(self, request: Request):
29-
"""Transparently proxy a request to the upstream STAC API."""
30-
26+
async def proxy_request(self, request: Request, *, stream=False) -> httpx.Response:
3127
headers = MutableHeaders(request.headers)
3228

3329
# https://github.com/fastapi/fastapi/discussions/7382#discussioncomment-5136466
@@ -43,14 +39,18 @@ async def passthrough(self, request: Request):
4339
logger.debug(f"Proxying request to {rp_req.url}")
4440

4541
start_time = time.perf_counter()
46-
rp_resp = await self.client.send(rp_req, stream=True)
42+
rp_resp = await self.client.send(rp_req, stream=stream)
4743
proxy_time = time.perf_counter() - start_time
4844

4945
logger.debug(
5046
f"Received response status {rp_resp.status_code!r} from {rp_req.url} in {proxy_time:.3f}s"
5147
)
5248
rp_resp.headers["X-Upstream-Time"] = f"{proxy_time:.3f}"
49+
return rp_resp
5350

51+
async def stream(self, request: Request) -> StreamingResponse:
52+
"""Transparently proxy a request to the upstream STAC API."""
53+
rp_resp = await self.proxy_request(request, stream=True)
5454
return StreamingResponse(
5555
rp_resp.aiter_raw(),
5656
status_code=rp_resp.status_code,

src/stac_auth_proxy/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from httpx import Headers
2+
3+
4+
def safe_headers(headers: Headers) -> dict[str, str]:
5+
"""
6+
Scrub headers that should not be proxied to the client.
7+
"""
8+
excluded_headers = [
9+
"content-length",
10+
"content-encoding",
11+
]
12+
return {
13+
key: value
14+
for key, value in headers.items()
15+
if key.lower() not in excluded_headers
16+
}

0 commit comments

Comments
 (0)