11"""Middleware to add auth information to the OpenAPI spec served by upstream API."""
22
3- import json
43from dataclasses import dataclass
5- from typing import Any , Optional
4+ from typing import Any
65
7- from starlette .datastructures import MutableHeaders
86from starlette .requests import Request
9- from starlette .types import ASGIApp , Message , Receive , Scope , Send
7+ from starlette .types import ASGIApp
108
119from ..config import EndpointMethods
12- from ..utils .requests import dict_to_bytes , find_match
10+ from ..utils .middleware import JsonResponseMiddleware
11+ from ..utils .requests import find_match
1312
1413
1514@dataclass (frozen = True )
16- class OpenApiMiddleware :
15+ class OpenApiMiddleware ( JsonResponseMiddleware ) :
1716 """Middleware to add the OpenAPI spec to the response."""
1817
1918 app : ASGIApp
@@ -24,61 +23,11 @@ class OpenApiMiddleware:
2423 default_public : bool
2524 oidc_auth_scheme_name : str = "oidcAuth"
2625
27- async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
28- """Add the OpenAPI spec to the response."""
29- if scope ["type" ] != "http" or Request (scope ).url .path != self .openapi_spec_path :
30- return await self .app (scope , receive , send )
26+ def should_transform_response (self , request : Request ) -> bool :
27+ """Only transform responses for the OpenAPI spec path."""
28+ return request .url .path == self .openapi_spec_path
3129
32- start_message : Optional [Message ] = None
33- body = b""
34-
35- async def augment_oidc_spec (message : Message ):
36- nonlocal start_message
37- nonlocal body
38- if message ["type" ] == "http.response.start" :
39- # NOTE: Because we are modifying the response body, we will need to update
40- # the content-length header. However, headers are sent before we see the
41- # body. To handle this, we delay sending the http.response.start message
42- # until after we alter the body.
43- start_message = message
44- return
45- elif message ["type" ] != "http.response.body" :
46- return await send (message )
47-
48- body += message ["body" ]
49-
50- # Skip body chunks until all chunks have been received
51- if message .get ("more_body" ):
52- return
53-
54- # Maybe decompress the body
55- headers = MutableHeaders (scope = start_message )
56-
57- # Augment the spec
58- body = dict_to_bytes (self .augment_spec (json .loads (body )))
59-
60- # Update the content-length header
61- headers ["content-length" ] = str (len (body ))
62- assert start_message , "Expected start_message to be set"
63- start_message ["headers" ] = [
64- (key .encode (), value .encode ()) for key , value in headers .items ()
65- ]
66-
67- # Send http.response.start
68- await send (start_message )
69-
70- # Send http.response.body
71- await send (
72- {
73- "type" : "http.response.body" ,
74- "body" : body ,
75- "more_body" : False ,
76- }
77- )
78-
79- return await self .app (scope , receive , augment_oidc_spec )
80-
81- def augment_spec (self , openapi_spec ) -> dict [str , Any ]:
30+ def transform_json (self , openapi_spec : dict [str , Any ]) -> dict [str , Any ]:
8231 """Augment the OpenAPI spec with auth information."""
8332 components = openapi_spec .setdefault ("components" , {})
8433 securitySchemes = components .setdefault ("securitySchemes" , {})
0 commit comments