11"""Middleware to add auth information to the OpenAPI spec served by upstream API."""
22
3- import brotli
43import gzip
54import json
65import re
76import zlib
87from dataclasses import dataclass
9- from typing import Any
8+ from typing import Any , Optional
109
10+ import brotli
1111from starlette .datastructures import MutableHeaders
1212from starlette .requests import Request
1313from starlette .types import ASGIApp , Message , Receive , Scope , Send
1414
1515from ..config import EndpointMethods
1616from ..utils .requests import dict_to_bytes
1717
18-
1918ENCODING_HANDLERS = {
2019 "gzip" : gzip ,
2120 "deflate" : zlib ,
@@ -40,7 +39,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
4039 if scope ["type" ] != "http" or Request (scope ).url .path != self .openapi_spec_path :
4140 return await self .app (scope , receive , send )
4241
43- start_message = None
42+ start_message : Optional [ Message ] = None
4443 body = b""
4544
4645 async def augment_oidc_spec (message : Message ):
@@ -80,6 +79,7 @@ async def augment_oidc_spec(message: Message):
8079
8180 # Update the content-length header
8281 headers ["content-length" ] = str (len (body ))
82+ assert start_message , "Expected start_message to be set"
8383 start_message ["headers" ] = [
8484 (key .encode (), value .encode ()) for key , value in headers .items ()
8585 ]
@@ -120,7 +120,7 @@ def augment_spec(self, openapi_spec) -> dict[str, Any]:
120120 return openapi_spec
121121
122122 @staticmethod
123- def path_matches (path : str , method : str , endpoints : dict [ str , list [ str ]] ) -> bool :
123+ def path_matches (path : str , method : str , endpoints : EndpointMethods ) -> bool :
124124 """Check if the given path and method match any of the regex patterns and methods in the endpoints."""
125125 for pattern , endpoint_methods in endpoints .items ():
126126 if not re .match (pattern , path ):
0 commit comments