Skip to content

Commit c537803

Browse files
committed
Working validation
1 parent d0386d0 commit c537803

File tree

6 files changed

+205
-126
lines changed

6 files changed

+205
-126
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ classifiers = [
88
dependencies = [
99
"authlib>=1.3.2",
1010
"brotli>=1.1.0",
11-
"cql2>=0.3.5",
11+
"cql2>=0.3.6",
1212
"fastapi>=0.115.5",
1313
"httpx[http2]>=0.28.0",
1414
"jinja2>=3.1.4",

src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
"""Middleware to apply CQL2 filters."""
22

33
import json
4+
import re
45
from dataclasses import dataclass
56
from logging import getLogger
7+
from typing import Optional
68

9+
from cql2 import Expr
10+
from starlette.datastructures import MutableHeaders
711
from starlette.requests import Request
812
from starlette.types import ASGIApp, Message, Receive, Scope, Send
913

@@ -28,12 +32,87 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
2832
request = Request(scope)
2933

3034
if request.method == "GET":
31-
cql2_filter = getattr(request.state, self.state_key, None)
35+
cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None)
3236
if cql2_filter:
3337
scope["query_string"] = filters.append_qs_filter(
3438
request.url.query, cql2_filter
3539
)
36-
return await self.app(scope, receive, send)
40+
41+
initial_message = None
42+
body = b""
43+
44+
async def validate_response(message: Message) -> None:
45+
nonlocal initial_message
46+
nonlocal body
47+
headers = MutableHeaders(scope=initial_message)
48+
if message["type"] == "http.response.start":
49+
initial_message = message
50+
return
51+
52+
if message["type"] == "http.response.body":
53+
assert initial_message, "Initial message not set"
54+
assert cql2_filter, "Cql2Filter not set"
55+
body += message["body"]
56+
if message.get("more_body"):
57+
return
58+
59+
try:
60+
body = json.loads(body)
61+
except json.JSONDecodeError:
62+
logger.warning("Failed to parse response body as JSON")
63+
not_found_body = json.dumps({"message": "Not found"}).encode(
64+
"utf-8"
65+
)
66+
headers["content-length"] = str(len(not_found_body))
67+
initial_message["status"] = 502
68+
await send(initial_message)
69+
await send(
70+
{
71+
"type": "http.response.body",
72+
"body": not_found_body,
73+
"more_body": False,
74+
}
75+
)
76+
return
77+
78+
logger.debug(
79+
"Applying %s filter to %s", cql2_filter.to_text(), body
80+
)
81+
if cql2_filter.matches(body):
82+
await send(initial_message)
83+
await send(
84+
{
85+
"type": "http.response.body",
86+
"body": json.dumps(body).encode("utf-8"),
87+
"more_body": False,
88+
}
89+
)
90+
else:
91+
not_found_body = json.dumps({"message": "Not found"}).encode(
92+
"utf-8"
93+
)
94+
headers["content-length"] = str(len(not_found_body))
95+
initial_message["status"] = 404
96+
await send(initial_message)
97+
await send(
98+
{
99+
"type": "http.response.body",
100+
"body": not_found_body,
101+
"more_body": False,
102+
}
103+
)
104+
105+
return message
106+
107+
should_validate_response = cql2_filter and re.match(
108+
r"^/collections/([^/]+)/items/([^/]+)$", request.url.path
109+
)
110+
111+
return await self.app(
112+
scope,
113+
receive,
114+
validate_response if should_validate_response else send,
115+
)
37116

38117
elif request.method in ["POST", "PUT", "PATCH"]:
39118

@@ -55,6 +134,10 @@ async def receive_and_apply_filter() -> Message:
55134
message["body"] = json.dumps(new_body).encode("utf-8")
56135
return message
57136

58-
return await self.app(scope, receive_and_apply_filter, send)
137+
return await self.app(
138+
scope,
139+
receive_and_apply_filter,
140+
send,
141+
)
59142

60143
return await self.app(scope, receive, send)

src/stac_auth_proxy/middleware/BuildCql2FilterMiddleware.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
"""Middleware to build the Cql2Filter."""
22

33
import json
4+
import re
45
from dataclasses import dataclass
56
from typing import Callable, Optional
67

78
from cql2 import Expr
89
from starlette.requests import Request
910
from starlette.types import ASGIApp, Message, Receive, Scope, Send
1011

11-
from ..utils import filters, requests
12+
from ..utils import requests
1213

1314

1415
@dataclass(frozen=True)
@@ -78,11 +79,10 @@ async def receive_build_filter() -> Message:
7879
def _get_filter(self, path: str) -> Optional[Callable[..., Expr]]:
7980
"""Get the CQL2 filter builder for the given path."""
8081
endpoint_filters = [
81-
(filters.is_collection_endpoint, self.collections_filter),
82-
(filters.is_item_endpoint, self.items_filter),
83-
(filters.is_search_endpoint, self.items_filter),
82+
(r"^/collections(/[^/]+)?$", self.collections_filter),
83+
(r"^(/collections/([^/]+)/items(/[^/]+)?$|/search$)", self.items_filter),
8484
]
85-
for check, builder in endpoint_filters:
86-
if check(path):
85+
for expr, builder in endpoint_filters:
86+
if re.match(expr, path):
8787
return builder
8888
return None

src/stac_auth_proxy/utils/filters.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Utility functions."""
22

33
import json
4-
import re
54
from typing import Optional
65
from urllib.parse import parse_qs
76

@@ -32,23 +31,6 @@ def append_body_filter(
3231
}
3332

3433

35-
def is_collection_endpoint(path: str) -> bool:
36-
"""Check if the path is a collection endpoint."""
37-
# TODO: Expand this to cover all cases where a collection filter should be applied
38-
return path == "/collections"
39-
40-
41-
def is_item_endpoint(path: str) -> bool:
42-
"""Check if the path is an item endpoint."""
43-
# TODO: Expand this to cover all cases where an item filter should be applied
44-
return bool(re.compile(r"^(/collections/([^/]+)/items$|/search)").match(path))
45-
46-
47-
def is_search_endpoint(path: str) -> bool:
48-
"""Check if the path is a search endpoint."""
49-
return path == "/search"
50-
51-
5234
def dict_to_query_string(params: dict) -> str:
5335
"""
5436
Convert a dictionary to a query string. Dict values are converted to JSON strings,

tests/test_filters_jinja2.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,11 @@ def test_item_get(
307307
"properties": {"private": True},
308308
}
309309
response = client.get("/collections/foo/items/bar")
310-
expected_status = 404 if is_authenticated else 200
310+
expected_status = 200 if is_authenticated else 404
311+
expected_body = (
312+
{"id": "bar", "properties": {"private": True}}
313+
if is_authenticated
314+
else {"message": "Not found"}
315+
)
311316
assert response.status_code == expected_status
312-
if is_authenticated:
313-
assert response.json() == {"id": "bar", "properties": {"private": True}}
317+
assert response.json() == expected_body

0 commit comments

Comments
 (0)