|
| 1 | +"""Middleware to rewrite 'filter' in .links of the JSON response, removing the filter from the request state.""" |
| 2 | + |
| 3 | +import json |
| 4 | +from dataclasses import dataclass |
| 5 | +from logging import getLogger |
| 6 | +from typing import Optional |
| 7 | +from urllib.parse import parse_qs, urlencode, urlparse, urlunparse |
| 8 | + |
| 9 | +from cql2 import Expr |
| 10 | +from starlette.requests import Request |
| 11 | +from starlette.types import ASGIApp, Message, Receive, Scope, Send |
| 12 | + |
| 13 | +logger = getLogger(__name__) |
| 14 | + |
| 15 | + |
| 16 | +@dataclass(frozen=True) |
| 17 | +class Cql2RewriteLinksFilterMiddleware: |
| 18 | + """ASGI middleware to rewrite 'filter' in .links of the JSON response, removing the filter from the request state.""" |
| 19 | + |
| 20 | + app: ASGIApp |
| 21 | + state_key: str = "cql2_filter" |
| 22 | + |
| 23 | + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: |
| 24 | + """Replace 'filter' in .links of the JSON response to state before we had applied the filter.""" |
| 25 | + if scope["type"] != "http": |
| 26 | + return await self.app(scope, receive, send) |
| 27 | + |
| 28 | + request = Request(scope) |
| 29 | + original_filter = request.query_params.get("filter") |
| 30 | + cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None) |
| 31 | + if cql2_filter is None: |
| 32 | + # No filter set, just pass through |
| 33 | + return await self.app(scope, receive, send) |
| 34 | + |
| 35 | + # Intercept the response |
| 36 | + response_start = None |
| 37 | + body_chunks = [] |
| 38 | + more_body = True |
| 39 | + |
| 40 | + async def send_wrapper(message: Message): |
| 41 | + nonlocal response_start, body_chunks, more_body |
| 42 | + if message["type"] == "http.response.start": |
| 43 | + response_start = message |
| 44 | + elif message["type"] == "http.response.body": |
| 45 | + body_chunks.append(message.get("body", b"")) |
| 46 | + more_body = message.get("more_body", False) |
| 47 | + if not more_body: |
| 48 | + await self._process_and_send_response( |
| 49 | + response_start, body_chunks, send, original_filter |
| 50 | + ) |
| 51 | + else: |
| 52 | + await send(message) |
| 53 | + |
| 54 | + await self.app(scope, receive, send_wrapper) |
| 55 | + |
| 56 | + async def _process_and_send_response( |
| 57 | + self, |
| 58 | + response_start: Message, |
| 59 | + body_chunks: list[bytes], |
| 60 | + send: Send, |
| 61 | + original_filter: Optional[str], |
| 62 | + ): |
| 63 | + body = b"".join(body_chunks) |
| 64 | + try: |
| 65 | + data = json.loads(body) |
| 66 | + except Exception: |
| 67 | + await send(response_start) |
| 68 | + await send({"type": "http.response.body", "body": body, "more_body": False}) |
| 69 | + return |
| 70 | + |
| 71 | + cql2_filter = Expr(original_filter) if original_filter else None |
| 72 | + links = data.get("links") |
| 73 | + if isinstance(links, list): |
| 74 | + for link in links: |
| 75 | + # Handle filter in query string |
| 76 | + if "href" in link: |
| 77 | + url = urlparse(link["href"]) |
| 78 | + qs = parse_qs(url.query) |
| 79 | + if "filter" in qs: |
| 80 | + if cql2_filter: |
| 81 | + qs["filter"] = [cql2_filter.to_text()] |
| 82 | + else: |
| 83 | + qs.pop("filter", None) |
| 84 | + qs.pop("filter-lang", None) |
| 85 | + new_query = urlencode(qs, doseq=True) |
| 86 | + link["href"] = urlunparse(url._replace(query=new_query)) |
| 87 | + |
| 88 | + # Handle filter in body (for POST links) |
| 89 | + if "body" in link and isinstance(link["body"], dict): |
| 90 | + if "filter" in link["body"]: |
| 91 | + if cql2_filter: |
| 92 | + link["body"]["filter"] = cql2_filter.to_json() |
| 93 | + else: |
| 94 | + link["body"].pop("filter", None) |
| 95 | + link["body"].pop("filter-lang", None) |
| 96 | + |
| 97 | + # Send the modified response |
| 98 | + new_body = json.dumps(data).encode("utf-8") |
| 99 | + |
| 100 | + # Patch content-length |
| 101 | + headers = [ |
| 102 | + (k, v) for k, v in response_start["headers"] if k != b"content-length" |
| 103 | + ] |
| 104 | + headers.append((b"content-length", str(len(new_body)).encode("latin1"))) |
| 105 | + response_start = dict(response_start) |
| 106 | + response_start["headers"] = headers |
| 107 | + await send(response_start) |
| 108 | + await send({"type": "http.response.body", "body": new_body, "more_body": False}) |
0 commit comments