Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -191,21 +191,21 @@ If enabled, filters are intended to be applied to the following endpoints:
- **Action:** Read Item
- **Applied Filter:** `ITEMS_FILTER`
- **Strategy:** Append body with generated CQL2 query.
- `GET /collections/{collection_id}`
- **Supported:** ❌[^23]
- **Action:** Read Collection
- **Applied Filter:** `COLLECTIONS_FILTER`
- **Strategy:** Append query params with generated CQL2 query.
- `GET /collections/{collection_id}/items`
- **Supported:** ✅
- **Action:** Read Item
- **Applied Filter:** `ITEMS_FILTER`
- **Strategy:** Append query params with generated CQL2 query.
- `GET /collections/{collection_id}/items/{item_id}`
- **Supported:** ❌[^25]
- **Supported:**
- **Action:** Read Item
- **Applied Filter:** `ITEMS_FILTER`
- **Strategy:** Validate response against CQL2 query.
- `GET /collections/{collection_id}`
- **Supported:** ❌[^23]
- **Action:** Read Collection
- **Applied Filter:** `COLLECTIONS_FILTER`
- **Strategy:** Append query params with generated CQL2 query.
- `POST /collections/`
- **Supported:** ❌[^22]
- **Action:** Create Collection
Expand Down Expand Up @@ -257,6 +257,5 @@ sequenceDiagram
[^21]: https://github.com/developmentseed/stac-auth-proxy/issues/21
[^22]: https://github.com/developmentseed/stac-auth-proxy/issues/22
[^23]: https://github.com/developmentseed/stac-auth-proxy/issues/23
[^25]: https://github.com/developmentseed/stac-auth-proxy/issues/25
[^30]: https://github.com/developmentseed/stac-auth-proxy/issues/30
[^37]: https://github.com/developmentseed/stac-auth-proxy/issues/37
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ classifiers = [
dependencies = [
"authlib>=1.3.2",
"brotli>=1.1.0",
"cql2>=0.3.5",
"cql2>=0.3.6",
"fastapi>=0.115.5",
"httpx[http2]>=0.28.0",
"jinja2>=3.1.4",
Expand Down
90 changes: 87 additions & 3 deletions src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""Middleware to apply CQL2 filters."""

import json
import re
from dataclasses import dataclass
from logging import getLogger
from typing import Optional

from cql2 import Expr
from starlette.datastructures import MutableHeaders
from starlette.requests import Request
from starlette.types import ASGIApp, Message, Receive, Scope, Send

Expand All @@ -28,12 +32,88 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope)

if request.method == "GET":
cql2_filter = getattr(request.state, self.state_key, None)
cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None)
if cql2_filter:
scope["query_string"] = filters.append_qs_filter(
request.url.query, cql2_filter
)
return await self.app(scope, receive, send)

initial_message = None
body = b""

async def validate_response(message: Message) -> None:
nonlocal initial_message
nonlocal body
headers = MutableHeaders(scope=initial_message)
if message["type"] == "http.response.start":
initial_message = message
return

if message["type"] == "http.response.body":
assert initial_message, "Initial message not set"
assert cql2_filter, "Cql2Filter not set"

body += message["body"]
if message.get("more_body"):
return

try:
body = json.loads(body)
except json.JSONDecodeError:
logger.warning("Failed to parse response body as JSON")
not_found_body = json.dumps({"message": "Not found"}).encode(
"utf-8"
)
headers["content-length"] = str(len(not_found_body))
initial_message["status"] = 502
await send(initial_message)
await send(
{
"type": "http.response.body",
"body": not_found_body,
"more_body": False,
}
)
return

logger.debug(
"Applying %s filter to %s", cql2_filter.to_text(), body
)
if cql2_filter.matches(body):
await send(initial_message)
await send(
{
"type": "http.response.body",
"body": json.dumps(body).encode("utf-8"),
"more_body": False,
}
)
else:
not_found_body = json.dumps({"message": "Not found"}).encode(
"utf-8"
)
headers["content-length"] = str(len(not_found_body))
initial_message["status"] = 404
await send(initial_message)
await send(
{
"type": "http.response.body",
"body": not_found_body,
"more_body": False,
}
)

return message

should_validate_response = cql2_filter and re.match(
r"^/collections/([^/]+)/items/([^/]+)$", request.url.path
)

return await self.app(
scope,
receive,
validate_response if should_validate_response else send,
)

elif request.method in ["POST", "PUT", "PATCH"]:

Expand All @@ -55,6 +135,10 @@ async def receive_and_apply_filter() -> Message:
message["body"] = json.dumps(new_body).encode("utf-8")
return message

return await self.app(scope, receive_and_apply_filter, send)
return await self.app(
scope,
receive_and_apply_filter,
send,
)

return await self.app(scope, receive, send)
12 changes: 6 additions & 6 deletions src/stac_auth_proxy/middleware/BuildCql2FilterMiddleware.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""Middleware to build the Cql2Filter."""

import json
import re
from dataclasses import dataclass
from typing import Callable, Optional

from cql2 import Expr
from starlette.requests import Request
from starlette.types import ASGIApp, Message, Receive, Scope, Send

from ..utils import filters, requests
from ..utils import requests


@dataclass(frozen=True)
Expand Down Expand Up @@ -78,11 +79,10 @@ async def receive_build_filter() -> Message:
def _get_filter(self, path: str) -> Optional[Callable[..., Expr]]:
"""Get the CQL2 filter builder for the given path."""
endpoint_filters = [
(filters.is_collection_endpoint, self.collections_filter),
(filters.is_item_endpoint, self.items_filter),
(filters.is_search_endpoint, self.items_filter),
(r"^/collections(/[^/]+)?$", self.collections_filter),
(r"^(/collections/([^/]+)/items(/[^/]+)?$|/search$)", self.items_filter),
]
for check, builder in endpoint_filters:
if check(path):
for expr, builder in endpoint_filters:
if re.match(expr, path):
return builder
return None
18 changes: 0 additions & 18 deletions src/stac_auth_proxy/utils/filters.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Utility functions."""

import json
import re
from typing import Optional
from urllib.parse import parse_qs

Expand Down Expand Up @@ -32,23 +31,6 @@ def append_body_filter(
}


def is_collection_endpoint(path: str) -> bool:
"""Check if the path is a collection endpoint."""
# TODO: Expand this to cover all cases where a collection filter should be applied
return path == "/collections"


def is_item_endpoint(path: str) -> bool:
"""Check if the path is an item endpoint."""
# TODO: Expand this to cover all cases where an item filter should be applied
return bool(re.compile(r"^(/collections/([^/]+)/items$|/search)").match(path))


def is_search_endpoint(path: str) -> bool:
"""Check if the path is a search endpoint."""
return path == "/search"


def dict_to_query_string(params: dict) -> str:
"""
Convert a dictionary to a query string. Dict values are converted to JSON strings,
Expand Down
127 changes: 86 additions & 41 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import socket
import threading
from functools import partial
from typing import Any, AsyncGenerator
from unittest.mock import DEFAULT, AsyncMock, MagicMock, patch

Expand Down Expand Up @@ -65,60 +66,104 @@ def build_token(payload: dict[str, Any], key=None) -> str:

@pytest.fixture(scope="session")
def source_api():
"""Create upstream API for testing purposes."""
"""
Create upstream API for testing purposes.

You can customize the response for each endpoint by passing a dict of responses:
{
"path": {
"method": response_body
}
}
"""
app = FastAPI(docs_url="/api.html", openapi_url="/api")

app.add_middleware(CompressionMiddleware, minimum_size=0, compression_level=1)

for path, methods in {
"/": [
"GET",
],
"/conformance": [
"GET",
],
"/queryables": [
"GET",
],
"/search": [
"GET",
"POST",
],
"/collections": [
"GET",
"POST",
],
"/collections/{collection_id}": [
"GET",
"PUT",
"PATCH",
"DELETE",
],
"/collections/{collection_id}/items": [
"GET",
"POST",
],
"/collections/{collection_id}/items/{item_id}": [
"GET",
"PUT",
"PATCH",
"DELETE",
],
"/collections/{collection_id}/bulk_items": [
"POST",
],
}.items():
# Default responses for each endpoint
default_responses = {
"/": {"GET": {"id": "Response from GET@"}},
"/conformance": {"GET": {"conformsTo": ["http://example.com/conformance"]}},
"/queryables": {"GET": {"queryables": {}}},
"/search": {
"GET": {"type": "FeatureCollection", "features": []},
"POST": {"type": "FeatureCollection", "features": []},
},
"/collections": {
"GET": {"collections": []},
"POST": {"id": "Response from POST@"},
},
"/collections/{collection_id}": {
"GET": {"id": "Response from GET@"},
"PUT": {"id": "Response from PUT@"},
"PATCH": {"id": "Response from PATCH@"},
"DELETE": {"id": "Response from DELETE@"},
},
"/collections/{collection_id}/items": {
"GET": {"type": "FeatureCollection", "features": []},
"POST": {"id": "Response from POST@"},
},
"/collections/{collection_id}/items/{item_id}": {
"GET": {"id": "Response from GET@"},
"PUT": {"id": "Response from PUT@"},
"PATCH": {"id": "Response from PATCH@"},
"DELETE": {"id": "Response from DELETE@"},
},
"/collections/{collection_id}/bulk_items": {
"POST": {"id": "Response from POST@"},
},
}

# Store responses in app state
app.state.default_responses = default_responses

def get_response(path: str, method: str) -> dict:
"""Get response for a given path and method."""
return app.state.default_responses.get(path, {}).get(
method, {"id": f"Response from {method}@{path}"}
)

for path, methods in default_responses.items():
for method in methods:
# NOTE: declare routes per method separately to avoid warning of "Duplicate Operation ID ... for function <lambda>"
app.add_api_route(
path,
lambda: {"id": f"Response from {method}@{path}"},
partial(get_response, path, method),
methods=[method],
)

return app


@pytest.fixture
def source_api_responses(source_api):
"""
Fixture to override source API responses for specific tests.

Usage:
def test_something(source_api_responses):
# Override responses for specific endpoints
source_api_responses["/collections"]["GET"] = {"collections": [{"id": "test"}]}
source_api_responses["/search"]["POST"] = {"type": "FeatureCollection", "features": [{"id": "test"}]}

# Your test code here
"""
# Get the default responses from the source_api fixture
default_responses = source_api.state.default_responses

# Create a new dict that can be modified by tests
responses = {}
for path, methods in default_responses.items():
responses[path] = methods.copy()

# Store the responses in the app state for the get_response function to use
source_api.state.default_responses = responses

yield responses

# Restore the original responses after the test
source_api.state.default_responses = default_responses


@pytest.fixture(scope="session")
def free_port():
"""Get a free port."""
Expand Down
Loading