Skip to content

Commit 7cee884

Browse files
committed
In progress
1 parent 2ab85c3 commit 7cee884

File tree

5 files changed

+74
-66
lines changed

5 files changed

+74
-66
lines changed

src/stac_auth_proxy/app.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
OpenApiMiddleware,
1515
AddProcessTimeHeaderMiddleware,
1616
EnforceAuthMiddleware,
17+
BuildCql2FilterMiddleware,
18+
ApplyCql2FilterMiddleware,
1719
)
1820

1921
from .config import Settings
@@ -48,6 +50,13 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
4850
default_public=settings.default_public,
4951
)
5052

53+
app.add_middleware(ApplyCql2FilterMiddleware)
54+
app.add_middleware(
55+
BuildCql2FilterMiddleware,
56+
# collections_filter=settings.collections_filter,
57+
items_filter=settings.items_filter(),
58+
)
59+
5160
if settings.debug:
5261
app.add_api_route(
5362
"/_debug",

src/stac_auth_proxy/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ class Settings(BaseSettings):
5555
}
5656
items_filter: Optional[ClassInput] = None
5757
items_filter_endpoints: Optional[EndpointMethods] = {
58-
"/search": ["POST"],
59-
"/collections/{collection_id}/items": ["GET", "POST"],
58+
r"^/search$": ["POST"],
59+
r"^/collections/([^/]+)/items$": ["GET", "POST"],
6060
}
6161

6262
model_config = SettingsConfigDict(env_prefix="STAC_AUTH_PROXY_")

src/stac_auth_proxy/filters/template.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
"""Generate CQL2 filter expressions via Jinja2 templating."""
22

3+
from typing import Any
34
from dataclasses import dataclass, field
45

56
from cql2 import Expr
67
from fastapi import Request
78
from jinja2 import BaseLoader, Environment
89

9-
from ..utils.requests import extract_variables
10-
1110

1211
@dataclass
1312
class Template:
@@ -20,24 +19,9 @@ def __post_init__(self):
2019
"""Initialize the Jinja2 environment."""
2120
self.env = Environment(loader=BaseLoader).from_string(self.template_str)
2221

23-
async def __call__(self, request: Request) -> Expr:
22+
async def __call__(self, context: dict[str, Any]) -> Expr:
2423
"""Render a CQL2 filter expression with the request and auth token."""
2524
# TODO: How to handle the case where auth_token is null?
26-
context = {
27-
"req": {
28-
"path": request.url.path,
29-
"method": request.method,
30-
"query_params": dict(request.query_params),
31-
"path_params": extract_variables(request.url.path),
32-
"headers": dict(request.headers),
33-
"body": (
34-
await request.json()
35-
if request.headers.get("content-type") == "application/json"
36-
else (await request.body()).decode()
37-
),
38-
},
39-
"token": request.state.user,
40-
}
4125
cql2_str = self.env.render(**context).strip()
4226
cql2_expr = Expr(cql2_str)
4327
cql2_expr.validate()
Lines changed: 57 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from logging import getLogger
12
import json
23
from dataclasses import dataclass
34
from typing import Annotated, Callable, Optional
@@ -9,6 +10,7 @@
910
from ..config import EndpointMethods
1011
from ..utils import di, filters, requests
1112

13+
logger = getLogger(__name__)
1214

1315
FILTER_STATE_KEY = "cql2_filter"
1416

@@ -27,17 +29,38 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
2729
if scope["type"] != "http":
2830
return await self.app(scope, receive, send)
2931

30-
request = Request(scope)
31-
filter_builder = self._get_filter(request.url.path)
32-
if filter_builder:
33-
cql2_filter = await di.call_with_injected_dependencies(
34-
func=filter_builder,
35-
request=request,
36-
)
37-
cql2_filter.validate()
38-
scope["state"][FILTER_STATE_KEY] = cql2_filter
32+
total_body = b""
33+
34+
async def receive_build_filter() -> Message:
35+
nonlocal total_body
3936

40-
return await self.app(scope, receive, send)
37+
message = await receive()
38+
total_body += message.get("body", b"")
39+
40+
if not message.get("more_body"):
41+
request = Request(scope)
42+
filter_builder = self._get_filter(request.url.path)
43+
if filter_builder:
44+
cql2_filter = await filter_builder(
45+
{
46+
"req": {
47+
"path": request.url.path,
48+
"method": request.method,
49+
"query_params": dict(request.query_params),
50+
"path_params": requests.extract_variables(
51+
request.url.path
52+
),
53+
"headers": dict(request.headers),
54+
"body": json.loads(total_body),
55+
},
56+
**request.state._state,
57+
}
58+
)
59+
cql2_filter.validate()
60+
scope["state"][FILTER_STATE_KEY] = cql2_filter
61+
return message
62+
63+
return await self.app(scope, receive_build_filter, send)
4164

4265
def _get_filter(self, path: str) -> Optional[Callable[..., Expr]]:
4366
"""Get the CQL2 filter builder for the given path."""
@@ -55,51 +78,40 @@ def _get_filter(self, path: str) -> Optional[Callable[..., Expr]]:
5578

5679
@dataclass(frozen=True)
5780
class ApplyCql2FilterMiddleware:
58-
"""Middleware to add the OpenAPI spec to the response."""
81+
"""Middleware to apply the Cql2Filter to the request."""
5982

6083
app: ASGIApp
6184

6285
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
6386
"""Add the Cql2Filter to the request."""
64-
request = Request(scope)
65-
cql2_filter = request.state.get(FILTER_STATE_KEY)
66-
67-
if scope["type"] != "http" or not cql2_filter:
87+
if scope["type"] != "http":
6888
return await self.app(scope, receive, send)
6989

70-
# Apply filter if applicable
90+
async def apply_filter() -> Message:
91+
message = await receive()
92+
request = Request(scope)
93+
cql2_filter = getattr(request.state, FILTER_STATE_KEY, None)
94+
if not cql2_filter:
95+
logger.debug("No cql2 filter found on message.")
96+
return message
7197

72-
total_body = b""
73-
74-
async def receive_with_filter(message: Message):
75-
query = request.url.query
76-
77-
# TODO: How do we handle querystrings in middleware?
7898
if request.method == "GET":
7999
query = filters.insert_qs_filter(qs=query, filter=cql2_filter)
80-
81-
if message["type"] == "http.response.body":
82-
nonlocal total_body
83-
total_body += message["body"]
84-
if message["more_body"]:
85-
return await receive({**message, "body": b""})
86-
87-
# TODO: Only on search, not on create or update...
88-
if request.method in ["POST", "PUT"]:
89-
return await receive(
90-
{
91-
"type": "http.response.body",
92-
"body": requests.dict_to_bytes(
93-
filters.append_body_filter(
94-
json.loads(total_body), cql2_filter
95-
)
96-
),
97-
"more_body": False,
98-
}
100+
# Get the original query string
101+
original_qs = scope["query_string"].decode("utf-8")
102+
# Apply the filter to query string
103+
new_qs = filters.append_qs_filter(original_qs, cql2_filter)
104+
# Update the scope with new query string
105+
# scope["query_string"] = new_qs.encode("utf-8")
106+
elif request.method in ["POST", "PUT", "PATCH"]:
107+
# TODO: Apply the filter to the body
108+
message["body"] = json.dumps(
109+
filters.append_body_filter(
110+
body=json.loads(message.get("body", "{}")),
111+
filter=cql2_filter,
99112
)
113+
).encode("utf-8")
100114

101-
return await receive(message)
102-
103-
await receive(message)
115+
return message
104116

105-
return await self.app(scope, receive_with_filter, send)
117+
return await self.app(scope, apply_filter, send)

src/stac_auth_proxy/middleware/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33
from .UpdateOpenApiMiddleware import OpenApiMiddleware
44
from .AddProcessTimeHeaderMiddleware import AddProcessTimeHeaderMiddleware
55
from .EnforceAuthMiddleware import EnforceAuthMiddleware
6+
from .Cql2FilterMiddleware import BuildCql2FilterMiddleware, ApplyCql2FilterMiddleware
67

78
__all__ = [
8-
UpdateOpenApiMiddleware,
9+
OpenApiMiddleware,
910
AddProcessTimeHeaderMiddleware,
1011
EnforceAuthMiddleware,
12+
BuildCql2FilterMiddleware,
13+
ApplyCql2FilterMiddleware,
1114
]

0 commit comments

Comments
 (0)