Skip to content

Commit afa05c9

Browse files
committed
Take care to separate GET and POST middleware behavior
1 parent f8ae991 commit afa05c9

File tree

4 files changed

+91
-64
lines changed

4 files changed

+91
-64
lines changed

src/stac_auth_proxy/middleware/Cql2FilterMiddleware.py

Lines changed: 59 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,35 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
2929
if scope["type"] != "http":
3030
return await self.app(scope, receive, send)
3131

32+
request = Request(scope)
33+
34+
filter_builder = self._get_filter(request.url.path)
35+
if not filter_builder:
36+
return await self.app(scope, receive, send)
37+
38+
async def set_filter(body: Optional[dict] = None) -> None:
39+
cql2_filter = await filter_builder(
40+
{
41+
"req": {
42+
"path": request.url.path,
43+
"method": request.method,
44+
"query_params": dict(request.query_params),
45+
"path_params": requests.extract_variables(request.url.path),
46+
"headers": dict(request.headers),
47+
"body": body,
48+
},
49+
**request.state._state,
50+
}
51+
)
52+
cql2_filter.validate()
53+
scope["state"][FILTER_STATE_KEY] = cql2_filter
54+
55+
# For GET requests, we can build the filter immediately
56+
# NOTE: It appears that FastAPI will not call receive function for GET requests
57+
if request.method == "GET":
58+
await set_filter()
59+
return await self.app(scope, receive, send)
60+
3261
total_body = b""
3362

3463
async def receive_build_filter() -> Message:
@@ -38,26 +67,7 @@ async def receive_build_filter() -> Message:
3867
total_body += message.get("body", b"")
3968

4069
if not message.get("more_body"):
41-
request = Request(scope)
42-
filter_builder = self._get_filter(request.url.path)
43-
if filter_builder:
44-
cql2_filter = await filter_builder(
45-
{
46-
"req": {
47-
"path": request.url.path,
48-
"method": request.method,
49-
"query_params": dict(request.query_params),
50-
"path_params": requests.extract_variables(
51-
request.url.path
52-
),
53-
"headers": dict(request.headers),
54-
"body": json.loads(total_body),
55-
},
56-
**request.state._state,
57-
}
58-
)
59-
cql2_filter.validate()
60-
scope["state"][FILTER_STATE_KEY] = cql2_filter
70+
await set_filter(json.loads(total_body) if total_body else None)
6171
return message
6272

6373
return await self.app(scope, receive_build_filter, send)
@@ -87,31 +97,35 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
8797
if scope["type"] != "http":
8898
return await self.app(scope, receive, send)
8999

90-
async def apply_filter() -> Message:
91-
message = await receive()
92-
request = Request(scope)
93-
cql2_filter = getattr(request.state, FILTER_STATE_KEY, None)
94-
if not cql2_filter:
95-
logger.debug("No cql2 filter found on message.")
96-
return message
100+
request = Request(scope)
97101

98-
if request.method == "GET":
99-
query = filters.insert_qs_filter(qs=query, filter=cql2_filter)
100-
# Get the original query string
101-
original_qs = scope["query_string"].decode("utf-8")
102-
# Apply the filter to query string
103-
new_qs = filters.append_qs_filter(original_qs, cql2_filter)
104-
# Update the scope with new query string
105-
# scope["query_string"] = new_qs.encode("utf-8")
106-
elif request.method in ["POST", "PUT", "PATCH"]:
107-
# TODO: Apply the filter to the body
108-
message["body"] = json.dumps(
109-
filters.append_body_filter(
110-
body=json.loads(message.get("body", "{}")),
111-
filter=cql2_filter,
112-
)
113-
).encode("utf-8")
102+
if request.method == "GET":
103+
cql2_filter = scope["state"].get(FILTER_STATE_KEY)
104+
if cql2_filter:
105+
scope["query_string"] = filters.append_qs_filter(
106+
request.url.query, cql2_filter
107+
)
108+
return await self.app(scope, receive, send)
109+
elif request.method in ["POST", "PUT", "PATCH"]:
110+
# For methods with bodies, we need to wrap receive to modify the body
111+
async def receive_with_filter() -> Message:
112+
message = await receive()
113+
if message["type"] != "http.request":
114+
return message
115+
116+
cql2_filter = scope["state"].get(FILTER_STATE_KEY)
117+
if cql2_filter:
118+
try:
119+
body = message.get("body", b"{}")
120+
except json.JSONDecodeError as e:
121+
logger.warning("Failed to parse request body as JSON")
122+
# TODO: Return a 400 error
123+
raise e
124+
125+
new_body = filters.append_body_filter(json.loads(body), cql2_filter)
126+
message["body"] = json.dumps(new_body).encode("utf-8")
127+
return message
114128

115-
return message
129+
return await self.app(scope, receive_with_filter, send)
116130

117-
return await self.app(scope, apply_filter, send)
131+
return await self.app(scope, receive, send)

src/stac_auth_proxy/middleware/EnforceAuthMiddleware.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import logging
55
import urllib.request
66

7-
from fastapi import HTTPException, Security, security, status, Request
7+
from fastapi import HTTPException, Security, status, Request
88
from pydantic import HttpUrl
99
from starlette.middleware.base import ASGIApp
1010
from starlette.responses import JSONResponse

src/stac_auth_proxy/utils/filters.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,42 @@
11
"""Utility functions."""
22

3+
import json
34
import re
5+
from typing import Optional
46
from urllib.parse import parse_qs, urlencode
57

68
from cql2 import Expr
79

10+
from .requests import dict_to_bytes
811

9-
def append_qs_filter(qs: str, filter: Expr) -> str:
10-
"""Insert a filter expression into a query string. If a filter already exists, combine them."""
11-
qs_dict = parse_qs(qs)
12-
13-
for qs_filter in qs_dict.get("filter", []):
14-
filter += Expr(qs_filter)
15-
16-
qs_dict["filter"] = filter.to_text()
17-
qs_dict["filter-lang"] = "cql2-text"
1812

19-
return urlencode(qs_dict, doseq=True)
20-
21-
22-
def append_body_filter(body: dict, filter: Expr) -> dict:
13+
def append_qs_filter(qs: str, filter: Expr, filter_lang: Optional[str] = None) -> bytes:
14+
"""Insert a filter expression into a query string. If a filter already exists, combine them."""
15+
qs_dict = {k: v[0] for k, v in parse_qs(qs).items()}
16+
new_qs_dict = append_body_filter(
17+
qs_dict, filter, filter_lang or qs_dict.get("filter-lang", "cql2-text")
18+
)
19+
return dict_to_bytes(
20+
urlencode(
21+
{
22+
k: json.dumps(v) if isinstance(v, (list, dict)) else v
23+
for k, v in new_qs_dict.items()
24+
}
25+
)
26+
)
27+
28+
29+
def append_body_filter(body: dict, filter: Expr, filter_lang: Optional[str]) -> dict:
2330
"""Insert a filter expression into a request body. If a filter already exists, combine them."""
2431
cur_filter = body.get("filter")
32+
filter_lang = filter_lang or body.get("filter-lang", "cql2-json")
2533
if cur_filter:
2634
filter = filter + Expr(cur_filter)
27-
body["filter"] = filter.to_json()
28-
body["filter-lang"] = "cql2-json"
29-
return body
35+
return {
36+
**body,
37+
"filter": filter.to_text() if filter_lang == "cql2-text" else filter.to_json(),
38+
"filter-lang": filter_lang,
39+
}
3040

3141

3242
def is_collection_endpoint(path: str) -> bool:

tests/test_filters_jinja2.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,11 @@ async def test_search_get(
241241

242242
expected_output = {
243243
**input_query,
244-
"filter": proxy_filter.to_text(),
245-
"filter-lang": "cql2-text",
244+
"filter": (
245+
proxy_filter.to_text()
246+
if input_query.get("filter-lang") == "cql2-text"
247+
else proxy_filter.to_json()
248+
),
246249
}
247250
assert (
248251
upstream_query == expected_output

0 commit comments

Comments
 (0)