Skip to content

Commit b6363c3

Browse files
committed
feat: enhance CQL2 filter support for collections
- Added support for a collections filter in the configuration and middleware. - Updated README to clarify content filtering based on request context. - Refactored middleware to handle both items and collections filters. - Improved error handling in filter application. - Updated tests to include scenarios for collections filtering.
1 parent 1ce8ed5 commit b6363c3

File tree

7 files changed

+182
-25
lines changed

7 files changed

+182
-25
lines changed

README.md

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ STAC Auth Proxy is a proxy API that mediates between the client and your interna
1010
## ✨Features✨
1111

1212
- **🔐 Authentication:** Apply [OpenID Connect (OIDC)](https://openid.net/developers/how-connect-works/) token validation and optional scope checks to specified endpoints and methods
13-
- **🛂 Content Filtering:** Use CQL2 filters via the [Filter Extension](https://github.com/stac-api-extensions/filter?tab=readme-ov-file) to tailor API responses based on user context
13+
- **🛂 Content Filtering:** Use CQL2 filters via the [Filter Extension](https://github.com/stac-api-extensions/filter?tab=readme-ov-file) to tailor API responses based on request context (e.g. user role)
1414
- **🤝 External Policy Integration:** Integrate with external systems (e.g. [Open Policy Agent (OPA)](https://www.openpolicyagent.org/)) to generate CQL2 filters dynamically from policy decisions
1515
- **🧩 Authentication Extension:** Add the [Authentication Extension](https://github.com/stac-extensions/authentication) to API responses to expose auth-related metadata
1616
- **📘 OpenAPI Augmentation:** Enhance the [OpenAPI spec](https://swagger.io/specification/) with security details to keep auto-generated docs and UIs (e.g., [Swagger UI](https://swagger.io/tools/swagger-ui/)) accurate
@@ -227,7 +227,7 @@ The system supports generating CQL2 filters based on request context to provide
227227
228228
#### Filters
229229

230-
If enabled, filters are intended to be applied to the following endpoints:
230+
If enabled, filters are applied to the following endpoints:
231231

232232
- `GET /search`
233233
- **Supported:**
@@ -250,12 +250,12 @@ If enabled, filters are intended to be applied to the following endpoints:
250250
- **Applied Filter:** `ITEMS_FILTER`
251251
- **Strategy:** Validate response against CQL2 query.
252252
- `GET /collections`
253-
- **Supported:** [^23]
253+
- **Supported:**
254254
- **Action:** Read Collection
255255
- **Applied Filter:** `COLLECTIONS_FILTER`
256256
- **Strategy:** Append query params with generated CQL2 query.
257257
- `GET /collections/{collection_id}`
258-
- **Supported:** [^23]
258+
- **Supported:**
259259
- **Action:** Read Collection
260260
- **Applied Filter:** `COLLECTIONS_FILTER`
261261
- **Strategy:** Validate response against CQL2 query.
@@ -411,6 +411,5 @@ class ApprovedCollectionsFilter:
411411
412412
[^21]: https://github.com/developmentseed/stac-auth-proxy/issues/21
413413
[^22]: https://github.com/developmentseed/stac-auth-proxy/issues/22
414-
[^23]: https://github.com/developmentseed/stac-auth-proxy/issues/23
415414
[^30]: https://github.com/developmentseed/stac-auth-proxy/issues/30
416415
[^37]: https://github.com/developmentseed/stac-auth-proxy/issues/37

src/stac_auth_proxy/app.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,16 @@ async def lifespan(app: FastAPI):
119119
auth_scheme_override=settings.openapi_auth_scheme_override,
120120
)
121121

122-
if settings.items_filter:
122+
if settings.items_filter or settings.collections_filter:
123123
app.add_middleware(
124124
ApplyCql2FilterMiddleware,
125125
)
126126
app.add_middleware(
127127
BuildCql2FilterMiddleware,
128-
items_filter=settings.items_filter(),
128+
items_filter=settings.items_filter() if settings.items_filter else None,
129+
collections_filter=(
130+
settings.collections_filter() if settings.collections_filter else None
131+
),
129132
)
130133

131134
app.add_middleware(

src/stac_auth_proxy/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class Settings(BaseSettings):
7171

7272
# Filters
7373
items_filter: Optional[ClassInput] = None
74+
collections_filter: Optional[ClassInput] = None
7475

7576
model_config = SettingsConfigDict(
7677
env_nested_delimiter="_",

src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,12 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
5151
)
5252
return await req_body_handler(scope, receive, send)
5353

54-
if re.match(r"^/collections/([^/]+)/items/([^/]+)$", request.url.path):
54+
# Handle single record requests (ie non-filterable endpoints)
55+
single_record_endpoints = [
56+
r"^/collections/([^/]+)/items/([^/]+)$",
57+
r"^/collections/([^/]+)$",
58+
]
59+
if any(re.match(expr, request.url.path) for expr in single_record_endpoints):
5560
res_body_validator = Cql2ResponseBodyValidator(
5661
app=self.app,
5762
cql2_filter=cql2_filter,
@@ -166,15 +171,19 @@ async def buffered_send(message: Message) -> None:
166171
logger.debug(
167172
"Applying %s filter to %s", self.cql2_filter.to_text(), body_json
168173
)
169-
if self.cql2_filter.matches(body_json):
170-
await send(initial_message)
171-
return await send(
172-
{
173-
"type": "http.response.body",
174-
"body": json.dumps(body_json).encode("utf-8"),
175-
"more_body": False,
176-
}
177-
)
174+
try:
175+
if self.cql2_filter.matches(body_json):
176+
await send(initial_message)
177+
return await send(
178+
{
179+
"type": "http.response.body",
180+
"body": json.dumps(body_json).encode("utf-8"),
181+
"more_body": False,
182+
}
183+
)
184+
except Exception as e:
185+
logger.warning("Failed to apply filter: %s", e)
186+
178187
return await _send_error_response(404, "Not found")
179188

180189
return await self.app(scope, receive, buffered_send)

src/stac_auth_proxy/middleware/BuildCql2FilterMiddleware.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ class BuildCql2FilterMiddleware:
2525

2626
# Filters
2727
collections_filter: Optional[Callable] = None
28+
collections_filter_path: str = r"^/collections(/[^/]+)?$"
2829
items_filter: Optional[Callable] = None
30+
items_filter_path: str = r"^(/collections/([^/]+)/items(/[^/]+)?$|/search$)"
2931

3032
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
3133
"""Build the CQL2 filter, place on the request state."""
@@ -65,8 +67,8 @@ def _get_filter(
6567
) -> Optional[Callable[..., Awaitable[str | dict[str, Any]]]]:
6668
"""Get the CQL2 filter builder for the given path."""
6769
endpoint_filters = [
68-
(r"^/collections(/[^/]+)?$", self.collections_filter),
69-
(r"^(/collections/([^/]+)/items(/[^/]+)?$|/search$)", self.items_filter),
70+
(self.collections_filter_path, self.collections_filter),
71+
(self.items_filter_path, self.items_filter),
7072
]
7173
for expr, builder in endpoint_filters:
7274
if re.match(expr, path):

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def mock_env():
207207
@pytest.fixture
208208
async def mock_upstream() -> AsyncGenerator[MagicMock, None]:
209209
"""Mock the HTTPX send method. Useful when we want to inspect the request is sent to upstream API."""
210+
# NOTE: This fixture will interfere with the source_api_responses fixture
210211

211212
async def store_body(request, **kwargs):
212213
"""Exhaust and store the request body."""

tests/test_filters_jinja2.py

Lines changed: 148 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@
121121
)
122122

123123

124-
def _build_client(
124+
def _build_items_filter_client(
125125
*,
126126
src_api_server: str,
127127
template_expr: str,
@@ -162,7 +162,7 @@ async def test_search_post(
162162
token_builder,
163163
):
164164
"""Test that POST /search merges the upstream query with the templated filter."""
165-
response = _build_client(
165+
response = _build_items_filter_client(
166166
src_api_server=source_api_server,
167167
template_expr=filter_template_expr,
168168
is_authenticated=is_authenticated,
@@ -210,7 +210,7 @@ async def test_search_get(
210210
token_builder,
211211
):
212212
"""Test that GET /search merges the upstream query params with the templated filter."""
213-
client = _build_client(
213+
client = _build_items_filter_client(
214214
src_api_server=source_api_server,
215215
template_expr=filter_template_expr,
216216
is_authenticated=is_authenticated,
@@ -263,7 +263,7 @@ async def test_items_list(
263263
token_builder,
264264
):
265265
"""Test that GET /collections/foo/items merges query params with the templated filter."""
266-
client = _build_client(
266+
client = _build_items_filter_client(
267267
src_api_server=source_api_server,
268268
template_expr=filter_template_expr,
269269
is_authenticated=is_authenticated,
@@ -296,7 +296,7 @@ def test_item_get(
296296
source_api_server, is_authenticated, token_builder, source_api_responses
297297
):
298298
"""Test that GET /collections/foo/items/bar is rejected."""
299-
client = _build_client(
299+
client = _build_items_filter_client(
300300
src_api_server=source_api_server,
301301
template_expr="{{ '(properties.private = false)' if payload is none else true }}",
302302
is_authenticated=is_authenticated,
@@ -323,7 +323,7 @@ async def test_search_post_empty_body(
323323
token_builder,
324324
):
325325
"""Test that POST /search with empty body."""
326-
client = _build_client(
326+
client = _build_items_filter_client(
327327
src_api_server=source_api_server,
328328
template_expr="(properties.private = false)",
329329
is_authenticated=is_authenticated,
@@ -337,3 +337,145 @@ async def test_search_post_empty_body(
337337
)
338338

339339
assert response.status_code == 200
340+
341+
342+
COLLECTIONS_FILTER_CASES = [
343+
pytest.param(
344+
"(properties.private = false)",
345+
"(properties.private = false)",
346+
"(properties.private = false)",
347+
id="simple_collections_filter",
348+
),
349+
pytest.param(
350+
"{{ '(properties.private = false)' if payload is none else true }}",
351+
"true",
352+
"(properties.private = false)",
353+
id="templated_collections_filter",
354+
),
355+
]
356+
357+
COLLECTIONS_QUERIES = [
358+
pytest.param(
359+
{},
360+
id="collections_no_filter",
361+
),
362+
pytest.param(
363+
{
364+
"filter-lang": "cql2-text",
365+
"filter": "(properties.private = true)",
366+
},
367+
id="collections_with_filter",
368+
),
369+
]
370+
371+
372+
def _build_collections_filter_client(
373+
*,
374+
src_api_server: str,
375+
template_expr: str,
376+
is_authenticated: bool,
377+
token_builder,
378+
):
379+
"""Build a TestClient configured for either authenticated or anonymous usage."""
380+
app = app_factory(
381+
upstream_url=src_api_server,
382+
collections_filter={
383+
"cls": "stac_auth_proxy.filters:Template",
384+
"args": [template_expr.strip()],
385+
},
386+
default_public=True,
387+
)
388+
headers = (
389+
{"Authorization": f"Bearer {token_builder({'sub': 'test-user'})}"}
390+
if is_authenticated
391+
else {}
392+
)
393+
return TestClient(app, headers=headers)
394+
395+
396+
@pytest.mark.parametrize(
397+
"filter_template_expr, expected_auth_filter, expected_anon_filter",
398+
COLLECTIONS_FILTER_CASES,
399+
)
400+
@pytest.mark.parametrize("is_authenticated", [True, False], ids=["auth", "anon"])
401+
@pytest.mark.parametrize("input_query", COLLECTIONS_QUERIES)
402+
async def test_collections_list(
403+
mock_upstream,
404+
source_api_server,
405+
filter_template_expr,
406+
expected_auth_filter,
407+
expected_anon_filter,
408+
is_authenticated,
409+
input_query,
410+
token_builder,
411+
):
412+
"""Test that GET /collections merges query params with the templated filter."""
413+
client = _build_collections_filter_client(
414+
src_api_server=source_api_server,
415+
template_expr=filter_template_expr,
416+
is_authenticated=is_authenticated,
417+
token_builder=token_builder,
418+
)
419+
response = client.get("/collections", params=input_query)
420+
response.raise_for_status()
421+
422+
# For GET collections, we expect an empty body and appended querystring
423+
proxied_request = await get_upstream_request(mock_upstream)
424+
assert proxied_request.body == ""
425+
426+
# Determine the expected combined filter
427+
proxy_filter = cql2.Expr(
428+
expected_auth_filter if is_authenticated else expected_anon_filter
429+
)
430+
input_filter = input_query.get("filter")
431+
if input_filter:
432+
proxy_filter += cql2.Expr(input_filter)
433+
434+
filter_lang = input_query.get("filter-lang", "cql2-text")
435+
expected_output = {
436+
**input_query,
437+
"filter": (
438+
proxy_filter.to_text()
439+
if filter_lang == "cql2-text"
440+
else proxy_filter.to_json()
441+
),
442+
"filter-lang": filter_lang,
443+
}
444+
assert (
445+
proxied_request.query_params == expected_output
446+
), "Collections query should combine filter expressions."
447+
448+
449+
@pytest.mark.parametrize(
450+
"filter_template_expr, expected_auth_filter, expected_anon_filter",
451+
COLLECTIONS_FILTER_CASES,
452+
)
453+
@pytest.mark.parametrize("is_authenticated", [True, False], ids=["auth", "anon"])
454+
async def test_collection_get(
455+
source_api_server,
456+
filter_template_expr,
457+
expected_auth_filter,
458+
expected_anon_filter,
459+
is_authenticated,
460+
token_builder,
461+
source_api_responses,
462+
):
463+
"""Test that GET /collections/{collection_id} applies the templated filter."""
464+
client = _build_collections_filter_client(
465+
src_api_server=source_api_server,
466+
template_expr=filter_template_expr,
467+
is_authenticated=is_authenticated,
468+
token_builder=token_builder,
469+
)
470+
response_body = {
471+
"id": "foo",
472+
"properties": {"private": True},
473+
}
474+
source_api_responses["/collections/{collection_id}"]["GET"] = response_body
475+
response = client.get("/collections/foo")
476+
477+
# Expected applied filter
478+
proxy_filter = cql2.Expr(
479+
expected_auth_filter if is_authenticated else expected_anon_filter
480+
)
481+
assert response.status_code == 200 if proxy_filter.matches(response_body) else 404

0 commit comments

Comments
 (0)