From 7a4e9843d6a7790e68117429a31378864da09e65 Mon Sep 17 00:00:00 2001 From: Anthony Lukach Date: Mon, 21 Jul 2025 22:49:16 -0700 Subject: [PATCH 1/5] fix: handle empty search body --- src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py b/src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py index aa4f8d5..871fe6e 100644 --- a/src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py +++ b/src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py @@ -89,7 +89,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # Modify body try: - body = json.loads(body) + body = json.loads(body) if body else {} except json.JSONDecodeError as e: logger.warning("Failed to parse request body as JSON") # TODO: Return a 400 error From 7e4729ab43194232d89c5870f6b1f68abbe60401 Mon Sep 17 00:00:00 2001 From: Anthony Lukach Date: Mon, 21 Jul 2025 23:27:19 -0700 Subject: [PATCH 2/5] refactor: breakup CQL2 middleware --- src/stac_auth_proxy/app.py | 14 +- .../middleware/ApplyCql2FilterMiddleware.py | 202 ------------------ .../Cql2ApplyFilterBodyMiddleware.py | 98 +++++++++ .../Cql2ApplyFilterQueryStringMiddleware.py | 56 +++++ ...leware.py => Cql2BuildFilterMiddleware.py} | 2 +- .../Cql2ValidateResponseBodyMiddleware.py | 133 ++++++++++++ src/stac_auth_proxy/middleware/__init__.py | 14 +- 7 files changed, 305 insertions(+), 214 deletions(-) delete mode 100644 src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py create mode 100644 src/stac_auth_proxy/middleware/Cql2ApplyFilterBodyMiddleware.py create mode 100644 src/stac_auth_proxy/middleware/Cql2ApplyFilterQueryStringMiddleware.py rename src/stac_auth_proxy/middleware/{BuildCql2FilterMiddleware.py => Cql2BuildFilterMiddleware.py} (99%) create mode 100644 src/stac_auth_proxy/middleware/Cql2ValidateResponseBodyMiddleware.py diff --git a/src/stac_auth_proxy/app.py b/src/stac_auth_proxy/app.py index 6329ab1..eb62eda 100644 --- a/src/stac_auth_proxy/app.py +++ b/src/stac_auth_proxy/app.py @@ -16,9 +16,11 @@ from .handlers import HealthzHandler, ReverseProxyHandler, SwaggerUI from .middleware import ( AddProcessTimeHeaderMiddleware, - ApplyCql2FilterMiddleware, AuthenticationExtensionMiddleware, - BuildCql2FilterMiddleware, + Cql2ApplyFilterBodyMiddleware, + Cql2ApplyFilterQueryStringMiddleware, + Cql2BuildFilterMiddleware, + Cql2ValidateResponseBodyMiddleware, EnforceAuthMiddleware, OpenApiMiddleware, ProcessLinksMiddleware, @@ -132,11 +134,11 @@ async def lifespan(app: FastAPI): ) if settings.items_filter or settings.collections_filter: + app.add_middleware(Cql2ValidateResponseBodyMiddleware) + app.add_middleware(Cql2ApplyFilterBodyMiddleware) + app.add_middleware(Cql2ApplyFilterQueryStringMiddleware) app.add_middleware( - ApplyCql2FilterMiddleware, - ) - app.add_middleware( - BuildCql2FilterMiddleware, + Cql2BuildFilterMiddleware, items_filter=settings.items_filter() if settings.items_filter else None, collections_filter=( settings.collections_filter() if settings.collections_filter else None diff --git a/src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py b/src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py deleted file mode 100644 index 871fe6e..0000000 --- a/src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py +++ /dev/null @@ -1,202 +0,0 @@ -"""Middleware to apply CQL2 filters.""" - -import json -import re -from dataclasses import dataclass -from logging import getLogger -from typing import Optional - -from cql2 import Expr -from starlette.datastructures import MutableHeaders -from starlette.requests import Request -from starlette.types import ASGIApp, Message, Receive, Scope, Send - -from ..utils import filters -from ..utils.middleware import required_conformance - -logger = getLogger(__name__) - - -@required_conformance( - r"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2", - r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-text", - r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-json", -) -@dataclass(frozen=True) -class ApplyCql2FilterMiddleware: - """Middleware to apply the Cql2Filter to the request.""" - - app: ASGIApp - state_key: str = "cql2_filter" - - single_record_endpoints = [ - r"^/collections/([^/]+)/items/([^/]+)$", - r"^/collections/([^/]+)$", - ] - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - """Add the Cql2Filter to the request.""" - if scope["type"] != "http": - return await self.app(scope, receive, send) - - request = Request(scope) - - cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None) - - if not cql2_filter: - return await self.app(scope, receive, send) - - # Handle POST, PUT, PATCH - if request.method in ["POST", "PUT", "PATCH"]: - req_body_handler = Cql2RequestBodyAugmentor( - app=self.app, - cql2_filter=cql2_filter, - ) - return await req_body_handler(scope, receive, send) - - # Handle single record requests (ie non-filterable endpoints) - if any( - re.match(expr, request.url.path) for expr in self.single_record_endpoints - ): - res_body_validator = Cql2ResponseBodyValidator( - app=self.app, - cql2_filter=cql2_filter, - ) - return await res_body_validator(scope, send, receive) - - scope["query_string"] = filters.append_qs_filter(request.url.query, cql2_filter) - return await self.app(scope, receive, send) - - -@dataclass(frozen=True) -class Cql2RequestBodyAugmentor: - """Handler to augment the request body with a CQL2 filter.""" - - app: ASGIApp - cql2_filter: Expr - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - """Augment the request body with a CQL2 filter.""" - body = b"" - more_body = True - - # Read the body - while more_body: - message = await receive() - if message["type"] == "http.request": - body += message.get("body", b"") - more_body = message.get("more_body", False) - - # Modify body - try: - body = json.loads(body) if body else {} - except json.JSONDecodeError as e: - logger.warning("Failed to parse request body as JSON") - # TODO: Return a 400 error - raise e - - # Augment the body - assert isinstance(body, dict), "Request body must be a JSON object" - new_body = json.dumps( - filters.append_body_filter(body, self.cql2_filter) - ).encode("utf-8") - - # Patch content-length in the headers - headers = dict(scope["headers"]) - headers[b"content-length"] = str(len(new_body)).encode("latin1") - scope["headers"] = list(headers.items()) - - async def new_receive(): - return { - "type": "http.request", - "body": new_body, - "more_body": False, - } - - await self.app(scope, new_receive, send) - - -@dataclass -class Cql2ResponseBodyValidator: - """Handler to validate response body with CQL2.""" - - app: ASGIApp - cql2_filter: Expr - - async def __call__(self, scope: Scope, send: Send, receive: Receive) -> None: - """Process a response message and apply filtering if needed.""" - if scope["type"] != "http": - return await self.app(scope, send, receive) - - body = b"" - initial_message: Optional[Message] = None - - async def _send_error_response(status: int, code: str, message: str) -> None: - """Send an error response with the given status and message.""" - assert initial_message, "Initial message not set" - response_dict = { - "code": code, - "description": message, - } - response_bytes = json.dumps(response_dict).encode("utf-8") - headers = MutableHeaders(scope=initial_message) - headers["content-length"] = str(len(response_bytes)) - initial_message["status"] = status - await send(initial_message) - await send( - { - "type": "http.response.body", - "body": response_bytes, - "more_body": False, - } - ) - - async def buffered_send(message: Message) -> None: - """Process a response message and apply filtering if needed.""" - nonlocal body - nonlocal initial_message - initial_message = initial_message or message - # NOTE: to avoid data-leak, we process 404s so their responses are the same as rejected 200s - should_process = initial_message["status"] in [200, 404] - - if not should_process: - return await send(message) - - if message["type"] == "http.response.start": - # Hold off on sending response headers until we've validated the response body - return - - body += message["body"] - if message.get("more_body"): - return - - try: - body_json = json.loads(body) - except json.JSONDecodeError: - msg = "Failed to parse response body as JSON" - logger.warning(msg) - await _send_error_response(status=502, code="ParseError", message=msg) - return - - try: - cql2_matches = self.cql2_filter.matches(body_json) - except Exception as e: - cql2_matches = False - logger.warning("Failed to apply filter: %s", e) - - if cql2_matches: - logger.debug("Response matches filter, returning record") - await send(initial_message) - return await send( - { - "type": "http.response.body", - "body": json.dumps(body_json).encode("utf-8"), - "more_body": False, - } - ) - logger.debug("Response did not match filter, returning 404") - return await _send_error_response( - status=404, code="NotFoundError", message="Record not found." - ) - - return await self.app(scope, receive, buffered_send) diff --git a/src/stac_auth_proxy/middleware/Cql2ApplyFilterBodyMiddleware.py b/src/stac_auth_proxy/middleware/Cql2ApplyFilterBodyMiddleware.py new file mode 100644 index 0000000..b1d46d4 --- /dev/null +++ b/src/stac_auth_proxy/middleware/Cql2ApplyFilterBodyMiddleware.py @@ -0,0 +1,98 @@ +"""Middleware to augment the request body with a CQL2 filter for POST/PUT/PATCH requests.""" + +import json +from dataclasses import dataclass +from logging import getLogger +from typing import Optional + +from cql2 import Expr +from starlette.requests import Request +from starlette.types import ASGIApp, Receive, Scope, Send + +from ..utils import filters +from ..utils.middleware import required_conformance + +logger = getLogger(__name__) + + +@required_conformance( + r"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2", + r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-text", + r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-json", +) +@dataclass(frozen=True) +class Cql2ApplyFilterBodyMiddleware: + """Middleware to augment the request body with a CQL2 filter for POST/PUT/PATCH requests.""" + + app: ASGIApp + state_key: str = "cql2_filter" + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Apply the CQL2 filter to the request body.""" + if scope["type"] != "http": + return await self.app(scope, receive, send) + + request = Request(scope) + cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None) + if not cql2_filter: + return await self.app(scope, receive, send) + + if request.method not in ["POST", "PUT", "PATCH"]: + return await self.app(scope, receive, send) + + body = b"" + more_body = True + while more_body: + message = await receive() + if message["type"] == "http.request": + body += message.get("body", b"") + more_body = message.get("more_body", False) + + try: + body_json = json.loads(body) if body else {} + except json.JSONDecodeError: + logger.warning("Failed to parse request body as JSON") + from starlette.responses import JSONResponse + + response = JSONResponse( + { + "code": "ParseError", + "description": "Request body must be valid JSON.", + }, + status_code=400, + ) + await response(scope, receive, send) + return + + if not isinstance(body_json, dict): + logger.warning("Request body must be a JSON object") + from starlette.responses import JSONResponse + + response = JSONResponse( + { + "code": "TypeError", + "description": "Request body must be a JSON object.", + }, + status_code=400, + ) + await response(scope, receive, send) + return + + new_body = json.dumps( + filters.append_body_filter(body_json, cql2_filter) + ).encode("utf-8") + + # Patch content-length in the headers + headers = dict(scope["headers"]) + headers[b"content-length"] = str(len(new_body)).encode("latin1") + scope = dict(scope) + scope["headers"] = list(headers.items()) + + async def new_receive(): + return { + "type": "http.request", + "body": new_body, + "more_body": False, + } + + await self.app(scope, new_receive, send) diff --git a/src/stac_auth_proxy/middleware/Cql2ApplyFilterQueryStringMiddleware.py b/src/stac_auth_proxy/middleware/Cql2ApplyFilterQueryStringMiddleware.py new file mode 100644 index 0000000..539731e --- /dev/null +++ b/src/stac_auth_proxy/middleware/Cql2ApplyFilterQueryStringMiddleware.py @@ -0,0 +1,56 @@ +"""Middleware to inject CQL2 filters into the query string for GET/list endpoints.""" + +import re +from dataclasses import dataclass +from logging import getLogger +from typing import Optional + +from cql2 import Expr +from starlette.requests import Request +from starlette.types import ASGIApp, Receive, Scope, Send + +from ..utils import filters +from ..utils.middleware import required_conformance + +logger = getLogger(__name__) + + +@required_conformance( + r"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2", + r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-text", + r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-json", +) +@dataclass(frozen=True) +class Cql2ApplyFilterQueryStringMiddleware: + """Middleware to inject CQL2 filters into the query string for GET/list endpoints.""" + + app: ASGIApp + state_key: str = "cql2_filter" + + single_record_endpoints = [ + r"^/collections/([^/]+)/items/([^/]+)$", + r"^/collections/([^/]+)$", + ] + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Apply the CQL2 filter to the query string.""" + if scope["type"] != "http": + return await self.app(scope, receive, send) + + request = Request(scope) + cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None) + if not cql2_filter: + return await self.app(scope, receive, send) + + # Only handle GET requests that are not single-record endpoints + if request.method != "GET": + return await self.app(scope, receive, send) + if any( + re.match(expr, request.url.path) for expr in self.single_record_endpoints + ): + return await self.app(scope, receive, send) + + # Inject filter into query string + scope = dict(scope) + scope["query_string"] = filters.append_qs_filter(request.url.query, cql2_filter) + return await self.app(scope, receive, send) diff --git a/src/stac_auth_proxy/middleware/BuildCql2FilterMiddleware.py b/src/stac_auth_proxy/middleware/Cql2BuildFilterMiddleware.py similarity index 99% rename from src/stac_auth_proxy/middleware/BuildCql2FilterMiddleware.py rename to src/stac_auth_proxy/middleware/Cql2BuildFilterMiddleware.py index cfa153d..03083e6 100644 --- a/src/stac_auth_proxy/middleware/BuildCql2FilterMiddleware.py +++ b/src/stac_auth_proxy/middleware/Cql2BuildFilterMiddleware.py @@ -22,7 +22,7 @@ "http://www.opengis.net/spec/cql2/1.0/conf/cql2-json", ) @dataclass(frozen=True) -class BuildCql2FilterMiddleware: +class Cql2BuildFilterMiddleware: """Middleware to build the Cql2Filter.""" app: ASGIApp diff --git a/src/stac_auth_proxy/middleware/Cql2ValidateResponseBodyMiddleware.py b/src/stac_auth_proxy/middleware/Cql2ValidateResponseBodyMiddleware.py new file mode 100644 index 0000000..c55a9a0 --- /dev/null +++ b/src/stac_auth_proxy/middleware/Cql2ValidateResponseBodyMiddleware.py @@ -0,0 +1,133 @@ +"""Middleware to validate the response body with a CQL2 filter for single-record endpoints.""" + +import json +import re +from dataclasses import dataclass +from logging import getLogger +from typing import Optional + +from cql2 import Expr +from starlette.requests import Request +from starlette.types import ASGIApp, Message, Receive, Scope, Send + +from ..utils.middleware import required_conformance + +logger = getLogger(__name__) + + +@required_conformance( + r"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2", + r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-text", + r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-json", +) +@dataclass +class Cql2ValidateResponseBodyMiddleware: + """ASGI middleware to validate the response body with a CQL2 filter for single-record endpoints.""" + + app: ASGIApp + state_key: str = "cql2_filter" + + single_record_endpoints = [ + r"^/collections/([^/]+)/items/([^/]+)$", + r"^/collections/([^/]+)$", + ] + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Validate the response body with a CQL2 filter for single-record endpoints.""" + if scope["type"] != "http": + return await self.app(scope, receive, send) + + request = Request(scope) + cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None) + if not cql2_filter: + return await self.app(scope, receive, send) + + if not any( + re.match(expr, request.url.path) for expr in self.single_record_endpoints + ): + return await self.app(scope, receive, send) + + # Intercept the response + response_start = None + body_chunks = [] + more_body = True + + async def send_wrapper(message: Message): + nonlocal response_start, body_chunks, more_body + if message["type"] == "http.response.start": + response_start = message + elif message["type"] == "http.response.body": + body_chunks.append(message.get("body", b"")) + more_body = message.get("more_body", False) + if not more_body: + await self._process_and_send_response( + response_start, body_chunks, send, cql2_filter + ) + else: + await send(message) + + await self.app(scope, receive, send_wrapper) + + async def _process_and_send_response( + self, response_start, body_chunks, send, cql2_filter + ): + body = b"".join(body_chunks) + try: + body_json = json.loads(body) + except json.JSONDecodeError: + logger.warning("Failed to parse response body as JSON") + await self._send_json_response( + send, + status=502, + content={ + "code": "ParseError", + "description": "Failed to parse response body as JSON", + }, + ) + return + + try: + cql2_matches = cql2_filter.matches(body_json) + except Exception as e: + cql2_matches = False + logger.warning("Failed to apply filter: %s", e) + + if cql2_matches: + logger.debug("Response matches filter, returning record") + # Send the original response start + await send(response_start) + # Send the filtered body + await send( + { + "type": "http.response.body", + "body": json.dumps(body_json).encode("utf-8"), + "more_body": False, + } + ) + else: + logger.debug("Response did not match filter, returning 404") + await self._send_json_response( + send, + status=404, + content={"code": "NotFoundError", "description": "Record not found."}, + ) + + async def _send_json_response(self, send, status, content): + response_bytes = json.dumps(content).encode("utf-8") + await send( + { + "type": "http.response.start", + "status": status, + "headers": [ + (b"content-type", b"application/json"), + (b"content-length", str(len(response_bytes)).encode("latin1")), + ], + } + ) + await send( + { + "type": "http.response.body", + "body": response_bytes, + "more_body": False, + } + ) diff --git a/src/stac_auth_proxy/middleware/__init__.py b/src/stac_auth_proxy/middleware/__init__.py index 6ad1875..bc1ae4f 100644 --- a/src/stac_auth_proxy/middleware/__init__.py +++ b/src/stac_auth_proxy/middleware/__init__.py @@ -1,9 +1,11 @@ """Custom middleware.""" from .AddProcessTimeHeaderMiddleware import AddProcessTimeHeaderMiddleware -from .ApplyCql2FilterMiddleware import ApplyCql2FilterMiddleware from .AuthenticationExtensionMiddleware import AuthenticationExtensionMiddleware -from .BuildCql2FilterMiddleware import BuildCql2FilterMiddleware +from .Cql2ApplyFilterBodyMiddleware import Cql2ApplyFilterBodyMiddleware +from .Cql2ApplyFilterQueryStringMiddleware import Cql2ApplyFilterQueryStringMiddleware +from .Cql2BuildFilterMiddleware import Cql2BuildFilterMiddleware +from .Cql2ValidateResponseBodyMiddleware import Cql2ValidateResponseBodyMiddleware from .EnforceAuthMiddleware import EnforceAuthMiddleware from .ProcessLinksMiddleware import ProcessLinksMiddleware from .RemoveRootPathMiddleware import RemoveRootPathMiddleware @@ -11,11 +13,13 @@ __all__ = [ "AddProcessTimeHeaderMiddleware", - "ApplyCql2FilterMiddleware", "AuthenticationExtensionMiddleware", - "BuildCql2FilterMiddleware", + "Cql2ApplyFilterBodyMiddleware", + "Cql2ApplyFilterQueryStringMiddleware", + "Cql2BuildFilterMiddleware", + "Cql2ValidateResponseBodyMiddleware", "EnforceAuthMiddleware", + "OpenApiMiddleware", "ProcessLinksMiddleware", "RemoveRootPathMiddleware", - "OpenApiMiddleware", ] From 729104e4ed6ffc3b28f6830544d9d47add76f10e Mon Sep 17 00:00:00 2001 From: Anthony Lukach Date: Tue, 22 Jul 2025 08:22:11 -0700 Subject: [PATCH 3/5] feat: remove added filters from response links #65 --- src/stac_auth_proxy/app.py | 2 + .../Cql2RewriteLinksFilterMiddleware.py | 108 ++++++++++++++++++ src/stac_auth_proxy/middleware/__init__.py | 2 + 3 files changed, 112 insertions(+) create mode 100644 src/stac_auth_proxy/middleware/Cql2RewriteLinksFilterMiddleware.py diff --git a/src/stac_auth_proxy/app.py b/src/stac_auth_proxy/app.py index eb62eda..313dc51 100644 --- a/src/stac_auth_proxy/app.py +++ b/src/stac_auth_proxy/app.py @@ -20,6 +20,7 @@ Cql2ApplyFilterBodyMiddleware, Cql2ApplyFilterQueryStringMiddleware, Cql2BuildFilterMiddleware, + Cql2RewriteLinksFilterMiddleware, Cql2ValidateResponseBodyMiddleware, EnforceAuthMiddleware, OpenApiMiddleware, @@ -137,6 +138,7 @@ async def lifespan(app: FastAPI): app.add_middleware(Cql2ValidateResponseBodyMiddleware) app.add_middleware(Cql2ApplyFilterBodyMiddleware) app.add_middleware(Cql2ApplyFilterQueryStringMiddleware) + app.add_middleware(Cql2RewriteLinksFilterMiddleware) app.add_middleware( Cql2BuildFilterMiddleware, items_filter=settings.items_filter() if settings.items_filter else None, diff --git a/src/stac_auth_proxy/middleware/Cql2RewriteLinksFilterMiddleware.py b/src/stac_auth_proxy/middleware/Cql2RewriteLinksFilterMiddleware.py new file mode 100644 index 0000000..1909bfd --- /dev/null +++ b/src/stac_auth_proxy/middleware/Cql2RewriteLinksFilterMiddleware.py @@ -0,0 +1,108 @@ +"""Middleware to rewrite 'filter' in .links of the JSON response, removing the filter from the request state.""" + +import json +from dataclasses import dataclass +from logging import getLogger +from typing import Optional +from urllib.parse import parse_qs, urlencode, urlparse, urlunparse + +from cql2 import Expr +from starlette.requests import Request +from starlette.types import ASGIApp, Message, Receive, Scope, Send + +logger = getLogger(__name__) + + +@dataclass(frozen=True) +class Cql2RewriteLinksFilterMiddleware: + """ASGI middleware to rewrite 'filter' in .links of the JSON response, removing the filter from the request state.""" + + app: ASGIApp + state_key: str = "cql2_filter" + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Replace 'filter' in .links of the JSON response to state before we had applied the filter.""" + if scope["type"] != "http": + return await self.app(scope, receive, send) + + request = Request(scope) + original_filter = request.query_params.get("filter") + cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None) + if cql2_filter is None: + # No filter set, just pass through + return await self.app(scope, receive, send) + + # Intercept the response + response_start = None + body_chunks = [] + more_body = True + + async def send_wrapper(message: Message): + nonlocal response_start, body_chunks, more_body + if message["type"] == "http.response.start": + response_start = message + elif message["type"] == "http.response.body": + body_chunks.append(message.get("body", b"")) + more_body = message.get("more_body", False) + if not more_body: + await self._process_and_send_response( + response_start, body_chunks, send, original_filter + ) + else: + await send(message) + + await self.app(scope, receive, send_wrapper) + + async def _process_and_send_response( + self, + response_start: Message, + body_chunks: list[bytes], + send: Send, + original_filter: Optional[str], + ): + body = b"".join(body_chunks) + try: + data = json.loads(body) + except Exception: + await send(response_start) + await send({"type": "http.response.body", "body": body, "more_body": False}) + return + + cql2_filter = Expr(original_filter) if original_filter else None + links = data.get("links") + if isinstance(links, list): + for link in links: + # Handle filter in query string + if "href" in link: + url = urlparse(link["href"]) + qs = parse_qs(url.query) + if "filter" in qs: + if cql2_filter: + qs["filter"] = [cql2_filter.to_text()] + else: + qs.pop("filter", None) + qs.pop("filter-lang", None) + new_query = urlencode(qs, doseq=True) + link["href"] = urlunparse(url._replace(query=new_query)) + + # Handle filter in body (for POST links) + if "body" in link and isinstance(link["body"], dict): + if "filter" in link["body"]: + if cql2_filter: + link["body"]["filter"] = cql2_filter.to_json() + else: + link["body"].pop("filter", None) + link["body"].pop("filter-lang", None) + + # Send the modified response + new_body = json.dumps(data).encode("utf-8") + + # Patch content-length + headers = [ + (k, v) for k, v in response_start["headers"] if k != b"content-length" + ] + headers.append((b"content-length", str(len(new_body)).encode("latin1"))) + response_start = dict(response_start) + response_start["headers"] = headers + await send(response_start) + await send({"type": "http.response.body", "body": new_body, "more_body": False}) diff --git a/src/stac_auth_proxy/middleware/__init__.py b/src/stac_auth_proxy/middleware/__init__.py index bc1ae4f..c5dc005 100644 --- a/src/stac_auth_proxy/middleware/__init__.py +++ b/src/stac_auth_proxy/middleware/__init__.py @@ -5,6 +5,7 @@ from .Cql2ApplyFilterBodyMiddleware import Cql2ApplyFilterBodyMiddleware from .Cql2ApplyFilterQueryStringMiddleware import Cql2ApplyFilterQueryStringMiddleware from .Cql2BuildFilterMiddleware import Cql2BuildFilterMiddleware +from .Cql2RewriteLinksFilterMiddleware import Cql2RewriteLinksFilterMiddleware from .Cql2ValidateResponseBodyMiddleware import Cql2ValidateResponseBodyMiddleware from .EnforceAuthMiddleware import EnforceAuthMiddleware from .ProcessLinksMiddleware import ProcessLinksMiddleware @@ -17,6 +18,7 @@ "Cql2ApplyFilterBodyMiddleware", "Cql2ApplyFilterQueryStringMiddleware", "Cql2BuildFilterMiddleware", + "Cql2RewriteLinksFilterMiddleware", "Cql2ValidateResponseBodyMiddleware", "EnforceAuthMiddleware", "OpenApiMiddleware", From 771b02b9293b3130e2af46262eac4f531c515e0a Mon Sep 17 00:00:00 2001 From: Anthony Lukach Date: Thu, 24 Jul 2025 11:34:28 -0700 Subject: [PATCH 4/5] in progress --- ...st_cql2_rewrite_links_filter_middleware.py | 110 ++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 tests/test_cql2_rewrite_links_filter_middleware.py diff --git a/tests/test_cql2_rewrite_links_filter_middleware.py b/tests/test_cql2_rewrite_links_filter_middleware.py new file mode 100644 index 0000000..d80e238 --- /dev/null +++ b/tests/test_cql2_rewrite_links_filter_middleware.py @@ -0,0 +1,110 @@ +from unittest.mock import patch, MagicMock + +import pytest +from fastapi import FastAPI, Request, Response +from starlette.testclient import TestClient + +from stac_auth_proxy.middleware.Cql2RewriteLinksFilterMiddleware import ( + Cql2RewriteLinksFilterMiddleware, +) + + +@pytest.fixture +def app_with_middleware(): + app = FastAPI() + app.add_middleware(Cql2RewriteLinksFilterMiddleware) + + @app.get("/test") + async def test_endpoint(request: Request): + # Simulate a response with links containing a filter in the query and body + return { + "links": [ + { + "rel": "self", + "href": "http://example.com/search?filter=foo&filter-lang=cql2-text", + }, + { + "rel": "post", + "body": {"filter": "foo", "filter-lang": "cql2-json"}, + }, + ] + } + + return app + + +def test_rewrite_links_with_filter(app_with_middleware): + # Patch cql2.Expr to simulate to_text and to_json + with patch( + "stac_auth_proxy.middleware.Cql2RewriteLinksFilterMiddleware.Expr" + ) as MockExpr: + mock_expr = MagicMock() + mock_expr.to_text.return_value = "bar" + mock_expr.to_json.return_value = {"foo": "bar"} + MockExpr.return_value = mock_expr + + client = TestClient(app_with_middleware) + response = client.get("/test?filter=foo") + assert response.status_code == 200 + data = response.json() + # The filter in the href should be rewritten + assert any( + "filter=bar" in link["href"] for link in data["links"] if "href" in link + ) + # The filter in the body should be rewritten + assert any( + link.get("body", {}).get("filter") == {"foo": "bar"} + for link in data["links"] + ) + + +def test_remove_filter_from_links(app_with_middleware): + # Patch cql2.Expr to return None (no filter) + with patch( + "stac_auth_proxy.middleware.Cql2RewriteLinksFilterMiddleware.Expr" + ) as MockExpr: + MockExpr.return_value = None + client = TestClient(app_with_middleware) + response = client.get("/test") + assert response.status_code == 200 + data = response.json() + # The filter should be removed from href and body + for link in data["links"]: + if "href" in link: + assert "filter=" not in link["href"] + if "body" in link: + assert "filter" not in link["body"] + assert "filter-lang" not in link["body"] + + +def test_passthrough_when_no_filter_state(app_with_middleware): + # Simulate no filter in request.state + with patch( + "stac_auth_proxy.middleware.Cql2RewriteLinksFilterMiddleware.Expr" + ) as MockExpr: + MockExpr.return_value = None + client = TestClient(app_with_middleware) + response = client.get("/test") + assert response.status_code == 200 + data = response.json() + # Should be unchanged (filter removed) + for link in data["links"]: + if "href" in link: + assert "filter=" not in link["href"] + if "body" in link: + assert "filter" not in link["body"] + assert "filter-lang" not in link["body"] + + +def test_non_json_response(app_with_middleware): + # Add a route that returns plain text + app = app_with_middleware + + @app.get("/plain") + async def plain(): + return Response(content="not json", media_type="text/plain") + + client = TestClient(app) + response = client.get("/plain") + assert response.status_code == 200 + assert response.text == "not json" From ca5eafabcaf6824db501aa22672f7aed6a92dfeb Mon Sep 17 00:00:00 2001 From: Anthony Lukach Date: Thu, 24 Jul 2025 11:34:38 -0700 Subject: [PATCH 5/5] in progress --- tests/test_cql2_rewrite_links_filter_middleware.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_cql2_rewrite_links_filter_middleware.py b/tests/test_cql2_rewrite_links_filter_middleware.py index d80e238..c81649e 100644 --- a/tests/test_cql2_rewrite_links_filter_middleware.py +++ b/tests/test_cql2_rewrite_links_filter_middleware.py @@ -1,4 +1,4 @@ -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch import pytest from fastapi import FastAPI, Request, Response