11"""Utilities for middleware response handling."""
22
33import json
4- import re
54from abc import ABC , abstractmethod
65from typing import Any , Optional
76
@@ -14,25 +13,20 @@ class JsonResponseMiddleware(ABC):
1413 """Base class for middleware that transforms JSON response bodies."""
1514
1615 app : ASGIApp
17- json_content_type_expr : str = (
18- r"application/vnd\.oai\.openapi\+json;.*|application/json|application/geo\+json"
19- )
2016
2117 @abstractmethod
22- def should_transform_response (self , request : Request ) -> bool :
18+ def should_transform_response (
19+ self , request : Request , response_headers : Headers
20+ ) -> bool : # mypy: ignore
2321 """
24- Determine if this request's response should be transformed.
25-
26- Args:
27- request: The incoming request
22+ Determine if this response should be transformed. At a minimum, this
23+ should check the request's path and content type.
2824
2925 Returns
3026 -------
3127 bool: True if the response should be transformed
3228 """
33- return bool (
34- re .match (self .json_content_type_expr , request .headers .get ("accept" , "" ))
35- )
29+ ...
3630
3731 @abstractmethod
3832 def transform_json (self , data : Any ) -> Any :
@@ -46,36 +40,31 @@ def transform_json(self, data: Any) -> Any:
4640 -------
4741 The transformed JSON data
4842 """
49- pass
43+ ...
5044
5145 async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
5246 """Process the request/response."""
5347 if scope ["type" ] != "http" :
5448 return await self .app (scope , receive , send )
5549
56- request = Request (scope )
57- if not self .should_transform_response (request ):
58- return await self .app (scope , receive , send )
59-
6050 start_message : Optional [Message ] = None
6151 body = b""
62- not_json = False
6352
64- async def process_message (message : Message ) -> None :
53+ async def transform_response (message : Message ) -> None :
6554 nonlocal start_message
6655 nonlocal body
67- nonlocal not_json
56+
6857 if message ["type" ] == "http.response.start" :
6958 # Delay sending start message until we've processed the body
70- if not re .match (
71- self .json_content_type_expr ,
72- Headers (scope = message ).get ("content-type" , "" ),
73- ):
74- not_json = True
75- return await send (message )
7659 start_message = message
7760 return
78- elif message ["type" ] != "http.response.body" or not_json :
61+ assert start_message is not None
62+ if not self .should_transform_response (
63+ request = Request (scope ),
64+ response_headers = Headers (scope = start_message ),
65+ ):
66+ return await send (message )
67+ if message ["type" ] != "http.response.body" :
7968 return await send (message )
8069
8170 body += message ["body" ]
@@ -109,4 +98,4 @@ async def process_message(message: Message) -> None:
10998 }
11099 )
111100
112- return await self .app (scope , receive , process_message )
101+ return await self .app (scope , receive , transform_response )
0 commit comments