Skip to content

Commit b6cf0a5

Browse files
committed
Working
1 parent 4ac85af commit b6cf0a5

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

src/stac_auth_proxy/utils/middleware.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,17 @@
22

33
import gzip
44
import json
5+
import re
56
import zlib
67
from abc import ABC, abstractmethod
78
from typing import Any, Optional
89

910
import brotli
10-
from starlette.datastructures import MutableHeaders
11+
from starlette.datastructures import Headers, MutableHeaders
1112
from starlette.requests import Request
1213
from starlette.types import ASGIApp, Message, Receive, Scope, Send
1314

15+
# TODO: Consider using a single middleware to handle all compression/decompression
1416
ENCODING_HANDLERS = {
1517
"gzip": gzip,
1618
"deflate": zlib,
@@ -22,6 +24,9 @@ class JsonResponseMiddleware(ABC):
2224
"""Base class for middleware that transforms JSON response bodies."""
2325

2426
app: ASGIApp
27+
json_content_type_expr: str = (
28+
r"application/vnd\.oai\.openapi\+json;.*|application/json|application/geo\+json"
29+
)
2530

2631
@abstractmethod
2732
def should_transform_response(self, request: Request) -> bool:
@@ -35,7 +40,7 @@ def should_transform_response(self, request: Request) -> bool:
3540
-------
3641
bool: True if the response should be transformed
3742
"""
38-
pass
43+
return request.headers.get("accept") == "application/json"
3944

4045
@abstractmethod
4146
def transform_json(self, data: Any) -> Any:
@@ -62,16 +67,23 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
6267

6368
start_message: Optional[Message] = None
6469
body = b""
70+
not_json = False
6571

6672
async def process_message(message: Message) -> None:
6773
nonlocal start_message
6874
nonlocal body
69-
75+
nonlocal not_json
7076
if message["type"] == "http.response.start":
7177
# Delay sending start message until we've processed the body
78+
if not re.match(
79+
self.json_content_type_expr,
80+
Headers(scope=message).get("content-type", ""),
81+
):
82+
not_json = True
83+
return await send(message)
7284
start_message = message
7385
return
74-
elif message["type"] != "http.response.body":
86+
elif message["type"] != "http.response.body" or not_json:
7587
return await send(message)
7688

7789
body += message["body"]
@@ -94,9 +106,10 @@ async def process_message(message: Message) -> None:
94106
)
95107

96108
# Transform the JSON body
97-
data = json.loads(body)
98-
transformed = self.transform_json(data)
99-
body = json.dumps(transformed).encode()
109+
if body:
110+
data = json.loads(body)
111+
transformed = self.transform_json(data)
112+
body = json.dumps(transformed).encode()
100113

101114
# Re-compress if necessary
102115
if handler:

0 commit comments

Comments
 (0)