Skip to content

Commit eb8c8fb

Browse files
committed
fix: only transform non-errors
1 parent f0ec9a5 commit eb8c8fb

File tree

3 files changed

+35
-31
lines changed

3 files changed

+35
-31
lines changed

src/stac_auth_proxy/middleware/AuthenticationExtensionMiddleware.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from starlette.datastructures import Headers
1111
from starlette.requests import Request
12-
from starlette.types import ASGIApp
12+
from starlette.types import ASGIApp, Scope
1313

1414
from ..config import EndpointMethods
1515
from ..utils.middleware import JsonResponseMiddleware
@@ -38,24 +38,27 @@ class AuthenticationExtensionMiddleware(JsonResponseMiddleware):
3838

3939
state_key: str = "oidc_metadata"
4040

41-
def should_transform_response(
42-
self, request: Request, response_headers: Headers
43-
) -> bool:
41+
def should_transform_response(self, request: Request, scope: Scope) -> bool:
4442
"""Determine if the response should be transformed."""
4543
# Match STAC catalog, collection, or item URLs with a single regex
46-
return all(
47-
re.match(expr, val)
48-
for expr, val in [
44+
return (
45+
all(
4946
(
50-
# catalog, collections, collection, items, item, search
51-
r"^(/|/collections(/[^/]+(/items(/[^/]+)?)?)?|/search)$",
52-
request.url.path,
47+
re.match(expr, val)
48+
for expr, val in [
49+
(
50+
# catalog, collections, collection, items, item, search
51+
r"^(/|/collections(/[^/]+(/items(/[^/]+)?)?)?|/search)$",
52+
request.url.path,
53+
),
54+
(
55+
self.json_content_type_expr,
56+
Headers(scope=scope).get("content-type", ""),
57+
),
58+
]
5359
),
54-
(
55-
self.json_content_type_expr,
56-
response_headers.get("content-type", ""),
57-
),
58-
]
60+
)
61+
and 200 >= scope["status"] < 300
5962
)
6063

6164
def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]:

src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from starlette.datastructures import Headers
88
from starlette.requests import Request
9-
from starlette.types import ASGIApp
9+
from starlette.types import ASGIApp, Scope
1010

1111
from ..config import EndpointMethods
1212
from ..utils.middleware import JsonResponseMiddleware
@@ -27,19 +27,20 @@ class OpenApiMiddleware(JsonResponseMiddleware):
2727

2828
json_content_type_expr: str = r"application/(vnd\.oai\.openapi\+json?|json)"
2929

30-
def should_transform_response(
31-
self, request: Request, response_headers: Headers
32-
) -> bool:
30+
def should_transform_response(self, request: Request, scope: Scope) -> bool:
3331
"""Only transform responses for the OpenAPI spec path."""
34-
return all(
35-
re.match(expr, val)
36-
for expr, val in [
37-
(self.openapi_spec_path, request.url.path),
38-
(
39-
self.json_content_type_expr,
40-
response_headers.get("content-type", ""),
41-
),
42-
]
32+
return (
33+
all(
34+
re.match(expr, val)
35+
for expr, val in [
36+
(self.openapi_spec_path, request.url.path),
37+
(
38+
self.json_content_type_expr,
39+
Headers(scope=scope).get("content-type", ""),
40+
),
41+
]
42+
)
43+
and 200 >= scope["status"] < 300
4344
)
4445

4546
def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]:

src/stac_auth_proxy/utils/middleware.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from abc import ABC, abstractmethod
55
from typing import Any, Optional
66

7-
from starlette.datastructures import Headers, MutableHeaders
7+
from starlette.datastructures import MutableHeaders
88
from starlette.requests import Request
99
from starlette.types import ASGIApp, Message, Receive, Scope, Send
1010

@@ -16,7 +16,7 @@ class JsonResponseMiddleware(ABC):
1616

1717
@abstractmethod
1818
def should_transform_response(
19-
self, request: Request, response_headers: Headers
19+
self, request: Request, scope: Scope
2020
) -> bool: # mypy: ignore
2121
"""
2222
Determine if this response should be transformed. At a minimum, this
@@ -60,7 +60,7 @@ async def transform_response(message: Message) -> None:
6060

6161
if not self.should_transform_response(
6262
request=request,
63-
response_headers=headers,
63+
scope=start_message,
6464
):
6565
# For non-JSON responses, send the start message immediately
6666
await send(message)

0 commit comments

Comments
 (0)