|
17 | 17 | import os |
18 | 18 | import traceback |
19 | 19 | from collections.abc import Callable |
20 | | -from typing import ParamSpec, TypedDict |
| 20 | +from typing import Any, ParamSpec, TypedDict |
21 | 21 | from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse |
22 | 22 |
|
| 23 | +import brotli |
23 | 24 | from fastapi import FastAPI, Request, Response, status |
24 | 25 | from fastapi.responses import JSONResponse |
25 | 26 | from rs_server_common import settings as common_settings |
@@ -131,6 +132,111 @@ def is_bad_request(self, request: Request, e: Exception) -> bool: |
131 | 132 | ) |
132 | 133 |
|
133 | 134 |
|
| 135 | +class PaginationLinksMiddleware(BaseHTTPMiddleware): |
| 136 | + """ |
| 137 | + Middleware to implement 'first' button's functionality in STAC Browser |
| 138 | + """ |
| 139 | + |
| 140 | + async def dispatch( |
| 141 | + self, |
| 142 | + request: Request, |
| 143 | + call_next: Callable, |
| 144 | + ): # pylint: disable=too-many-branches,too-many-statements |
| 145 | + |
| 146 | + # Only for /search in auxip, prip, cadip |
| 147 | + if request.url.path in ["/auxip/search", "/cadip/search", "/prip/search", "/catalog/search"]: |
| 148 | + |
| 149 | + first_link: dict[str, Any] = { |
| 150 | + "rel": "first", |
| 151 | + "type": "application/geo+json", |
| 152 | + "method": request.method, |
| 153 | + "href": f"{str(request.base_url).rstrip('/')}{request.url.path}", |
| 154 | + "title": "First link", |
| 155 | + } |
| 156 | + |
| 157 | + if common_settings.CLUSTER_MODE: |
| 158 | + first_link["href"] = f"https://{str(request.base_url.hostname).rstrip('/')}{request.url.path}" |
| 159 | + |
| 160 | + if request.method == "GET": |
| 161 | + # parse query params to remove any 'prev' or 'next' |
| 162 | + query_dict = dict(request.query_params) |
| 163 | + |
| 164 | + query_dict.pop("token", None) |
| 165 | + if "page" in query_dict: |
| 166 | + query_dict["page"] = "1" |
| 167 | + new_query_string = urlencode(query_dict, doseq=True) |
| 168 | + first_link["href"] += f"?{new_query_string}" |
| 169 | + |
| 170 | + elif request.method == "POST": |
| 171 | + try: |
| 172 | + query = await request.json() |
| 173 | + body = {} |
| 174 | + |
| 175 | + for key in ["datetime", "limit"]: |
| 176 | + if key in query and query[key] is not None: |
| 177 | + body[key] = query[key] |
| 178 | + |
| 179 | + if "token" in query and request.url.path != "/catalog/search": |
| 180 | + body["token"] = "page=1" # nosec |
| 181 | + |
| 182 | + first_link["body"] = body |
| 183 | + except Exception: # pylint: disable = broad-exception-caught |
| 184 | + logger.error(traceback.format_exc()) |
| 185 | + |
| 186 | + response = await call_next(request) |
| 187 | + |
| 188 | + encoding = response.headers.get("content-encoding", "") |
| 189 | + if encoding == "br": |
| 190 | + body_bytes = b"".join([section async for section in response.body_iterator]) |
| 191 | + response_body = brotli.decompress(body_bytes) |
| 192 | + |
| 193 | + if request.url.path == "/catalog/search": |
| 194 | + first_link["auth:refs"] = ["apikey", "openid", "oauth2"] |
| 195 | + else: |
| 196 | + response_body = b"" |
| 197 | + async for chunk in response.body_iterator: |
| 198 | + response_body += chunk |
| 199 | + |
| 200 | + try: |
| 201 | + data = json.loads(response_body) |
| 202 | + |
| 203 | + links = data.get("links", []) |
| 204 | + has_prev = any(link.get("rel") == "previous" for link in links) |
| 205 | + |
| 206 | + if has_prev is True: |
| 207 | + links.append(first_link) |
| 208 | + data["links"] = links |
| 209 | + |
| 210 | + headers = dict(response.headers) |
| 211 | + headers.pop("content-length", None) |
| 212 | + |
| 213 | + if encoding == "br": |
| 214 | + new_body = brotli.compress(json.dumps(data).encode("utf-8")) |
| 215 | + else: |
| 216 | + new_body = json.dumps(data).encode("utf-8") |
| 217 | + |
| 218 | + response = Response( |
| 219 | + content=new_body, |
| 220 | + status_code=response.status_code, |
| 221 | + headers=headers, |
| 222 | + media_type="application/json", |
| 223 | + ) |
| 224 | + except Exception: # pylint: disable = broad-exception-caught |
| 225 | + headers = dict(response.headers) |
| 226 | + headers.pop("content-length", None) |
| 227 | + |
| 228 | + response = Response( |
| 229 | + content=response_body, |
| 230 | + status_code=response.status_code, |
| 231 | + headers=headers, |
| 232 | + media_type=response.headers.get("content-type"), |
| 233 | + ) |
| 234 | + else: |
| 235 | + return await call_next(request) |
| 236 | + |
| 237 | + return response |
| 238 | + |
| 239 | + |
134 | 240 | def get_link_title(link: dict, entity: dict) -> str: |
135 | 241 | """ |
136 | 242 | Determine a human-readable STAC link title based on the link relation and context. |
|
0 commit comments