77from dataclasses import dataclass
88from typing import Any
99
10+ from starlette .datastructures import MutableHeaders
1011from starlette .requests import Request
1112from starlette .types import ASGIApp , Message , Receive , Scope , Send
1213
1314from ..config import EndpointMethods
1415from ..utils .requests import dict_to_bytes
1516
1617
18+ ENCODING_HANDLERS = {
19+ "gzip" : gzip ,
20+ "deflate" : zlib ,
21+ "br" : brotli ,
22+ }
23+
24+
1725@dataclass (frozen = True )
1826class OpenApiMiddleware :
1927 """Middleware to add the OpenAPI spec to the response."""
@@ -30,41 +38,56 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
3038 if scope ["type" ] != "http" or Request (scope ).url .path != self .openapi_spec_path :
3139 return await self .app (scope , receive , send )
3240
33- total_body = b""
41+ start_message = None
42+ body = b""
3443
3544 async def augment_oidc_spec (message : Message ):
36- if message ["type" ] != "http.response.body" :
45+ nonlocal start_message
46+ nonlocal body
47+ if message ["type" ] == "http.response.start" :
48+ # NOTE: Because we are modifying the response body, we will need to update
49+ # the content-length header. However, headers are sent before we see the
50+ # body. To handle this, we delay sending the http.response.start message
51+ # until after we alter the body.
52+ start_message = message
53+ return
54+ elif message ["type" ] != "http.response.body" :
3755 return await send (message )
3856
39- # TODO: Make more robust to handle non-JSON responses
57+ body += message [ "body" ]
4058
41- nonlocal total_body
59+ # Skip body chunks until all chunks have been received
60+ if message ["more_body" ]:
61+ return
4262
43- total_body += message ["body" ]
63+ # Maybe decompress the body
64+ headers = MutableHeaders (scope = start_message )
65+ content_encoding = headers .get ("content-encoding" , "" ).lower ()
66+ handler = None
67+ if content_encoding :
68+ handler = ENCODING_HANDLERS .get (content_encoding )
69+ assert handler , f"Unsupported content encoding: { content_encoding } "
70+ body = handler .decompress (body )
4471
45- # Pass empty body chunks until all chunks have been received
46- if message ["more_body" ]:
47- return await send ({** message , "body" : b"" })
48-
49- # Handle compressed responses
50- # content_encoding = (
51- # message.get("headers", {})
52- # .get(b"content-encoding", b"")
53- # .decode()
54- # .lower()
55- # )
56- # if content_encoding:
57- # if "gzip" in content_encoding:
58- # total_body = gzip.decompress(total_body)
59- # elif "deflate" in content_encoding:
60- # total_body = zlib.decompress(total_body)
61- # elif "br" in content_encoding:
62- # total_body = brotli.decompress(total_body)
63- # print(f"{message=}")
72+ # Augment the spec
73+ body = dict_to_bytes (self .augment_spec (json .loads (body )))
74+
75+ # Maybe re-compress the body
76+ if handler :
77+ body = handler .compress (body )
78+
79+ # Update the content-length header
80+ headers ["content-length" ] = str (len (body ))
81+ start_message ["headers" ] = headers .items ()
82+
83+ # Send http.response.start
84+ await send (start_message )
85+
86+ # Send http.response.body
6487 await send (
6588 {
6689 "type" : "http.response.body" ,
67- "body" : dict_to_bytes ( self . augment_spec ( json . loads ( total_body ))) ,
90+ "body" : body ,
6891 "more_body" : False ,
6992 }
7093 )
0 commit comments