|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | """Common functions for fastapi middlewares""" |
| 16 | +import json |
16 | 17 | import os |
17 | 18 | import traceback |
18 | 19 | from collections.abc import Callable |
19 | 20 | from typing import ParamSpec, TypedDict |
| 21 | +from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse |
20 | 22 |
|
21 | | -from fastapi import FastAPI, Request, status |
| 23 | +from fastapi import FastAPI, Request, Response, status |
22 | 24 | from fastapi.responses import JSONResponse |
23 | 25 | from rs_server_common import settings as common_settings |
24 | 26 | from rs_server_common.authentication import authentication, oauth2 |
|
30 | 32 | from starlette.middleware.base import BaseHTTPMiddleware |
31 | 33 | from starlette.middleware.sessions import SessionMiddleware |
32 | 34 |
|
| 35 | +REL_TITLES = { |
| 36 | + "collection": "Collection", |
| 37 | + "item": "Item", |
| 38 | + "parent": "Parent Catalog", |
| 39 | + "root": "STAC Root Catalog", |
| 40 | + "conformance": "Conformance link", |
| 41 | + "service-desc": "Service description", |
| 42 | + "service-doc": "Service documentation", |
| 43 | + "search": "Search endpoint", |
| 44 | + "data": "Data link", |
| 45 | + "items": "This collection items", |
| 46 | + "self": "This collection", |
| 47 | + "license": "License description", |
| 48 | + "describedby": "Described by link", |
| 49 | + "next": "Next link", |
| 50 | + "previous": "Previous link", |
| 51 | +} |
| 52 | +# pylint: disable = too-few-public-methods, too-many-return-statements |
33 | 53 | logger = Logging.default(__name__) |
34 | 54 | P = ParamSpec("P") |
35 | 55 |
|
@@ -111,6 +131,114 @@ def is_bad_request(self, request: Request, e: Exception) -> bool: |
111 | 131 | ) |
112 | 132 |
|
113 | 133 |
|
| 134 | +def get_link_title(link: dict, entity: dict) -> str: |
| 135 | + """ |
| 136 | + Determine a human-readable STAC link title based on the link relation and context. |
| 137 | + """ |
| 138 | + rel = link.get("rel") |
| 139 | + href = link.get("href", "") |
| 140 | + if "title" in link: |
| 141 | + # don't overwrite |
| 142 | + return link["title"] |
| 143 | + match rel: |
| 144 | + # --- special cases needing entity context --- |
| 145 | + case "collection": |
| 146 | + return entity.get("title") or entity.get("id") or REL_TITLES["collection"] |
| 147 | + case "item": |
| 148 | + return entity.get("title") or entity.get("id") or REL_TITLES["item"] |
| 149 | + case "self" if entity.get("type") == "Catalog": |
| 150 | + return "STAC Landing Page" |
| 151 | + case "self" if href.endswith("/collections"): |
| 152 | + return "All Collections" |
| 153 | + case "child": |
| 154 | + path = urlparse(href).path |
| 155 | + collection_id = path.split("/")[-1] if path else "unknown" |
| 156 | + return f"All from collection {collection_id}" |
| 157 | + # --- all others: just lookup in REL_TITLES --- |
| 158 | + case _: |
| 159 | + return REL_TITLES.get(rel, href or "Unknown Entity") # type: ignore |
| 160 | + |
| 161 | + |
| 162 | +def normalize_href(href: str) -> str: |
| 163 | + """Encode query parameters in href to match expected STAC format.""" |
| 164 | + parsed = urlparse(href) |
| 165 | + query = urlencode(parse_qsl(parsed.query), safe="") # encode ":" -> "%3A" |
| 166 | + return urlunparse(parsed._replace(query=query)) |
| 167 | + |
| 168 | + |
| 169 | +class StacLinksTitleMiddleware(BaseHTTPMiddleware): |
| 170 | + """Middleware used to update links with title""" |
| 171 | + |
| 172 | + def __init__(self, app: FastAPI, title: str = "Default Title"): |
| 173 | + """ |
| 174 | + Initialize the middleware. |
| 175 | +
|
| 176 | + Args: |
| 177 | + app: The FastAPI application instance to attach the middleware to. |
| 178 | + title: Default title to use for STAC links if no specific title is provided. |
| 179 | + """ |
| 180 | + super().__init__(app) |
| 181 | + self.title = title |
| 182 | + |
| 183 | + async def dispatch(self, request: Request, call_next): |
| 184 | + """ |
| 185 | + Intercept and modify outgoing responses to ensure all STAC links have proper titles. |
| 186 | +
|
| 187 | + This middleware method: |
| 188 | + 1. Awaits the response from the next handler. |
| 189 | + 2. Reads and parses the response body as JSON. |
| 190 | + 3. Updates the "title" property of each link using `get_link_title`. |
| 191 | + 4. Rebuilds the response without the original Content-Length header to prevent mismatches. |
| 192 | + 5. If the response body is not JSON, returns it unchanged. |
| 193 | +
|
| 194 | + Args: |
| 195 | + request: The incoming FastAPI Request object. |
| 196 | + call_next: The next ASGI handler in the middleware chain. |
| 197 | +
|
| 198 | + Returns: |
| 199 | + A FastAPI Response object with updated STAC link titles. |
| 200 | + """ |
| 201 | + response = await call_next(request) |
| 202 | + |
| 203 | + body = b"" |
| 204 | + async for chunk in response.body_iterator: |
| 205 | + body += chunk |
| 206 | + |
| 207 | + try: |
| 208 | + data = json.loads(body) |
| 209 | + |
| 210 | + if isinstance(data, dict) and "links" in data: |
| 211 | + for link in data["links"]: |
| 212 | + if isinstance(link, dict): |
| 213 | + # normalize href to decode any %xx |
| 214 | + if "href" in link: |
| 215 | + link["href"] = normalize_href(link["href"]) |
| 216 | + # update title |
| 217 | + link["title"] = get_link_title(link, data) |
| 218 | + |
| 219 | + headers = dict(response.headers) |
| 220 | + headers.pop("content-length", None) |
| 221 | + |
| 222 | + response = Response( |
| 223 | + content=json.dumps(data, ensure_ascii=False).encode("utf-8"), |
| 224 | + status_code=response.status_code, |
| 225 | + headers=headers, |
| 226 | + media_type="application/json", |
| 227 | + ) |
| 228 | + except Exception: # pylint: disable = broad-exception-caught |
| 229 | + headers = dict(response.headers) |
| 230 | + headers.pop("content-length", None) |
| 231 | + |
| 232 | + response = Response( |
| 233 | + content=body, |
| 234 | + status_code=response.status_code, |
| 235 | + headers=headers, |
| 236 | + media_type=response.headers.get("content-type"), |
| 237 | + ) |
| 238 | + |
| 239 | + return response |
| 240 | + |
| 241 | + |
114 | 242 | def insert_middleware_at(app: FastAPI, index: int, middleware: Middleware): |
115 | 243 | """Insert the given middleware at the specified index in a FastAPI application. |
116 | 244 |
|
|
0 commit comments