11"""Middleware to add auth information to the OpenAPI spec served by upstream API."""
22
3+ import brotli
4+ import gzip
35import json
6+ import re
7+ import zlib
48from dataclasses import dataclass
59from typing import Any
610
11+ from starlette .datastructures import MutableHeaders
712from starlette .requests import Request
813from starlette .types import ASGIApp , Message , Receive , Scope , Send
914
1015from ..config import EndpointMethods
1116from ..utils .requests import dict_to_bytes
1217
1318
19+ ENCODING_HANDLERS = {
20+ "gzip" : gzip ,
21+ "deflate" : zlib ,
22+ "br" : brotli ,
23+ }
24+
25+
1426@dataclass (frozen = True )
1527class OpenApiMiddleware :
1628 """Middleware to add the OpenAPI spec to the response."""
@@ -27,26 +39,58 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
2739 if scope ["type" ] != "http" or Request (scope ).url .path != self .openapi_spec_path :
2840 return await self .app (scope , receive , send )
2941
30- total_body = b""
42+ start_message = None
43+ body = b""
3144
3245 async def augment_oidc_spec (message : Message ):
33- if message ["type" ] != "http.response.body" :
46+ nonlocal start_message
47+ nonlocal body
48+ if message ["type" ] == "http.response.start" :
49+ # NOTE: Because we are modifying the response body, we will need to update
50+ # the content-length header. However, headers are sent before we see the
51+ # body. To handle this, we delay sending the http.response.start message
52+ # until after we alter the body.
53+ start_message = message
54+ return
55+ elif message ["type" ] != "http.response.body" :
3456 return await send (message )
3557
36- # TODO: Make more robust to handle non-JSON responses
37-
38- nonlocal total_body
58+ body += message ["body" ]
3959
40- total_body += message ["body" ]
41-
42- # Pass empty body chunks until all chunks have been received
60+ # Skip body chunks until all chunks have been received
4361 if message ["more_body" ]:
44- return await send ({** message , "body" : b"" })
45-
62+ return
63+
64+ # Maybe decompress the body
65+ headers = MutableHeaders (scope = start_message )
66+ content_encoding = headers .get ("content-encoding" , "" ).lower ()
67+ handler = None
68+ if content_encoding :
69+ handler = ENCODING_HANDLERS .get (content_encoding )
70+ assert handler , f"Unsupported content encoding: { content_encoding } "
71+ body = handler .decompress (body )
72+
73+ # Augment the spec
74+ body = dict_to_bytes (self .augment_spec (json .loads (body )))
75+
76+ # Maybe re-compress the body
77+ if handler :
78+ body = handler .compress (body )
79+
80+ # Update the content-length header
81+ headers ["content-length" ] = str (len (body ))
82+ start_message ["headers" ] = [
83+ (key .encode (), value .encode ()) for key , value in headers .items ()
84+ ]
85+
86+ # Send http.response.start
87+ await send (start_message )
88+
89+ # Send http.response.body
4690 await send (
4791 {
4892 "type" : "http.response.body" ,
49- "body" : dict_to_bytes ( self . augment_spec ( json . loads ( total_body ))) ,
93+ "body" : body ,
5094 "more_body" : False ,
5195 }
5296 )
0 commit comments