Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 9 additions & 60 deletions src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
"""Middleware to add auth information to the OpenAPI spec served by upstream API."""

import json
from dataclasses import dataclass
from typing import Any, Optional
from typing import Any

from starlette.datastructures import MutableHeaders
from starlette.requests import Request
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.types import ASGIApp

from ..config import EndpointMethods
from ..utils.requests import dict_to_bytes, find_match
from ..utils.middleware import JsonResponseMiddleware
from ..utils.requests import find_match


@dataclass(frozen=True)
class OpenApiMiddleware:
class OpenApiMiddleware(JsonResponseMiddleware):
"""Middleware to add the OpenAPI spec to the response."""

app: ASGIApp
Expand All @@ -24,61 +23,11 @@ class OpenApiMiddleware:
default_public: bool
oidc_auth_scheme_name: str = "oidcAuth"

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Add the OpenAPI spec to the response."""
if scope["type"] != "http" or Request(scope).url.path != self.openapi_spec_path:
return await self.app(scope, receive, send)
def should_transform_response(self, request: Request) -> bool:
"""Only transform responses for the OpenAPI spec path."""
return request.url.path == self.openapi_spec_path

start_message: Optional[Message] = None
body = b""

async def augment_oidc_spec(message: Message):
nonlocal start_message
nonlocal body
if message["type"] == "http.response.start":
# NOTE: Because we are modifying the response body, we will need to update
# the content-length header. However, headers are sent before we see the
# body. To handle this, we delay sending the http.response.start message
# until after we alter the body.
start_message = message
return
elif message["type"] != "http.response.body":
return await send(message)

body += message["body"]

# Skip body chunks until all chunks have been received
if message.get("more_body"):
return

# Maybe decompress the body
headers = MutableHeaders(scope=start_message)

# Augment the spec
body = dict_to_bytes(self.augment_spec(json.loads(body)))

# Update the content-length header
headers["content-length"] = str(len(body))
assert start_message, "Expected start_message to be set"
start_message["headers"] = [
(key.encode(), value.encode()) for key, value in headers.items()
]

# Send http.response.start
await send(start_message)

# Send http.response.body
await send(
{
"type": "http.response.body",
"body": body,
"more_body": False,
}
)

return await self.app(scope, receive, augment_oidc_spec)

def augment_spec(self, openapi_spec) -> dict[str, Any]:
def transform_json(self, openapi_spec: dict[str, Any]) -> dict[str, Any]:
"""Augment the OpenAPI spec with auth information."""
components = openapi_spec.setdefault("components", {})
securitySchemes = components.setdefault("securitySchemes", {})
Expand Down
112 changes: 112 additions & 0 deletions src/stac_auth_proxy/utils/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""Utilities for middleware response handling."""

import json
import re
from abc import ABC, abstractmethod
from typing import Any, Optional

from starlette.datastructures import Headers, MutableHeaders
from starlette.requests import Request
from starlette.types import ASGIApp, Message, Receive, Scope, Send


class JsonResponseMiddleware(ABC):
"""Base class for middleware that transforms JSON response bodies."""

app: ASGIApp
json_content_type_expr: str = (
r"application/vnd\.oai\.openapi\+json;.*|application/json|application/geo\+json"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alukach you may also add application/schema+json (queryables)

)

@abstractmethod
def should_transform_response(self, request: Request) -> bool:
"""
Determine if this request's response should be transformed.

Args:
request: The incoming request

Returns
-------
bool: True if the response should be transformed
"""
return bool(
re.match(self.json_content_type_expr, request.headers.get("accept", ""))
)

@abstractmethod
def transform_json(self, data: Any) -> Any:
"""
Transform the JSON data.

Args:
data: The parsed JSON data

Returns
-------
The transformed JSON data
"""
pass

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Process the request/response."""
if scope["type"] != "http":
return await self.app(scope, receive, send)

request = Request(scope)
if not self.should_transform_response(request):
return await self.app(scope, receive, send)

start_message: Optional[Message] = None
body = b""
not_json = False

async def process_message(message: Message) -> None:
nonlocal start_message
nonlocal body
nonlocal not_json
if message["type"] == "http.response.start":
# Delay sending start message until we've processed the body
if not re.match(
self.json_content_type_expr,
Headers(scope=message).get("content-type", ""),
):
not_json = True
return await send(message)
start_message = message
return
elif message["type"] != "http.response.body" or not_json:
return await send(message)

body += message["body"]

# Skip body chunks until all chunks have been received
if message.get("more_body"):
return

headers = MutableHeaders(scope=start_message)

# Transform the JSON body
if body:
data = json.loads(body)
transformed = self.transform_json(data)
body = json.dumps(transformed).encode()

# Update content-length header
headers["content-length"] = str(len(body))
assert start_message, "Expected start_message to be set"
start_message["headers"] = [
(key.encode(), value.encode()) for key, value in headers.items()
]

# Send response
await send(start_message)
await send(
{
"type": "http.response.body",
"body": body,
"more_body": False,
}
)

return await self.app(scope, receive, process_message)