|
| 1 | +"""Middleware to add auth information to the OpenAPI spec served by upstream API.""" |
| 2 | + |
| 3 | +import gzip |
| 4 | +import zlib |
| 5 | +from dataclasses import dataclass, field |
| 6 | +from typing import Any |
| 7 | + |
| 8 | +import brotli |
| 9 | +from starlette.requests import Request |
| 10 | +from starlette.types import ASGIApp, Receive, Scope, Send |
| 11 | + |
| 12 | +from ..utils.requests import find_match |
| 13 | + |
| 14 | +ENCODING_HANDLERS = { |
| 15 | + "gzip": gzip, |
| 16 | + "deflate": zlib, |
| 17 | + "br": brotli, |
| 18 | +} |
| 19 | + |
| 20 | + |
| 21 | +@dataclass(frozen=True) |
| 22 | +class AuthorizationExtension: |
| 23 | + """Middleware to add the OpenAPI spec to the response.""" |
| 24 | + |
| 25 | + app: ASGIApp |
| 26 | + signers: dict[str, str] = field(default_factory=dict) |
| 27 | + |
| 28 | + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: |
| 29 | + """Add the OpenAPI spec to the response.""" |
| 30 | + if scope["type"] != "http" or Request(scope).url.path != self.openapi_spec_path: |
| 31 | + return await self.app(scope, receive, send) |
| 32 | + |
| 33 | + # TODO: test if asset path matches |
| 34 | + # start_message: Optional[Message] = None |
| 35 | + # body = b"" |
| 36 | + |
| 37 | + # async def augment_oidc_spec(message: Message): |
| 38 | + # nonlocal start_message |
| 39 | + # nonlocal body |
| 40 | + # if message["type"] == "http.response.start": |
| 41 | + # # NOTE: Because we are modifying the response body, we will need to update |
| 42 | + # # the content-length header. However, headers are sent before we see the |
| 43 | + # # body. To handle this, we delay sending the http.response.start message |
| 44 | + # # until after we alter the body. |
| 45 | + # start_message = message |
| 46 | + # return |
| 47 | + # elif message["type"] != "http.response.body": |
| 48 | + # return await send(message) |
| 49 | + |
| 50 | + # body += message["body"] |
| 51 | + |
| 52 | + # # Skip body chunks until all chunks have been received |
| 53 | + # if message["more_body"]: |
| 54 | + # return |
| 55 | + |
| 56 | + # # Maybe decompress the body |
| 57 | + # headers = MutableHeaders(scope=start_message) |
| 58 | + # content_encoding = headers.get("content-encoding", "").lower() |
| 59 | + # handler = None |
| 60 | + # if content_encoding: |
| 61 | + # handler = ENCODING_HANDLERS.get(content_encoding) |
| 62 | + # assert handler, f"Unsupported content encoding: {content_encoding}" |
| 63 | + # body = ( |
| 64 | + # handler.decompress(body) |
| 65 | + # if content_encoding != "deflate" |
| 66 | + # else handler.decompress(body, -zlib.MAX_WBITS) |
| 67 | + # ) |
| 68 | + |
| 69 | + # # Augment the spec |
| 70 | + # body = dict_to_bytes(self.augment_spec(json.loads(body))) |
| 71 | + |
| 72 | + # # Maybe re-compress the body |
| 73 | + # if handler: |
| 74 | + # body = handler.compress(body) |
| 75 | + |
| 76 | + # # Update the content-length header |
| 77 | + # headers["content-length"] = str(len(body)) |
| 78 | + # assert start_message, "Expected start_message to be set" |
| 79 | + # start_message["headers"] = [ |
| 80 | + # (key.encode(), value.encode()) for key, value in headers.items() |
| 81 | + # ] |
| 82 | + |
| 83 | + # # Send http.response.start |
| 84 | + # await send(start_message) |
| 85 | + |
| 86 | + # # Send http.response.body |
| 87 | + # await send( |
| 88 | + # { |
| 89 | + # "type": "http.response.body", |
| 90 | + # "body": body, |
| 91 | + # "more_body": False, |
| 92 | + # } |
| 93 | + # ) |
| 94 | + |
| 95 | + return await self.app(scope, receive, augment_oidc_spec) |
| 96 | + |
| 97 | + def augment_spec(self, openapi_spec) -> dict[str, Any]: |
| 98 | + """Augment the OpenAPI spec with auth information.""" |
| 99 | + components = openapi_spec.setdefault("components", {}) |
| 100 | + securitySchemes = components.setdefault("securitySchemes", {}) |
| 101 | + securitySchemes[self.oidc_auth_scheme_name] = { |
| 102 | + "type": "openIdConnect", |
| 103 | + "openIdConnectUrl": self.oidc_config_url, |
| 104 | + } |
| 105 | + for path, method_config in openapi_spec["paths"].items(): |
| 106 | + for method, config in method_config.items(): |
| 107 | + match = find_match( |
| 108 | + path, |
| 109 | + method, |
| 110 | + self.private_endpoints, |
| 111 | + self.public_endpoints, |
| 112 | + self.default_public, |
| 113 | + ) |
| 114 | + if match.is_private: |
| 115 | + config.setdefault("security", []).append( |
| 116 | + {self.oidc_auth_scheme_name: match.required_scopes} |
| 117 | + ) |
| 118 | + return openapi_spec |
0 commit comments