Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/stac_auth_proxy/utils/middleware.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
"""Utilities for middleware response handling."""

import json
import logging
from abc import ABC, abstractmethod
from typing import Any, Optional

from starlette.datastructures import MutableHeaders
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.types import ASGIApp, Message, Receive, Scope, Send

logger = logging.getLogger(__name__)


class JsonResponseMiddleware(ABC):
"""Base class for middleware that transforms JSON response bodies."""
Expand Down Expand Up @@ -78,7 +82,18 @@ async def transform_response(message: Message) -> None:

# Transform the JSON body
if body:
data = json.loads(body)
try:
data = json.loads(body)
except json.JSONDecodeError as e:
logger.error("Error parsing JSON: %s", e)
logger.error("Body: %s", body)
logger.error("Response scope: %s", scope)
response = JSONResponse(
{"error": "Received invalid JSON from upstream server"},
status_code=502,
)
await response(scope, receive, send)
return
transformed = self.transform_json(data, request=request)
body = json.dumps(transformed).encode()

Expand Down
18 changes: 18 additions & 0 deletions tests/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,21 @@ async def test_endpoint():
assert response.status_code == 200
assert "text/plain" in response.headers["content-type"]
assert response.text == "invalid json"


def test_json_response_middleware_invalid_json_upstream():
"""Test that invalid JSON from upstream server returns 502 error."""
app = FastAPI()
app.add_middleware(ExampleJsonResponseMiddleware)

@app.get("/test")
async def test_endpoint():
# Return invalid JSON with JSON content type to trigger the error handling
return Response(content="invalid json content", media_type="application/json")

client = TestClient(app)
response = client.get("/test")
assert response.status_code == 502
assert response.headers["content-type"] == "application/json"
data = response.json()
assert data == {"error": "Received invalid JSON from upstream server"}
Loading