Skip to content

Commit 66a03ad

Browse files
committed
Merge branch 'feature/authentication-extension' into authentication-ext/asset-signing
2 parents 2586214 + b8fbb63 commit 66a03ad

File tree

3 files changed

+51
-37
lines changed

3 files changed

+51
-37
lines changed

src/stac_auth_proxy/middleware/AuthenticationExtensionMiddleware.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import httpx
1111
from pydantic import HttpUrl
12+
from starlette.datastructures import Headers
1213
from starlette.requests import Request
1314
from starlette.types import ASGIApp
1415

@@ -40,6 +41,8 @@ class AuthenticationExtensionMiddleware(JsonResponseMiddleware):
4041
"https://stac-extensions.github.io/authentication/v1.1.0/schema.json"
4142
)
4243

44+
json_content_type_expr: str = r"application/json|geo\+json)"
45+
4346
def __post_init__(self):
4447
"""Load after initialization."""
4548
if self.oidc_config_url and not self.auth_scheme:
@@ -60,15 +63,23 @@ def __post_init__(self):
6063
},
6164
}
6265

63-
def should_transform_response(self, request: Request) -> bool:
66+
def should_transform_response(
67+
self, request: Request, response_headers: Headers
68+
) -> bool:
6469
"""Determine if the response should be transformed."""
6570
# Match STAC catalog, collection, or item URLs with a single regex
66-
return bool(
67-
re.match(
68-
# catalog, collections, collection, items, item, search
69-
r"^(/|/collections(/[^/]+(/items(/[^/]+)?)?)?|/search)$",
70-
request.url.path,
71-
)
71+
return all(
72+
[
73+
re.match(
74+
# catalog, collections, collection, items, item, search
75+
r"^(/|/collections(/[^/]+(/items(/[^/]+)?)?)?|/search)$",
76+
request.url.path,
77+
),
78+
re.match(
79+
self.json_content_type_expr,
80+
response_headers.get("content-type", ""),
81+
),
82+
]
7283
)
7384

7485
def transform_json(self, doc: dict[str, Any]) -> dict[str, Any]:

src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Middleware to add auth information to the OpenAPI spec served by upstream API."""
22

3+
import re
34
from dataclasses import dataclass
45
from typing import Any
56

7+
from starlette.datastructures import Headers
68
from starlette.requests import Request
79
from starlette.types import ASGIApp
810

@@ -23,9 +25,21 @@ class OpenApiMiddleware(JsonResponseMiddleware):
2325
default_public: bool
2426
oidc_auth_scheme_name: str = "oidcAuth"
2527

26-
def should_transform_response(self, request: Request) -> bool:
28+
json_content_type_expr: str = r"application/(vnd\.oai\.openapi\+json?|json)"
29+
30+
def should_transform_response(
31+
self, request: Request, response_headers: Headers
32+
) -> bool:
2733
"""Only transform responses for the OpenAPI spec path."""
28-
return request.url.path == self.openapi_spec_path
34+
return all(
35+
[
36+
re.match(self.openapi_spec_path, request.url.path),
37+
re.match(
38+
self.json_content_type_expr,
39+
response_headers.get("content-type", ""),
40+
),
41+
]
42+
)
2943

3044
def transform_json(self, openapi_spec: dict[str, Any]) -> dict[str, Any]:
3145
"""Augment the OpenAPI spec with auth information."""

src/stac_auth_proxy/utils/middleware.py

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Utilities for middleware response handling."""
22

33
import json
4-
import re
54
from abc import ABC, abstractmethod
65
from typing import Any, Optional
76

@@ -14,25 +13,20 @@ class JsonResponseMiddleware(ABC):
1413
"""Base class for middleware that transforms JSON response bodies."""
1514

1615
app: ASGIApp
17-
json_content_type_expr: str = (
18-
r"application/vnd\.oai\.openapi\+json;.*|application/json|application/geo\+json"
19-
)
2016

2117
@abstractmethod
22-
def should_transform_response(self, request: Request) -> bool:
18+
def should_transform_response(
19+
self, request: Request, response_headers: Headers
20+
) -> bool: # mypy: ignore
2321
"""
24-
Determine if this request's response should be transformed.
25-
26-
Args:
27-
request: The incoming request
22+
Determine if this response should be transformed. At a minimum, this
23+
should check the request's path and content type.
2824
2925
Returns
3026
-------
3127
bool: True if the response should be transformed
3228
"""
33-
return bool(
34-
re.match(self.json_content_type_expr, request.headers.get("accept", ""))
35-
)
29+
...
3630

3731
@abstractmethod
3832
def transform_json(self, data: Any) -> Any:
@@ -46,36 +40,31 @@ def transform_json(self, data: Any) -> Any:
4640
-------
4741
The transformed JSON data
4842
"""
49-
pass
43+
...
5044

5145
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
5246
"""Process the request/response."""
5347
if scope["type"] != "http":
5448
return await self.app(scope, receive, send)
5549

56-
request = Request(scope)
57-
if not self.should_transform_response(request):
58-
return await self.app(scope, receive, send)
59-
6050
start_message: Optional[Message] = None
6151
body = b""
62-
not_json = False
6352

64-
async def process_message(message: Message) -> None:
53+
async def transform_response(message: Message) -> None:
6554
nonlocal start_message
6655
nonlocal body
67-
nonlocal not_json
56+
6857
if message["type"] == "http.response.start":
6958
# Delay sending start message until we've processed the body
70-
if not re.match(
71-
self.json_content_type_expr,
72-
Headers(scope=message).get("content-type", ""),
73-
):
74-
not_json = True
75-
return await send(message)
7659
start_message = message
7760
return
78-
elif message["type"] != "http.response.body" or not_json:
61+
assert start_message is not None
62+
if not self.should_transform_response(
63+
request=Request(scope),
64+
response_headers=Headers(scope=start_message),
65+
):
66+
return await send(message)
67+
if message["type"] != "http.response.body":
7968
return await send(message)
8069

8170
body += message["body"]
@@ -109,4 +98,4 @@ async def process_message(message: Message) -> None:
10998
}
11099
)
111100

112-
return await self.app(scope, receive, process_message)
101+
return await self.app(scope, receive, transform_response)

0 commit comments

Comments
 (0)