1
1
"""Utilities for middleware response handling."""
2
2
3
3
import json
4
- import re
5
4
from abc import ABC , abstractmethod
6
5
from typing import Any , Optional
7
6
@@ -14,25 +13,20 @@ class JsonResponseMiddleware(ABC):
14
13
"""Base class for middleware that transforms JSON response bodies."""
15
14
16
15
app : ASGIApp
17
- json_content_type_expr : str = (
18
- r"application/vnd\.oai\.openapi\+json;.*|application/json|application/geo\+json"
19
- )
20
16
21
17
@abstractmethod
22
- def should_transform_response (self , request : Request ) -> bool :
18
+ def should_transform_response (
19
+ self , request : Request , response_headers : Headers
20
+ ) -> bool : # mypy: ignore
23
21
"""
24
- Determine if this request's response should be transformed.
25
-
26
- Args:
27
- request: The incoming request
22
+ Determine if this response should be transformed. At a minimum, this
23
+ should check the request's path and content type.
28
24
29
25
Returns
30
26
-------
31
27
bool: True if the response should be transformed
32
28
"""
33
- return bool (
34
- re .match (self .json_content_type_expr , request .headers .get ("accept" , "" ))
35
- )
29
+ ...
36
30
37
31
@abstractmethod
38
32
def transform_json (self , data : Any ) -> Any :
@@ -46,36 +40,31 @@ def transform_json(self, data: Any) -> Any:
46
40
-------
47
41
The transformed JSON data
48
42
"""
49
- pass
43
+ ...
50
44
51
45
async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
52
46
"""Process the request/response."""
53
47
if scope ["type" ] != "http" :
54
48
return await self .app (scope , receive , send )
55
49
56
- request = Request (scope )
57
- if not self .should_transform_response (request ):
58
- return await self .app (scope , receive , send )
59
-
60
50
start_message : Optional [Message ] = None
61
51
body = b""
62
- not_json = False
63
52
64
- async def process_message (message : Message ) -> None :
53
+ async def transform_response (message : Message ) -> None :
65
54
nonlocal start_message
66
55
nonlocal body
67
- nonlocal not_json
56
+
68
57
if message ["type" ] == "http.response.start" :
69
58
# Delay sending start message until we've processed the body
70
- if not re .match (
71
- self .json_content_type_expr ,
72
- Headers (scope = message ).get ("content-type" , "" ),
73
- ):
74
- not_json = True
75
- return await send (message )
76
59
start_message = message
77
60
return
78
- elif message ["type" ] != "http.response.body" or not_json :
61
+ assert start_message is not None
62
+ if not self .should_transform_response (
63
+ request = Request (scope ),
64
+ response_headers = Headers (scope = start_message ),
65
+ ):
66
+ return await send (message )
67
+ if message ["type" ] != "http.response.body" :
79
68
return await send (message )
80
69
81
70
body += message ["body" ]
@@ -109,4 +98,4 @@ async def process_message(message: Message) -> None:
109
98
}
110
99
)
111
100
112
- return await self .app (scope , receive , process_message )
101
+ return await self .app (scope , receive , transform_response )
0 commit comments