Skip to content

Commit 9b7b9f2

Browse files
committed
fix: prevent JSON middleware from throwing 500s on non-transformed content
1 parent 1b8fa28 commit 9b7b9f2

File tree

2 files changed

+114
-9
lines changed

2 files changed

+114
-9
lines changed

src/stac_auth_proxy/utils/middleware.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,19 @@ async def transform_response(message: Message) -> None:
5454
nonlocal start_message
5555
nonlocal body
5656

57-
if message["type"] == "http.response.start":
58-
# Delay sending start message until we've processed the body
59-
start_message = message
60-
return
61-
assert start_message is not None
57+
start_message = start_message or message
58+
6259
if not self.should_transform_response(
6360
request=Request(scope),
6461
response_headers=Headers(scope=start_message),
6562
):
66-
return await send(message)
67-
if message["type"] != "http.response.body":
68-
return await send(message)
63+
# For non-JSON responses, send the start message immediately
64+
await send(message)
65+
return
66+
67+
# Delay sending start message until we've processed the body
68+
if message["type"] == "http.response.start":
69+
return
6970

7071
body += message["body"]
7172

@@ -83,7 +84,6 @@ async def transform_response(message: Message) -> None:
8384

8485
# Update content-length header
8586
headers["content-length"] = str(len(body))
86-
assert start_message, "Expected start_message to be set"
8787
start_message["headers"] = [
8888
(key.encode(), value.encode()) for key, value in headers.items()
8989
]

tests/test_middleware.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
"""Tests for middleware utilities."""
2+
3+
from typing import Any
4+
5+
from fastapi import FastAPI, Response
6+
from starlette.datastructures import Headers
7+
from starlette.requests import Request
8+
from starlette.testclient import TestClient
9+
from starlette.types import ASGIApp
10+
11+
from stac_auth_proxy.utils.middleware import JsonResponseMiddleware
12+
13+
14+
class ExampleJsonResponseMiddleware(JsonResponseMiddleware):
15+
"""Example implementation of JsonResponseMiddleware."""
16+
17+
def __init__(self, app: ASGIApp):
18+
"""Initialize the middleware."""
19+
self.app = app
20+
21+
def should_transform_response(
22+
self, request: Request, response_headers: Headers
23+
) -> bool:
24+
"""Transform JSON responses based on content type."""
25+
return response_headers.get("content-type", "") == "application/json"
26+
27+
def transform_json(self, data: Any) -> Any:
28+
"""Add a test field to the response."""
29+
if isinstance(data, dict):
30+
data["transformed"] = True
31+
return data
32+
33+
34+
def test_json_response_middleware():
35+
"""Test that JSON responses are properly transformed."""
36+
app = FastAPI()
37+
app.add_middleware(ExampleJsonResponseMiddleware)
38+
39+
@app.get("/test")
40+
async def test_endpoint():
41+
return {"message": "test"}
42+
43+
client = TestClient(app)
44+
response = client.get("/test")
45+
assert response.status_code == 200
46+
assert response.headers["content-type"] == "application/json"
47+
data = response.json()
48+
assert data["message"] == "test"
49+
assert data["transformed"] is True
50+
51+
52+
def test_json_response_middleware_no_transform():
53+
"""Test that responses are not transformed when should_transform_response returns False."""
54+
app = FastAPI()
55+
app.add_middleware(ExampleJsonResponseMiddleware)
56+
57+
@app.get("/test")
58+
async def test_endpoint():
59+
return Response(
60+
content='{"message": "test"}',
61+
media_type="application/x-json", # Different from application/json
62+
)
63+
64+
client = TestClient(app)
65+
response = client.get("/test")
66+
assert response.status_code == 200
67+
assert "application/x-json" in response.headers["content-type"]
68+
data = response.json()
69+
assert data["message"] == "test"
70+
assert "transformed" not in data
71+
72+
73+
def test_json_response_middleware_chunked():
74+
"""Test that chunked JSON responses are properly transformed."""
75+
app = FastAPI()
76+
app.add_middleware(ExampleJsonResponseMiddleware)
77+
78+
@app.get("/test")
79+
async def test_endpoint():
80+
return {"message": "test", "large_field": "x" * 10000}
81+
82+
client = TestClient(app)
83+
response = client.get("/test")
84+
assert response.status_code == 200
85+
assert response.headers["content-type"] == "application/json"
86+
data = response.json()
87+
assert data["message"] == "test"
88+
assert data["transformed"] is True
89+
assert len(data["large_field"]) == 10000
90+
91+
92+
def test_json_response_middleware_error_handling():
93+
"""Test that JSON parsing errors are handled gracefully."""
94+
app = FastAPI()
95+
app.add_middleware(ExampleJsonResponseMiddleware)
96+
97+
@app.get("/test")
98+
async def test_endpoint():
99+
return Response(content="invalid json", media_type="text/plain")
100+
101+
client = TestClient(app)
102+
response = client.get("/test")
103+
assert response.status_code == 200
104+
assert "text/plain" in response.headers["content-type"]
105+
assert response.text == "invalid json"

0 commit comments

Comments
 (0)