Skip to content

Commit ac23e26

Browse files
authored
Buildout filter for item read (#45)
1 parent 9d8599f commit ac23e26

File tree

8 files changed

+317
-171
lines changed

8 files changed

+317
-171
lines changed

README.md

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -191,21 +191,21 @@ If enabled, filters are intended to be applied to the following endpoints:
191191
- **Action:** Read Item
192192
- **Applied Filter:** `ITEMS_FILTER`
193193
- **Strategy:** Append body with generated CQL2 query.
194-
- `GET /collections/{collection_id}`
195-
- **Supported:** ❌[^23]
196-
- **Action:** Read Collection
197-
- **Applied Filter:** `COLLECTIONS_FILTER`
198-
- **Strategy:** Append query params with generated CQL2 query.
199194
- `GET /collections/{collection_id}/items`
200195
- **Supported:** ✅
201196
- **Action:** Read Item
202197
- **Applied Filter:** `ITEMS_FILTER`
203198
- **Strategy:** Append query params with generated CQL2 query.
204199
- `GET /collections/{collection_id}/items/{item_id}`
205-
- **Supported:** ❌[^25]
200+
- **Supported:**
206201
- **Action:** Read Item
207202
- **Applied Filter:** `ITEMS_FILTER`
208203
- **Strategy:** Validate response against CQL2 query.
204+
- `GET /collections/{collection_id}`
205+
- **Supported:** ❌[^23]
206+
- **Action:** Read Collection
207+
- **Applied Filter:** `COLLECTIONS_FILTER`
208+
- **Strategy:** Append query params with generated CQL2 query.
209209
- `POST /collections/`
210210
- **Supported:** ❌[^22]
211211
- **Action:** Create Collection
@@ -257,6 +257,5 @@ sequenceDiagram
257257
[^21]: https://github.com/developmentseed/stac-auth-proxy/issues/21
258258
[^22]: https://github.com/developmentseed/stac-auth-proxy/issues/22
259259
[^23]: https://github.com/developmentseed/stac-auth-proxy/issues/23
260-
[^25]: https://github.com/developmentseed/stac-auth-proxy/issues/25
261260
[^30]: https://github.com/developmentseed/stac-auth-proxy/issues/30
262261
[^37]: https://github.com/developmentseed/stac-auth-proxy/issues/37

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ classifiers = [
88
dependencies = [
99
"authlib>=1.3.2",
1010
"brotli>=1.1.0",
11-
"cql2>=0.3.5",
11+
"cql2>=0.3.6",
1212
"fastapi>=0.115.5",
1313
"httpx[http2]>=0.28.0",
1414
"jinja2>=3.1.4",

src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
"""Middleware to apply CQL2 filters."""
22

33
import json
4+
import re
45
from dataclasses import dataclass
56
from logging import getLogger
7+
from typing import Optional
68

9+
from cql2 import Expr
10+
from starlette.datastructures import MutableHeaders
711
from starlette.requests import Request
812
from starlette.types import ASGIApp, Message, Receive, Scope, Send
913

@@ -28,12 +32,88 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
2832
request = Request(scope)
2933

3034
if request.method == "GET":
31-
cql2_filter = getattr(request.state, self.state_key, None)
35+
cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None)
3236
if cql2_filter:
3337
scope["query_string"] = filters.append_qs_filter(
3438
request.url.query, cql2_filter
3539
)
36-
return await self.app(scope, receive, send)
40+
41+
initial_message = None
42+
body = b""
43+
44+
async def validate_response(message: Message) -> None:
45+
nonlocal initial_message
46+
nonlocal body
47+
headers = MutableHeaders(scope=initial_message)
48+
if message["type"] == "http.response.start":
49+
initial_message = message
50+
return
51+
52+
if message["type"] == "http.response.body":
53+
assert initial_message, "Initial message not set"
54+
assert cql2_filter, "Cql2Filter not set"
55+
56+
body += message["body"]
57+
if message.get("more_body"):
58+
return
59+
60+
try:
61+
body = json.loads(body)
62+
except json.JSONDecodeError:
63+
logger.warning("Failed to parse response body as JSON")
64+
not_found_body = json.dumps({"message": "Not found"}).encode(
65+
"utf-8"
66+
)
67+
headers["content-length"] = str(len(not_found_body))
68+
initial_message["status"] = 502
69+
await send(initial_message)
70+
await send(
71+
{
72+
"type": "http.response.body",
73+
"body": not_found_body,
74+
"more_body": False,
75+
}
76+
)
77+
return
78+
79+
logger.debug(
80+
"Applying %s filter to %s", cql2_filter.to_text(), body
81+
)
82+
if cql2_filter.matches(body):
83+
await send(initial_message)
84+
await send(
85+
{
86+
"type": "http.response.body",
87+
"body": json.dumps(body).encode("utf-8"),
88+
"more_body": False,
89+
}
90+
)
91+
else:
92+
not_found_body = json.dumps({"message": "Not found"}).encode(
93+
"utf-8"
94+
)
95+
headers["content-length"] = str(len(not_found_body))
96+
initial_message["status"] = 404
97+
await send(initial_message)
98+
await send(
99+
{
100+
"type": "http.response.body",
101+
"body": not_found_body,
102+
"more_body": False,
103+
}
104+
)
105+
106+
return message
107+
108+
should_validate_response = cql2_filter and re.match(
109+
r"^/collections/([^/]+)/items/([^/]+)$", request.url.path
110+
)
111+
112+
return await self.app(
113+
scope,
114+
receive,
115+
validate_response if should_validate_response else send,
116+
)
37117

38118
elif request.method in ["POST", "PUT", "PATCH"]:
39119

@@ -55,6 +135,10 @@ async def receive_and_apply_filter() -> Message:
55135
message["body"] = json.dumps(new_body).encode("utf-8")
56136
return message
57137

58-
return await self.app(scope, receive_and_apply_filter, send)
138+
return await self.app(
139+
scope,
140+
receive_and_apply_filter,
141+
send,
142+
)
59143

60144
return await self.app(scope, receive, send)

src/stac_auth_proxy/middleware/BuildCql2FilterMiddleware.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
"""Middleware to build the Cql2Filter."""
22

33
import json
4+
import re
45
from dataclasses import dataclass
56
from typing import Callable, Optional
67

78
from cql2 import Expr
89
from starlette.requests import Request
910
from starlette.types import ASGIApp, Message, Receive, Scope, Send
1011

11-
from ..utils import filters, requests
12+
from ..utils import requests
1213

1314

1415
@dataclass(frozen=True)
@@ -78,11 +79,10 @@ async def receive_build_filter() -> Message:
7879
def _get_filter(self, path: str) -> Optional[Callable[..., Expr]]:
7980
"""Get the CQL2 filter builder for the given path."""
8081
endpoint_filters = [
81-
(filters.is_collection_endpoint, self.collections_filter),
82-
(filters.is_item_endpoint, self.items_filter),
83-
(filters.is_search_endpoint, self.items_filter),
82+
(r"^/collections(/[^/]+)?$", self.collections_filter),
83+
(r"^(/collections/([^/]+)/items(/[^/]+)?$|/search$)", self.items_filter),
8484
]
85-
for check, builder in endpoint_filters:
86-
if check(path):
85+
for expr, builder in endpoint_filters:
86+
if re.match(expr, path):
8787
return builder
8888
return None

src/stac_auth_proxy/utils/filters.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Utility functions."""
22

33
import json
4-
import re
54
from typing import Optional
65
from urllib.parse import parse_qs
76

@@ -32,23 +31,6 @@ def append_body_filter(
3231
}
3332

3433

35-
def is_collection_endpoint(path: str) -> bool:
36-
"""Check if the path is a collection endpoint."""
37-
# TODO: Expand this to cover all cases where a collection filter should be applied
38-
return path == "/collections"
39-
40-
41-
def is_item_endpoint(path: str) -> bool:
42-
"""Check if the path is an item endpoint."""
43-
# TODO: Expand this to cover all cases where an item filter should be applied
44-
return bool(re.compile(r"^(/collections/([^/]+)/items$|/search)").match(path))
45-
46-
47-
def is_search_endpoint(path: str) -> bool:
48-
"""Check if the path is a search endpoint."""
49-
return path == "/search"
50-
51-
5234
def dict_to_query_string(params: dict) -> str:
5335
"""
5436
Convert a dictionary to a query string. Dict values are converted to JSON strings,

tests/conftest.py

Lines changed: 86 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import socket
55
import threading
6+
from functools import partial
67
from typing import Any, AsyncGenerator
78
from unittest.mock import DEFAULT, AsyncMock, MagicMock, patch
89

@@ -65,60 +66,104 @@ def build_token(payload: dict[str, Any], key=None) -> str:
6566

6667
@pytest.fixture(scope="session")
6768
def source_api():
68-
"""Create upstream API for testing purposes."""
69+
"""
70+
Create upstream API for testing purposes.
71+
72+
You can customize the response for each endpoint by passing a dict of responses:
73+
{
74+
"path": {
75+
"method": response_body
76+
}
77+
}
78+
"""
6979
app = FastAPI(docs_url="/api.html", openapi_url="/api")
7080

7181
app.add_middleware(CompressionMiddleware, minimum_size=0, compression_level=1)
7282

73-
for path, methods in {
74-
"/": [
75-
"GET",
76-
],
77-
"/conformance": [
78-
"GET",
79-
],
80-
"/queryables": [
81-
"GET",
82-
],
83-
"/search": [
84-
"GET",
85-
"POST",
86-
],
87-
"/collections": [
88-
"GET",
89-
"POST",
90-
],
91-
"/collections/{collection_id}": [
92-
"GET",
93-
"PUT",
94-
"PATCH",
95-
"DELETE",
96-
],
97-
"/collections/{collection_id}/items": [
98-
"GET",
99-
"POST",
100-
],
101-
"/collections/{collection_id}/items/{item_id}": [
102-
"GET",
103-
"PUT",
104-
"PATCH",
105-
"DELETE",
106-
],
107-
"/collections/{collection_id}/bulk_items": [
108-
"POST",
109-
],
110-
}.items():
83+
# Default responses for each endpoint
84+
default_responses = {
85+
"/": {"GET": {"id": "Response from GET@"}},
86+
"/conformance": {"GET": {"conformsTo": ["http://example.com/conformance"]}},
87+
"/queryables": {"GET": {"queryables": {}}},
88+
"/search": {
89+
"GET": {"type": "FeatureCollection", "features": []},
90+
"POST": {"type": "FeatureCollection", "features": []},
91+
},
92+
"/collections": {
93+
"GET": {"collections": []},
94+
"POST": {"id": "Response from POST@"},
95+
},
96+
"/collections/{collection_id}": {
97+
"GET": {"id": "Response from GET@"},
98+
"PUT": {"id": "Response from PUT@"},
99+
"PATCH": {"id": "Response from PATCH@"},
100+
"DELETE": {"id": "Response from DELETE@"},
101+
},
102+
"/collections/{collection_id}/items": {
103+
"GET": {"type": "FeatureCollection", "features": []},
104+
"POST": {"id": "Response from POST@"},
105+
},
106+
"/collections/{collection_id}/items/{item_id}": {
107+
"GET": {"id": "Response from GET@"},
108+
"PUT": {"id": "Response from PUT@"},
109+
"PATCH": {"id": "Response from PATCH@"},
110+
"DELETE": {"id": "Response from DELETE@"},
111+
},
112+
"/collections/{collection_id}/bulk_items": {
113+
"POST": {"id": "Response from POST@"},
114+
},
115+
}
116+
117+
# Store responses in app state
118+
app.state.default_responses = default_responses
119+
120+
def get_response(path: str, method: str) -> dict:
121+
"""Get response for a given path and method."""
122+
return app.state.default_responses.get(path, {}).get(
123+
method, {"id": f"Response from {method}@{path}"}
124+
)
125+
126+
for path, methods in default_responses.items():
111127
for method in methods:
112-
# NOTE: declare routes per method separately to avoid warning of "Duplicate Operation ID ... for function <lambda>"
113128
app.add_api_route(
114129
path,
115-
lambda: {"id": f"Response from {method}@{path}"},
130+
partial(get_response, path, method),
116131
methods=[method],
117132
)
118133

119134
return app
120135

121136

137+
@pytest.fixture
138+
def source_api_responses(source_api):
139+
"""
140+
Fixture to override source API responses for specific tests.
141+
142+
Usage:
143+
def test_something(source_api_responses):
144+
# Override responses for specific endpoints
145+
source_api_responses["/collections"]["GET"] = {"collections": [{"id": "test"}]}
146+
source_api_responses["/search"]["POST"] = {"type": "FeatureCollection", "features": [{"id": "test"}]}
147+
148+
# Your test code here
149+
"""
150+
# Get the default responses from the source_api fixture
151+
default_responses = source_api.state.default_responses
152+
153+
# Create a new dict that can be modified by tests
154+
responses = {}
155+
for path, methods in default_responses.items():
156+
responses[path] = methods.copy()
157+
158+
# Store the responses in the app state for the get_response function to use
159+
source_api.state.default_responses = responses
160+
161+
yield responses
162+
163+
# Restore the original responses after the test
164+
source_api.state.default_responses = default_responses
165+
166+
122167
@pytest.fixture(scope="session")
123168
def free_port():
124169
"""Get a free port."""

0 commit comments

Comments
 (0)