Skip to content

Commit 5a2cae0

Browse files
committed
Merge branch 'main' into fix/retain-proxy-headers-when-behind-proxy
2 parents b2eb13b + a942200 commit 5a2cae0

File tree

5 files changed

+450
-1
lines changed

5 files changed

+450
-1
lines changed

src/stac_auth_proxy/app.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Cql2ApplyFilterBodyMiddleware,
2121
Cql2ApplyFilterQueryStringMiddleware,
2222
Cql2BuildFilterMiddleware,
23+
Cql2RewriteLinksFilterMiddleware,
2324
Cql2ValidateResponseBodyMiddleware,
2425
EnforceAuthMiddleware,
2526
OpenApiMiddleware,
@@ -110,6 +111,7 @@ def configure_app(
110111
app.add_middleware(Cql2ValidateResponseBodyMiddleware)
111112
app.add_middleware(Cql2ApplyFilterBodyMiddleware)
112113
app.add_middleware(Cql2ApplyFilterQueryStringMiddleware)
114+
app.add_middleware(Cql2RewriteLinksFilterMiddleware)
113115
app.add_middleware(
114116
Cql2BuildFilterMiddleware,
115117
items_filter=settings.items_filter() if settings.items_filter else None,
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""Middleware to rewrite 'filter' in .links of the JSON response, removing the filter from the request state."""
2+
3+
import json
4+
from dataclasses import dataclass
5+
from logging import getLogger
6+
from typing import Optional
7+
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
8+
9+
from cql2 import Expr
10+
from starlette.requests import Request
11+
from starlette.types import ASGIApp, Message, Receive, Scope, Send
12+
13+
logger = getLogger(__name__)
14+
15+
16+
@dataclass(frozen=True)
17+
class Cql2RewriteLinksFilterMiddleware:
18+
"""ASGI middleware to rewrite 'filter' in .links of the JSON response, removing the filter from the request state."""
19+
20+
app: ASGIApp
21+
state_key: str = "cql2_filter"
22+
23+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
24+
"""Replace 'filter' in .links of the JSON response to state before we had applied the filter."""
25+
if scope["type"] != "http":
26+
return await self.app(scope, receive, send)
27+
28+
request = Request(scope)
29+
original_filter = request.query_params.get("filter")
30+
cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None)
31+
if cql2_filter is None:
32+
# No filter set, just pass through
33+
return await self.app(scope, receive, send)
34+
35+
# Intercept the response
36+
response_start = None
37+
body_chunks = []
38+
more_body = True
39+
40+
async def send_wrapper(message: Message):
41+
nonlocal response_start, body_chunks, more_body
42+
if message["type"] == "http.response.start":
43+
response_start = message
44+
elif message["type"] == "http.response.body":
45+
body_chunks.append(message.get("body", b""))
46+
more_body = message.get("more_body", False)
47+
if not more_body:
48+
await self._process_and_send_response(
49+
response_start, body_chunks, send, original_filter
50+
)
51+
else:
52+
await send(message)
53+
54+
await self.app(scope, receive, send_wrapper)
55+
56+
async def _process_and_send_response(
57+
self,
58+
response_start: Message,
59+
body_chunks: list[bytes],
60+
send: Send,
61+
original_filter: Optional[str],
62+
):
63+
body = b"".join(body_chunks)
64+
try:
65+
data = json.loads(body)
66+
except Exception:
67+
await send(response_start)
68+
await send({"type": "http.response.body", "body": body, "more_body": False})
69+
return
70+
71+
cql2_filter = Expr(original_filter) if original_filter else None
72+
links = data.get("links")
73+
if isinstance(links, list):
74+
for link in links:
75+
# Handle filter in query string
76+
if "href" in link:
77+
url = urlparse(link["href"])
78+
qs = parse_qs(url.query)
79+
if "filter" in qs:
80+
if cql2_filter:
81+
qs["filter"] = [cql2_filter.to_text()]
82+
else:
83+
qs.pop("filter", None)
84+
qs.pop("filter-lang", None)
85+
new_query = urlencode(qs, doseq=True)
86+
link["href"] = urlunparse(url._replace(query=new_query))
87+
88+
# Handle filter in body (for POST links)
89+
if "body" in link and isinstance(link["body"], dict):
90+
if "filter" in link["body"]:
91+
if cql2_filter:
92+
link["body"]["filter"] = cql2_filter.to_json()
93+
else:
94+
link["body"].pop("filter", None)
95+
link["body"].pop("filter-lang", None)
96+
97+
# Send the modified response
98+
new_body = json.dumps(data).encode("utf-8")
99+
100+
# Patch content-length
101+
headers = [
102+
(k, v) for k, v in response_start["headers"] if k != b"content-length"
103+
]
104+
headers.append((b"content-length", str(len(new_body)).encode("latin1")))
105+
response_start = dict(response_start)
106+
response_start["headers"] = headers
107+
await send(response_start)
108+
await send({"type": "http.response.body", "body": new_body, "more_body": False})

src/stac_auth_proxy/middleware/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .Cql2ApplyFilterBodyMiddleware import Cql2ApplyFilterBodyMiddleware
66
from .Cql2ApplyFilterQueryStringMiddleware import Cql2ApplyFilterQueryStringMiddleware
77
from .Cql2BuildFilterMiddleware import Cql2BuildFilterMiddleware
8+
from .Cql2RewriteLinksFilterMiddleware import Cql2RewriteLinksFilterMiddleware
89
from .Cql2ValidateResponseBodyMiddleware import Cql2ValidateResponseBodyMiddleware
910
from .EnforceAuthMiddleware import EnforceAuthMiddleware
1011
from .ProcessLinksMiddleware import ProcessLinksMiddleware
@@ -17,6 +18,7 @@
1718
"Cql2ApplyFilterBodyMiddleware",
1819
"Cql2ApplyFilterQueryStringMiddleware",
1920
"Cql2BuildFilterMiddleware",
21+
"Cql2RewriteLinksFilterMiddleware",
2022
"Cql2ValidateResponseBodyMiddleware",
2123
"EnforceAuthMiddleware",
2224
"OpenApiMiddleware",

0 commit comments

Comments
 (0)