diff --git a/src/stac_auth_proxy/app.py b/src/stac_auth_proxy/app.py index eb62eda..313dc51 100644 --- a/src/stac_auth_proxy/app.py +++ b/src/stac_auth_proxy/app.py @@ -20,6 +20,7 @@ Cql2ApplyFilterBodyMiddleware, Cql2ApplyFilterQueryStringMiddleware, Cql2BuildFilterMiddleware, + Cql2RewriteLinksFilterMiddleware, Cql2ValidateResponseBodyMiddleware, EnforceAuthMiddleware, OpenApiMiddleware, @@ -137,6 +138,7 @@ async def lifespan(app: FastAPI): app.add_middleware(Cql2ValidateResponseBodyMiddleware) app.add_middleware(Cql2ApplyFilterBodyMiddleware) app.add_middleware(Cql2ApplyFilterQueryStringMiddleware) + app.add_middleware(Cql2RewriteLinksFilterMiddleware) app.add_middleware( Cql2BuildFilterMiddleware, items_filter=settings.items_filter() if settings.items_filter else None, diff --git a/src/stac_auth_proxy/middleware/Cql2RewriteLinksFilterMiddleware.py b/src/stac_auth_proxy/middleware/Cql2RewriteLinksFilterMiddleware.py new file mode 100644 index 0000000..1909bfd --- /dev/null +++ b/src/stac_auth_proxy/middleware/Cql2RewriteLinksFilterMiddleware.py @@ -0,0 +1,108 @@ +"""Middleware to rewrite 'filter' in .links of the JSON response, removing the filter from the request state.""" + +import json +from dataclasses import dataclass +from logging import getLogger +from typing import Optional +from urllib.parse import parse_qs, urlencode, urlparse, urlunparse + +from cql2 import Expr +from starlette.requests import Request +from starlette.types import ASGIApp, Message, Receive, Scope, Send + +logger = getLogger(__name__) + + +@dataclass(frozen=True) +class Cql2RewriteLinksFilterMiddleware: + """ASGI middleware to rewrite 'filter' in .links of the JSON response, removing the filter from the request state.""" + + app: ASGIApp + state_key: str = "cql2_filter" + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Replace 'filter' in .links of the JSON response to state before we had applied the filter.""" + if scope["type"] != "http": + return await self.app(scope, receive, send) + + request = Request(scope) + original_filter = request.query_params.get("filter") + cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None) + if cql2_filter is None: + # No filter set, just pass through + return await self.app(scope, receive, send) + + # Intercept the response + response_start = None + body_chunks = [] + more_body = True + + async def send_wrapper(message: Message): + nonlocal response_start, body_chunks, more_body + if message["type"] == "http.response.start": + response_start = message + elif message["type"] == "http.response.body": + body_chunks.append(message.get("body", b"")) + more_body = message.get("more_body", False) + if not more_body: + await self._process_and_send_response( + response_start, body_chunks, send, original_filter + ) + else: + await send(message) + + await self.app(scope, receive, send_wrapper) + + async def _process_and_send_response( + self, + response_start: Message, + body_chunks: list[bytes], + send: Send, + original_filter: Optional[str], + ): + body = b"".join(body_chunks) + try: + data = json.loads(body) + except Exception: + await send(response_start) + await send({"type": "http.response.body", "body": body, "more_body": False}) + return + + cql2_filter = Expr(original_filter) if original_filter else None + links = data.get("links") + if isinstance(links, list): + for link in links: + # Handle filter in query string + if "href" in link: + url = urlparse(link["href"]) + qs = parse_qs(url.query) + if "filter" in qs: + if cql2_filter: + qs["filter"] = [cql2_filter.to_text()] + else: + qs.pop("filter", None) + qs.pop("filter-lang", None) + new_query = urlencode(qs, doseq=True) + link["href"] = urlunparse(url._replace(query=new_query)) + + # Handle filter in body (for POST links) + if "body" in link and isinstance(link["body"], dict): + if "filter" in link["body"]: + if cql2_filter: + link["body"]["filter"] = cql2_filter.to_json() + else: + link["body"].pop("filter", None) + link["body"].pop("filter-lang", None) + + # Send the modified response + new_body = json.dumps(data).encode("utf-8") + + # Patch content-length + headers = [ + (k, v) for k, v in response_start["headers"] if k != b"content-length" + ] + headers.append((b"content-length", str(len(new_body)).encode("latin1"))) + response_start = dict(response_start) + response_start["headers"] = headers + await send(response_start) + await send({"type": "http.response.body", "body": new_body, "more_body": False}) diff --git a/src/stac_auth_proxy/middleware/__init__.py b/src/stac_auth_proxy/middleware/__init__.py index bc1ae4f..c5dc005 100644 --- a/src/stac_auth_proxy/middleware/__init__.py +++ b/src/stac_auth_proxy/middleware/__init__.py @@ -5,6 +5,7 @@ from .Cql2ApplyFilterBodyMiddleware import Cql2ApplyFilterBodyMiddleware from .Cql2ApplyFilterQueryStringMiddleware import Cql2ApplyFilterQueryStringMiddleware from .Cql2BuildFilterMiddleware import Cql2BuildFilterMiddleware +from .Cql2RewriteLinksFilterMiddleware import Cql2RewriteLinksFilterMiddleware from .Cql2ValidateResponseBodyMiddleware import Cql2ValidateResponseBodyMiddleware from .EnforceAuthMiddleware import EnforceAuthMiddleware from .ProcessLinksMiddleware import ProcessLinksMiddleware @@ -17,6 +18,7 @@ "Cql2ApplyFilterBodyMiddleware", "Cql2ApplyFilterQueryStringMiddleware", "Cql2BuildFilterMiddleware", + "Cql2RewriteLinksFilterMiddleware", "Cql2ValidateResponseBodyMiddleware", "EnforceAuthMiddleware", "OpenApiMiddleware", diff --git a/tests/test_cql2_rewrite_links_filter_middleware.py b/tests/test_cql2_rewrite_links_filter_middleware.py new file mode 100644 index 0000000..c81649e --- /dev/null +++ b/tests/test_cql2_rewrite_links_filter_middleware.py @@ -0,0 +1,110 @@ +from unittest.mock import MagicMock, patch + +import pytest +from fastapi import FastAPI, Request, Response +from starlette.testclient import TestClient + +from stac_auth_proxy.middleware.Cql2RewriteLinksFilterMiddleware import ( + Cql2RewriteLinksFilterMiddleware, +) + + +@pytest.fixture +def app_with_middleware(): + app = FastAPI() + app.add_middleware(Cql2RewriteLinksFilterMiddleware) + + @app.get("/test") + async def test_endpoint(request: Request): + # Simulate a response with links containing a filter in the query and body + return { + "links": [ + { + "rel": "self", + "href": "http://example.com/search?filter=foo&filter-lang=cql2-text", + }, + { + "rel": "post", + "body": {"filter": "foo", "filter-lang": "cql2-json"}, + }, + ] + } + + return app + + +def test_rewrite_links_with_filter(app_with_middleware): + # Patch cql2.Expr to simulate to_text and to_json + with patch( + "stac_auth_proxy.middleware.Cql2RewriteLinksFilterMiddleware.Expr" + ) as MockExpr: + mock_expr = MagicMock() + mock_expr.to_text.return_value = "bar" + mock_expr.to_json.return_value = {"foo": "bar"} + MockExpr.return_value = mock_expr + + client = TestClient(app_with_middleware) + response = client.get("/test?filter=foo") + assert response.status_code == 200 + data = response.json() + # The filter in the href should be rewritten + assert any( + "filter=bar" in link["href"] for link in data["links"] if "href" in link + ) + # The filter in the body should be rewritten + assert any( + link.get("body", {}).get("filter") == {"foo": "bar"} + for link in data["links"] + ) + + +def test_remove_filter_from_links(app_with_middleware): + # Patch cql2.Expr to return None (no filter) + with patch( + "stac_auth_proxy.middleware.Cql2RewriteLinksFilterMiddleware.Expr" + ) as MockExpr: + MockExpr.return_value = None + client = TestClient(app_with_middleware) + response = client.get("/test") + assert response.status_code == 200 + data = response.json() + # The filter should be removed from href and body + for link in data["links"]: + if "href" in link: + assert "filter=" not in link["href"] + if "body" in link: + assert "filter" not in link["body"] + assert "filter-lang" not in link["body"] + + +def test_passthrough_when_no_filter_state(app_with_middleware): + # Simulate no filter in request.state + with patch( + "stac_auth_proxy.middleware.Cql2RewriteLinksFilterMiddleware.Expr" + ) as MockExpr: + MockExpr.return_value = None + client = TestClient(app_with_middleware) + response = client.get("/test") + assert response.status_code == 200 + data = response.json() + # Should be unchanged (filter removed) + for link in data["links"]: + if "href" in link: + assert "filter=" not in link["href"] + if "body" in link: + assert "filter" not in link["body"] + assert "filter-lang" not in link["body"] + + +def test_non_json_response(app_with_middleware): + # Add a route that returns plain text + app = app_with_middleware + + @app.get("/plain") + async def plain(): + return Response(content="not json", media_type="text/plain") + + client = TestClient(app) + response = client.get("/plain") + assert response.status_code == 200 + assert response.text == "not json"