diff --git a/src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py b/src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py index 23b575e1..009243d6 100644 --- a/src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py +++ b/src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py @@ -1,19 +1,18 @@ """Middleware to add auth information to the OpenAPI spec served by upstream API.""" -import json from dataclasses import dataclass -from typing import Any, Optional +from typing import Any -from starlette.datastructures import MutableHeaders from starlette.requests import Request -from starlette.types import ASGIApp, Message, Receive, Scope, Send +from starlette.types import ASGIApp from ..config import EndpointMethods -from ..utils.requests import dict_to_bytes, find_match +from ..utils.middleware import JsonResponseMiddleware +from ..utils.requests import find_match @dataclass(frozen=True) -class OpenApiMiddleware: +class OpenApiMiddleware(JsonResponseMiddleware): """Middleware to add the OpenAPI spec to the response.""" app: ASGIApp @@ -24,61 +23,11 @@ class OpenApiMiddleware: default_public: bool oidc_auth_scheme_name: str = "oidcAuth" - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - """Add the OpenAPI spec to the response.""" - if scope["type"] != "http" or Request(scope).url.path != self.openapi_spec_path: - return await self.app(scope, receive, send) + def should_transform_response(self, request: Request) -> bool: + """Only transform responses for the OpenAPI spec path.""" + return request.url.path == self.openapi_spec_path - start_message: Optional[Message] = None - body = b"" - - async def augment_oidc_spec(message: Message): - nonlocal start_message - nonlocal body - if message["type"] == "http.response.start": - # NOTE: Because we are modifying the response body, we will need to update - # the content-length header. However, headers are sent before we see the - # body. To handle this, we delay sending the http.response.start message - # until after we alter the body. - start_message = message - return - elif message["type"] != "http.response.body": - return await send(message) - - body += message["body"] - - # Skip body chunks until all chunks have been received - if message.get("more_body"): - return - - # Maybe decompress the body - headers = MutableHeaders(scope=start_message) - - # Augment the spec - body = dict_to_bytes(self.augment_spec(json.loads(body))) - - # Update the content-length header - headers["content-length"] = str(len(body)) - assert start_message, "Expected start_message to be set" - start_message["headers"] = [ - (key.encode(), value.encode()) for key, value in headers.items() - ] - - # Send http.response.start - await send(start_message) - - # Send http.response.body - await send( - { - "type": "http.response.body", - "body": body, - "more_body": False, - } - ) - - return await self.app(scope, receive, augment_oidc_spec) - - def augment_spec(self, openapi_spec) -> dict[str, Any]: + def transform_json(self, openapi_spec: dict[str, Any]) -> dict[str, Any]: """Augment the OpenAPI spec with auth information.""" components = openapi_spec.setdefault("components", {}) securitySchemes = components.setdefault("securitySchemes", {}) diff --git a/src/stac_auth_proxy/utils/middleware.py b/src/stac_auth_proxy/utils/middleware.py new file mode 100644 index 00000000..c1db7b2f --- /dev/null +++ b/src/stac_auth_proxy/utils/middleware.py @@ -0,0 +1,112 @@ +"""Utilities for middleware response handling.""" + +import json +import re +from abc import ABC, abstractmethod +from typing import Any, Optional + +from starlette.datastructures import Headers, MutableHeaders +from starlette.requests import Request +from starlette.types import ASGIApp, Message, Receive, Scope, Send + + +class JsonResponseMiddleware(ABC): + """Base class for middleware that transforms JSON response bodies.""" + + app: ASGIApp + json_content_type_expr: str = ( + r"application/vnd\.oai\.openapi\+json;.*|application/json|application/geo\+json" + ) + + @abstractmethod + def should_transform_response(self, request: Request) -> bool: + """ + Determine if this request's response should be transformed. + + Args: + request: The incoming request + + Returns + ------- + bool: True if the response should be transformed + """ + return bool( + re.match(self.json_content_type_expr, request.headers.get("accept", "")) + ) + + @abstractmethod + def transform_json(self, data: Any) -> Any: + """ + Transform the JSON data. + + Args: + data: The parsed JSON data + + Returns + ------- + The transformed JSON data + """ + pass + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Process the request/response.""" + if scope["type"] != "http": + return await self.app(scope, receive, send) + + request = Request(scope) + if not self.should_transform_response(request): + return await self.app(scope, receive, send) + + start_message: Optional[Message] = None + body = b"" + not_json = False + + async def process_message(message: Message) -> None: + nonlocal start_message + nonlocal body + nonlocal not_json + if message["type"] == "http.response.start": + # Delay sending start message until we've processed the body + if not re.match( + self.json_content_type_expr, + Headers(scope=message).get("content-type", ""), + ): + not_json = True + return await send(message) + start_message = message + return + elif message["type"] != "http.response.body" or not_json: + return await send(message) + + body += message["body"] + + # Skip body chunks until all chunks have been received + if message.get("more_body"): + return + + headers = MutableHeaders(scope=start_message) + + # Transform the JSON body + if body: + data = json.loads(body) + transformed = self.transform_json(data) + body = json.dumps(transformed).encode() + + # Update content-length header + headers["content-length"] = str(len(body)) + assert start_message, "Expected start_message to be set" + start_message["headers"] = [ + (key.encode(), value.encode()) for key, value in headers.items() + ] + + # Send response + await send(start_message) + await send( + { + "type": "http.response.body", + "body": body, + "more_body": False, + } + ) + + return await self.app(scope, receive, process_message)