Skip to content

feat: remove applied filters on response links #67

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/stac_auth_proxy/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Cql2ApplyFilterBodyMiddleware,
Cql2ApplyFilterQueryStringMiddleware,
Cql2BuildFilterMiddleware,
Cql2RewriteLinksFilterMiddleware,
Cql2ValidateResponseBodyMiddleware,
EnforceAuthMiddleware,
OpenApiMiddleware,
Expand Down Expand Up @@ -137,6 +138,7 @@ async def lifespan(app: FastAPI):
app.add_middleware(Cql2ValidateResponseBodyMiddleware)
app.add_middleware(Cql2ApplyFilterBodyMiddleware)
app.add_middleware(Cql2ApplyFilterQueryStringMiddleware)
app.add_middleware(Cql2RewriteLinksFilterMiddleware)
app.add_middleware(
Cql2BuildFilterMiddleware,
items_filter=settings.items_filter() if settings.items_filter else None,
Expand Down
108 changes: 108 additions & 0 deletions src/stac_auth_proxy/middleware/Cql2RewriteLinksFilterMiddleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Middleware to rewrite 'filter' in .links of the JSON response, removing the filter from the request state."""

import json
from dataclasses import dataclass
from logging import getLogger
from typing import Optional
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse

from cql2 import Expr
from starlette.requests import Request
from starlette.types import ASGIApp, Message, Receive, Scope, Send

logger = getLogger(__name__)


@dataclass(frozen=True)
class Cql2RewriteLinksFilterMiddleware:
"""ASGI middleware to rewrite 'filter' in .links of the JSON response, removing the filter from the request state."""

app: ASGIApp
state_key: str = "cql2_filter"

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Replace 'filter' in .links of the JSON response to state before we had applied the filter."""
if scope["type"] != "http":
return await self.app(scope, receive, send)

request = Request(scope)
original_filter = request.query_params.get("filter")
cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None)
if cql2_filter is None:
# No filter set, just pass through
return await self.app(scope, receive, send)

# Intercept the response
response_start = None
body_chunks = []
more_body = True

async def send_wrapper(message: Message):
nonlocal response_start, body_chunks, more_body
if message["type"] == "http.response.start":
response_start = message
elif message["type"] == "http.response.body":
body_chunks.append(message.get("body", b""))
more_body = message.get("more_body", False)
if not more_body:
await self._process_and_send_response(
response_start, body_chunks, send, original_filter
)
else:
await send(message)

await self.app(scope, receive, send_wrapper)

async def _process_and_send_response(
self,
response_start: Message,
body_chunks: list[bytes],
send: Send,
original_filter: Optional[str],
):
body = b"".join(body_chunks)
try:
data = json.loads(body)
except Exception:
await send(response_start)
await send({"type": "http.response.body", "body": body, "more_body": False})
return

cql2_filter = Expr(original_filter) if original_filter else None
links = data.get("links")
if isinstance(links, list):
for link in links:
# Handle filter in query string
if "href" in link:
url = urlparse(link["href"])
qs = parse_qs(url.query)
if "filter" in qs:
if cql2_filter:
qs["filter"] = [cql2_filter.to_text()]
else:
qs.pop("filter", None)
qs.pop("filter-lang", None)
new_query = urlencode(qs, doseq=True)
link["href"] = urlunparse(url._replace(query=new_query))

# Handle filter in body (for POST links)
if "body" in link and isinstance(link["body"], dict):
if "filter" in link["body"]:
if cql2_filter:
link["body"]["filter"] = cql2_filter.to_json()
else:
link["body"].pop("filter", None)
link["body"].pop("filter-lang", None)

# Send the modified response
new_body = json.dumps(data).encode("utf-8")

# Patch content-length
headers = [
(k, v) for k, v in response_start["headers"] if k != b"content-length"
]
headers.append((b"content-length", str(len(new_body)).encode("latin1")))
response_start = dict(response_start)
response_start["headers"] = headers
await send(response_start)
await send({"type": "http.response.body", "body": new_body, "more_body": False})
2 changes: 2 additions & 0 deletions src/stac_auth_proxy/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .Cql2ApplyFilterBodyMiddleware import Cql2ApplyFilterBodyMiddleware
from .Cql2ApplyFilterQueryStringMiddleware import Cql2ApplyFilterQueryStringMiddleware
from .Cql2BuildFilterMiddleware import Cql2BuildFilterMiddleware
from .Cql2RewriteLinksFilterMiddleware import Cql2RewriteLinksFilterMiddleware
from .Cql2ValidateResponseBodyMiddleware import Cql2ValidateResponseBodyMiddleware
from .EnforceAuthMiddleware import EnforceAuthMiddleware
from .ProcessLinksMiddleware import ProcessLinksMiddleware
Expand All @@ -17,6 +18,7 @@
"Cql2ApplyFilterBodyMiddleware",
"Cql2ApplyFilterQueryStringMiddleware",
"Cql2BuildFilterMiddleware",
"Cql2RewriteLinksFilterMiddleware",
"Cql2ValidateResponseBodyMiddleware",
"EnforceAuthMiddleware",
"OpenApiMiddleware",
Expand Down
110 changes: 110 additions & 0 deletions tests/test_cql2_rewrite_links_filter_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from unittest.mock import MagicMock, patch

import pytest
from fastapi import FastAPI, Request, Response
from starlette.testclient import TestClient

from stac_auth_proxy.middleware.Cql2RewriteLinksFilterMiddleware import (
Cql2RewriteLinksFilterMiddleware,
)


@pytest.fixture
def app_with_middleware():
app = FastAPI()
app.add_middleware(Cql2RewriteLinksFilterMiddleware)

@app.get("/test")
async def test_endpoint(request: Request):
# Simulate a response with links containing a filter in the query and body
return {
"links": [
{
"rel": "self",
"href": "http://example.com/search?filter=foo&filter-lang=cql2-text",
},
{
"rel": "post",
"body": {"filter": "foo", "filter-lang": "cql2-json"},
},
]
}

return app


def test_rewrite_links_with_filter(app_with_middleware):
# Patch cql2.Expr to simulate to_text and to_json
with patch(
"stac_auth_proxy.middleware.Cql2RewriteLinksFilterMiddleware.Expr"
) as MockExpr:
mock_expr = MagicMock()
mock_expr.to_text.return_value = "bar"
mock_expr.to_json.return_value = {"foo": "bar"}
MockExpr.return_value = mock_expr

client = TestClient(app_with_middleware)
response = client.get("/test?filter=foo")
assert response.status_code == 200
data = response.json()
# The filter in the href should be rewritten
assert any(
"filter=bar" in link["href"] for link in data["links"] if "href" in link
)
# The filter in the body should be rewritten
assert any(
link.get("body", {}).get("filter") == {"foo": "bar"}
for link in data["links"]
)


def test_remove_filter_from_links(app_with_middleware):
# Patch cql2.Expr to return None (no filter)
with patch(
"stac_auth_proxy.middleware.Cql2RewriteLinksFilterMiddleware.Expr"
) as MockExpr:
MockExpr.return_value = None
client = TestClient(app_with_middleware)
response = client.get("/test")
assert response.status_code == 200
data = response.json()
# The filter should be removed from href and body
for link in data["links"]:
if "href" in link:
assert "filter=" not in link["href"]
if "body" in link:
assert "filter" not in link["body"]
assert "filter-lang" not in link["body"]


def test_passthrough_when_no_filter_state(app_with_middleware):
# Simulate no filter in request.state
with patch(
"stac_auth_proxy.middleware.Cql2RewriteLinksFilterMiddleware.Expr"
) as MockExpr:
MockExpr.return_value = None
client = TestClient(app_with_middleware)
response = client.get("/test")
assert response.status_code == 200
data = response.json()
# Should be unchanged (filter removed)
for link in data["links"]:
if "href" in link:
assert "filter=" not in link["href"]
if "body" in link:
assert "filter" not in link["body"]
assert "filter-lang" not in link["body"]


def test_non_json_response(app_with_middleware):
# Add a route that returns plain text
app = app_with_middleware

@app.get("/plain")
async def plain():
return Response(content="not json", media_type="text/plain")

client = TestClient(app)
response = client.get("/plain")
assert response.status_code == 200
assert response.text == "not json"
Loading