Skip to content

Commit aa27887

Browse files
committed
fix: prevent JSON middleware from throwing 500s on non-transformed content
1 parent 1b8fa28 commit aa27887

File tree

2 files changed

+116
-12
lines changed

2 files changed

+116
-12
lines changed

src/stac_auth_proxy/utils/middleware.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,27 +54,27 @@ async def transform_response(message: Message) -> None:
5454
nonlocal start_message
5555
nonlocal body
5656

57-
if message["type"] == "http.response.start":
58-
# Delay sending start message until we've processed the body
59-
start_message = message
60-
return
61-
assert start_message is not None
57+
start_message = start_message or message
58+
headers = MutableHeaders(scope=start_message)
59+
6260
if not self.should_transform_response(
6361
request=Request(scope),
64-
response_headers=Headers(scope=start_message),
62+
response_headers=headers,
6563
):
66-
return await send(message)
67-
if message["type"] != "http.response.body":
68-
return await send(message)
64+
# For non-JSON responses, send the start message immediately
65+
await send(message)
66+
return
67+
68+
# Delay sending start message until we've processed the body
69+
if message["type"] == "http.response.start":
70+
return
6971

7072
body += message["body"]
7173

7274
# Skip body chunks until all chunks have been received
7375
if message.get("more_body"):
7476
return
7577

76-
headers = MutableHeaders(scope=start_message)
77-
7878
# Transform the JSON body
7979
if body:
8080
data = json.loads(body)
@@ -83,7 +83,6 @@ async def transform_response(message: Message) -> None:
8383

8484
# Update content-length header
8585
headers["content-length"] = str(len(body))
86-
assert start_message, "Expected start_message to be set"
8786
start_message["headers"] = [
8887
(key.encode(), value.encode()) for key, value in headers.items()
8988
]

tests/test_middleware.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
"""Tests for middleware utilities."""
2+
3+
from typing import Any
4+
5+
from fastapi import FastAPI, Response
6+
from starlette.datastructures import Headers
7+
from starlette.requests import Request
8+
from starlette.testclient import TestClient
9+
from starlette.types import ASGIApp
10+
11+
from stac_auth_proxy.utils.middleware import JsonResponseMiddleware
12+
13+
14+
class ExampleJsonResponseMiddleware(JsonResponseMiddleware):
15+
"""Example implementation of JsonResponseMiddleware."""
16+
17+
def __init__(self, app: ASGIApp):
18+
"""Initialize the middleware."""
19+
self.app = app
20+
21+
def should_transform_response(
22+
self, request: Request, response_headers: Headers
23+
) -> bool:
24+
"""Transform JSON responses based on content type."""
25+
return response_headers.get("content-type", "") == "application/json"
26+
27+
def transform_json(self, data: Any) -> Any:
28+
"""Add a test field to the response."""
29+
if isinstance(data, dict):
30+
data["transformed"] = True
31+
return data
32+
33+
34+
def test_json_response_middleware():
35+
"""Test that JSON responses are properly transformed."""
36+
app = FastAPI()
37+
app.add_middleware(ExampleJsonResponseMiddleware)
38+
39+
@app.get("/test")
40+
async def test_endpoint():
41+
return {"message": "test"}
42+
43+
client = TestClient(app)
44+
response = client.get("/test")
45+
assert response.status_code == 200
46+
assert response.headers["content-type"] == "application/json"
47+
data = response.json()
48+
assert data["message"] == "test"
49+
assert data["transformed"] is True
50+
51+
52+
def test_json_response_middleware_no_transform():
53+
"""Test that responses are not transformed when should_transform_response returns False."""
54+
app = FastAPI()
55+
app.add_middleware(ExampleJsonResponseMiddleware)
56+
57+
@app.get("/test")
58+
async def test_endpoint():
59+
return Response(
60+
content='{"message": "test"}',
61+
media_type="application/x-json", # Different from application/json
62+
)
63+
64+
client = TestClient(app)
65+
response = client.get("/test")
66+
assert response.status_code == 200
67+
assert "application/x-json" in response.headers["content-type"]
68+
data = response.json()
69+
assert data["message"] == "test"
70+
assert "transformed" not in data
71+
72+
73+
def test_json_response_middleware_chunked():
74+
"""Test that chunked JSON responses are properly transformed."""
75+
app = FastAPI()
76+
app.add_middleware(ExampleJsonResponseMiddleware)
77+
78+
@app.get("/test")
79+
async def test_endpoint():
80+
return {"message": "test", "large_field": "x" * 10000}
81+
82+
client = TestClient(app)
83+
response = client.get("/test")
84+
assert response.status_code == 200
85+
assert response.headers["content-type"] == "application/json"
86+
data = response.json()
87+
assert data["message"] == "test"
88+
assert data["transformed"] is True
89+
assert len(data["large_field"]) == 10000
90+
91+
92+
def test_json_response_middleware_error_handling():
93+
"""Test that JSON parsing errors are handled gracefully."""
94+
app = FastAPI()
95+
app.add_middleware(ExampleJsonResponseMiddleware)
96+
97+
@app.get("/test")
98+
async def test_endpoint():
99+
return Response(content="invalid json", media_type="text/plain")
100+
101+
client = TestClient(app)
102+
response = client.get("/test")
103+
assert response.status_code == 200
104+
assert "text/plain" in response.headers["content-type"]
105+
assert response.text == "invalid json"

0 commit comments

Comments
 (0)