22
33import gzip
44import json
5+ import re
56import zlib
67from abc import ABC , abstractmethod
78from typing import Any , Optional
89
910import brotli
10- from starlette .datastructures import MutableHeaders
11+ from starlette .datastructures import Headers , MutableHeaders
1112from starlette .requests import Request
1213from starlette .types import ASGIApp , Message , Receive , Scope , Send
1314
15+ # TODO: Consider using a single middleware to handle all compression/decompression
1416ENCODING_HANDLERS = {
1517 "gzip" : gzip ,
1618 "deflate" : zlib ,
@@ -22,6 +24,9 @@ class JsonResponseMiddleware(ABC):
2224 """Base class for middleware that transforms JSON response bodies."""
2325
2426 app : ASGIApp
27+ json_content_type_expr : str = (
28+ r"application/vnd\.oai\.openapi\+json;.*|application/json|application/geo\+json"
29+ )
2530
2631 @abstractmethod
2732 def should_transform_response (self , request : Request ) -> bool :
@@ -35,7 +40,7 @@ def should_transform_response(self, request: Request) -> bool:
3540 -------
3641 bool: True if the response should be transformed
3742 """
38- pass
43+ return request . headers . get ( "accept" ) == "application/json"
3944
4045 @abstractmethod
4146 def transform_json (self , data : Any ) -> Any :
@@ -62,16 +67,23 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
6267
6368 start_message : Optional [Message ] = None
6469 body = b""
70+ not_json = False
6571
6672 async def process_message (message : Message ) -> None :
6773 nonlocal start_message
6874 nonlocal body
69-
75+ nonlocal not_json
7076 if message ["type" ] == "http.response.start" :
7177 # Delay sending start message until we've processed the body
78+ if not re .match (
79+ self .json_content_type_expr ,
80+ Headers (scope = message ).get ("content-type" , "" ),
81+ ):
82+ not_json = True
83+ return await send (message )
7284 start_message = message
7385 return
74- elif message ["type" ] != "http.response.body" :
86+ elif message ["type" ] != "http.response.body" or not_json :
7587 return await send (message )
7688
7789 body += message ["body" ]
@@ -94,9 +106,10 @@ async def process_message(message: Message) -> None:
94106 )
95107
96108 # Transform the JSON body
97- data = json .loads (body )
98- transformed = self .transform_json (data )
99- body = json .dumps (transformed ).encode ()
109+ if body :
110+ data = json .loads (body )
111+ transformed = self .transform_json (data )
112+ body = json .dumps (transformed ).encode ()
100113
101114 # Re-compress if necessary
102115 if handler :
0 commit comments