Skip to content

Commit 91f5e74

Browse files
committed
refactor: simplify middleware util logic
1 parent 9f8f057 commit 91f5e74

File tree

2 files changed

+33
-30
lines changed

2 files changed

+33
-30
lines changed

src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Middleware to add auth information to the OpenAPI spec served by upstream API."""
22

3+
import re
34
from dataclasses import dataclass
45
from typing import Any
56

7+
from starlette.datastructures import Headers
68
from starlette.requests import Request
79
from starlette.types import ASGIApp
810

@@ -23,9 +25,21 @@ class OpenApiMiddleware(JsonResponseMiddleware):
2325
default_public: bool
2426
oidc_auth_scheme_name: str = "oidcAuth"
2527

26-
def should_transform_response(self, request: Request) -> bool:
28+
json_content_type_expr: str = r"application/(vnd\.oai\.openapi\+json?|json)"
29+
30+
def should_transform_response(
31+
self, request: Request, response_headers: Headers
32+
) -> bool:
2733
"""Only transform responses for the OpenAPI spec path."""
28-
return request.url.path == self.openapi_spec_path
34+
return all(
35+
[
36+
re.match(self.openapi_spec_path, request.url.path),
37+
re.match(
38+
self.json_content_type_expr,
39+
response_headers.get("content-type", ""),
40+
),
41+
]
42+
)
2943

3044
def transform_json(self, openapi_spec: dict[str, Any]) -> dict[str, Any]:
3145
"""Augment the OpenAPI spec with auth information."""

src/stac_auth_proxy/utils/middleware.py

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Utilities for middleware response handling."""
22

33
import json
4-
import re
54
from abc import ABC, abstractmethod
65
from typing import Any, Optional
76

@@ -14,25 +13,20 @@ class JsonResponseMiddleware(ABC):
1413
"""Base class for middleware that transforms JSON response bodies."""
1514

1615
app: ASGIApp
17-
json_content_type_expr: str = (
18-
r"application/vnd\.oai\.openapi\+json;.*|application/json|application/geo\+json"
19-
)
2016

2117
@abstractmethod
22-
def should_transform_response(self, request: Request) -> bool:
18+
def should_transform_response(
19+
self, request: Request, response_headers: Headers
20+
) -> bool: # mypy: ignore
2321
"""
24-
Determine if this request's response should be transformed.
25-
26-
Args:
27-
request: The incoming request
22+
Determine if this response should be transformed. At a minimum, this
23+
should check the request's path and content type.
2824
2925
Returns
3026
-------
3127
bool: True if the response should be transformed
3228
"""
33-
return bool(
34-
re.match(self.json_content_type_expr, request.headers.get("accept", ""))
35-
)
29+
...
3630

3731
@abstractmethod
3832
def transform_json(self, data: Any) -> Any:
@@ -46,36 +40,31 @@ def transform_json(self, data: Any) -> Any:
4640
-------
4741
The transformed JSON data
4842
"""
49-
pass
43+
...
5044

5145
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
5246
"""Process the request/response."""
5347
if scope["type"] != "http":
5448
return await self.app(scope, receive, send)
5549

56-
request = Request(scope)
57-
if not self.should_transform_response(request):
58-
return await self.app(scope, receive, send)
59-
6050
start_message: Optional[Message] = None
6151
body = b""
62-
not_json = False
6352

64-
async def process_message(message: Message) -> None:
53+
async def transform_response(message: Message) -> None:
6554
nonlocal start_message
6655
nonlocal body
67-
nonlocal not_json
56+
6857
if message["type"] == "http.response.start":
6958
# Delay sending start message until we've processed the body
70-
if not re.match(
71-
self.json_content_type_expr,
72-
Headers(scope=message).get("content-type", ""),
73-
):
74-
not_json = True
75-
return await send(message)
7659
start_message = message
7760
return
78-
elif message["type"] != "http.response.body" or not_json:
61+
assert start_message is not None
62+
if not self.should_transform_response(
63+
request=Request(scope),
64+
response_headers=Headers(scope=start_message),
65+
):
66+
return await send(message)
67+
if message["type"] != "http.response.body":
7968
return await send(message)
8069

8170
body += message["body"]
@@ -109,4 +98,4 @@ async def process_message(message: Message) -> None:
10998
}
11099
)
111100

112-
return await self.app(scope, receive, process_message)
101+
return await self.app(scope, receive, transform_response)

0 commit comments

Comments
 (0)