Skip to content

Commit eb1435a

Browse files
committed
pass in request rather than scope, use state_key
1 parent d583363 commit eb1435a

File tree

4 files changed

+23
-22
lines changed

4 files changed

+23
-22
lines changed

src/stac_auth_proxy/middleware/AuthenticationExtensionMiddleware.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from starlette.datastructures import Headers
1111
from starlette.requests import Request
12-
from starlette.types import ASGIApp, Scope
12+
from starlette.types import ASGIApp
1313

1414
from ..config import EndpointMethods
1515
from ..utils.middleware import JsonResponseMiddleware
@@ -36,6 +36,8 @@ class AuthenticationExtensionMiddleware(JsonResponseMiddleware):
3636

3737
json_content_type_expr: str = r"(application/json|geo\+json)"
3838

39+
state_key: str = "oidc_metadata"
40+
3941
def should_transform_response(
4042
self, request: Request, response_headers: Headers
4143
) -> bool:
@@ -55,9 +57,9 @@ def should_transform_response(
5557
]
5658
)
5759

58-
def transform_json(self, doc: dict[str, Any], scope: Scope) -> dict[str, Any]:
60+
def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]:
5961
"""Augment the STAC Item with auth information."""
60-
extensions = doc.setdefault("stac_extensions", [])
62+
extensions = data.setdefault("stac_extensions", [])
6163
if self.extension_url not in extensions:
6264
extensions.append(self.extension_url)
6365

@@ -70,30 +72,30 @@ def transform_json(self, doc: dict[str, Any], scope: Scope) -> dict[str, Any]:
7072
# - Collections
7173
# - Item Properties
7274

73-
if "oidc_metadata" not in scope:
75+
if self.state_key not in request.state:
7476
logger.error(
7577
"OIDC metadata not found in scope. "
7678
"Skipping authentication extension."
7779
)
78-
return doc
80+
return data
7981

80-
scheme_loc = doc["properties"] if "properties" in doc else doc
82+
scheme_loc = data["properties"] if "properties" in data else data
8183
schemes = scheme_loc.setdefault("auth:schemes", {})
8284
schemes[self.auth_scheme_name] = self.parse_oidc_config(
83-
scope.get("oidc_metadata", {})
85+
request.state.get(self.state_key, {})
8486
)
8587

8688
# auth:refs
8789
# ---
8890
# Annotate links with "auth:refs": [auth_scheme]
8991
links = chain(
9092
# Item/Collection
91-
doc.get("links", []),
93+
data.get("links", []),
9294
# Collections/Items/Search
9395
(
9496
link
9597
for prop in ["features", "collections"]
96-
for object_with_links in doc.get(prop, [])
98+
for object_with_links in data.get(prop, [])
9799
for link in object_with_links.get("links", [])
98100
),
99101
)
@@ -111,7 +113,7 @@ def transform_json(self, doc: dict[str, Any], scope: Scope) -> dict[str, Any]:
111113
if match.is_private:
112114
link.setdefault("auth:refs", []).append(self.auth_scheme_name)
113115

114-
return doc
116+
return data
115117

116118
def parse_oidc_config(self, oidc_config: dict[str, Any]) -> dict[str, Any]:
117119
"""Parse the OIDC configuration."""

src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from starlette.datastructures import Headers
88
from starlette.requests import Request
9-
from starlette.types import ASGIApp, Scope
9+
from starlette.types import ASGIApp
1010

1111
from ..config import EndpointMethods
1212
from ..utils.middleware import JsonResponseMiddleware
@@ -41,17 +41,15 @@ def should_transform_response(
4141
]
4242
)
4343

44-
def transform_json(
45-
self, openapi_spec: dict[str, Any], scope: Scope
46-
) -> dict[str, Any]:
44+
def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]:
4745
"""Augment the OpenAPI spec with auth information."""
48-
components = openapi_spec.setdefault("components", {})
46+
components = data.setdefault("components", {})
4947
securitySchemes = components.setdefault("securitySchemes", {})
5048
securitySchemes[self.oidc_auth_scheme_name] = {
5149
"type": "openIdConnect",
5250
"openIdConnectUrl": self.oidc_config_url,
5351
}
54-
for path, method_config in openapi_spec["paths"].items():
52+
for path, method_config in data["paths"].items():
5553
for method, config in method_config.items():
5654
match = find_match(
5755
path,
@@ -64,4 +62,4 @@ def transform_json(
6462
config.setdefault("security", []).append(
6563
{self.oidc_auth_scheme_name: match.required_scopes}
6664
)
67-
return openapi_spec
65+
return data

src/stac_auth_proxy/utils/middleware.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def should_transform_response(
2929
...
3030

3131
@abstractmethod
32-
def transform_json(self, data: Any, scope: Scope) -> Any:
32+
def transform_json(self, data: Any, request: Request) -> Any:
3333
"""
3434
Transform the JSON data.
3535
@@ -56,9 +56,10 @@ async def transform_response(message: Message) -> None:
5656

5757
start_message = start_message or message
5858
headers = MutableHeaders(scope=start_message)
59+
request = Request(scope)
5960

6061
if not self.should_transform_response(
61-
request=Request(scope),
62+
request=request,
6263
response_headers=headers,
6364
):
6465
# For non-JSON responses, send the start message immediately
@@ -78,7 +79,7 @@ async def transform_response(message: Message) -> None:
7879
# Transform the JSON body
7980
if body:
8081
data = json.loads(body)
81-
transformed = self.transform_json(data, scope=scope)
82+
transformed = self.transform_json(data, request=request)
8283
body = json.dumps(transformed).encode()
8384

8485
# Update content-length header

tests/test_middleware.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from starlette.datastructures import Headers
77
from starlette.requests import Request
88
from starlette.testclient import TestClient
9-
from starlette.types import ASGIApp, Scope
9+
from starlette.types import ASGIApp
1010

1111
from stac_auth_proxy.utils.middleware import JsonResponseMiddleware
1212

@@ -24,7 +24,7 @@ def should_transform_response(
2424
"""Transform JSON responses based on content type."""
2525
return response_headers.get("content-type", "") == "application/json"
2626

27-
def transform_json(self, data: Any, scope: Scope) -> Any:
27+
def transform_json(self, data: Any, request: Request) -> Any:
2828
"""Add a test field to the response."""
2929
if isinstance(data, dict):
3030
data["transformed"] = True

0 commit comments

Comments
 (0)