Skip to content

Commit 2ab85c3

Browse files
committed
CQL2 Middleware: in progress
1 parent 6a13f22 commit 2ab85c3

File tree

4 files changed

+127
-69
lines changed

4 files changed

+127
-69
lines changed

src/stac_auth_proxy/app.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,7 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
5656
)
5757

5858
# Tooling
59-
proxy_handler = ReverseProxyHandler(
60-
upstream=str(settings.upstream_url),
61-
# TODO: Refactor filter tooling into middleare
62-
collections_filter=settings.collections_filter,
63-
items_filter=settings.items_filter,
64-
)
59+
proxy_handler = ReverseProxyHandler(upstream=str(settings.upstream_url))
6560
# Catchall for remainder of the endpoints
6661
app.add_api_route(
6762
"/{path:path}",

src/stac_auth_proxy/handlers/reverse_proxy.py

Lines changed: 10 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,15 @@
11
"""Tooling to manage the reverse proxying of requests to an upstream STAC API."""
22

3-
import json
43
import logging
54
import time
6-
from dataclasses import dataclass
7-
from typing import Annotated, Callable, Optional
5+
from dataclasses import dataclass, field
86

97
import httpx
10-
from cql2 import Expr
11-
from fastapi import Depends, Request
8+
from fastapi import Request
129
from starlette.background import BackgroundTask
1310
from starlette.datastructures import MutableHeaders
1411
from starlette.responses import StreamingResponse
1512

16-
from ..utils import di, filters
1713

1814
logger = logging.getLogger(__name__)
1915

@@ -24,79 +20,35 @@ class ReverseProxyHandler:
2420

2521
upstream: str
2622
client: httpx.AsyncClient = None
27-
28-
# Filters
29-
collections_filter: Optional[Callable] = None
30-
items_filter: Optional[Callable] = None
23+
timeout: httpx.Timeout = field(default_factory=lambda: httpx.Timeout(timeout=15.0))
3124

3225
def __post_init__(self):
3326
"""Initialize the HTTP client."""
3427
self.client = self.client or httpx.AsyncClient(
3528
base_url=self.upstream,
36-
timeout=httpx.Timeout(timeout=15.0),
37-
)
38-
self.collections_filter = (
39-
self.collections_filter() if self.collections_filter else None
29+
timeout=self.timeout,
4030
)
41-
self.items_filter = self.items_filter() if self.items_filter else None
4231

43-
def _get_filter(self, path: str) -> Optional[Callable[..., Expr]]:
44-
"""Get the CQL2 filter builder for the given path."""
45-
endpoint_filters = [
46-
# TODO: Use collections_filter_endpoints & items_filter_endpoints
47-
(filters.is_collection_endpoint, self.collections_filter),
48-
(filters.is_item_endpoint, self.items_filter),
49-
(filters.is_search_endpoint, self.items_filter),
50-
]
51-
for check, builder in endpoint_filters:
52-
if check(path):
53-
return builder
54-
return None
55-
56-
async def proxy_request(self, request: Request, *, stream=False) -> httpx.Response:
32+
async def proxy_request(self, request: Request) -> httpx.Response:
5733
"""Proxy a request to the upstream STAC API."""
5834
headers = MutableHeaders(request.headers)
5935
headers.setdefault("X-Forwarded-For", request.client.host)
6036
headers.setdefault("X-Forwarded-Host", request.url.hostname)
6137

62-
path = request.url.path
63-
query = request.url.query
64-
# TODO: Should we only do this conditionally based on stream?
65-
body = (await request.body()).decode()
66-
67-
# Apply filter if applicable
68-
filter_builder = self._get_filter(path)
69-
if filter_builder:
70-
cql_filter = await di.call_with_injected_dependencies(
71-
func=filter_builder,
72-
request=request,
73-
)
74-
cql_filter.validate()
75-
76-
if request.method == "GET":
77-
query = filters.insert_filter(qs=query, filter=cql_filter)
78-
elif request.method in ["POST", "PUT"]:
79-
body_dict = json.loads(body)
80-
body_filter = body_dict.get("filter")
81-
if body_filter:
82-
cql_filter = cql_filter + Expr(body_filter)
83-
body_dict["filter"] = cql_filter.to_json()
84-
body = json.dumps(body_dict)
85-
8638
# https://github.com/fastapi/fastapi/discussions/7382#discussioncomment-5136466
8739
rp_req = self.client.build_request(
8840
request.method,
8941
url=httpx.URL(
90-
path=path,
91-
query=query.encode("utf-8"),
42+
path=request.url.path,
43+
query=request.url.query.encode("utf-8"),
9244
),
9345
headers=headers,
94-
content=body,
46+
content=request.stream(),
9547
)
9648
logger.debug(f"Proxying request to {rp_req.url}")
9749

9850
start_time = time.perf_counter()
99-
rp_resp = await self.client.send(rp_req, stream=stream)
51+
rp_resp = await self.client.send(rp_req, stream=True)
10052
proxy_time = time.perf_counter() - start_time
10153

10254
logger.debug(
@@ -107,11 +59,7 @@ async def proxy_request(self, request: Request, *, stream=False) -> httpx.Respon
10759

10860
async def stream(self, request: Request) -> StreamingResponse:
10961
"""Transparently proxy a request to the upstream STAC API."""
110-
rp_resp = await self.proxy_request(
111-
request,
112-
# collections_filter=collections_filter,
113-
stream=True,
114-
)
62+
rp_resp = await self.proxy_request(request)
11563
return StreamingResponse(
11664
rp_resp.aiter_raw(),
11765
status_code=rp_resp.status_code,
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import json
2+
from dataclasses import dataclass
3+
from typing import Annotated, Callable, Optional
4+
5+
from cql2 import Expr
6+
from starlette.types import ASGIApp, Message, Receive, Scope, Send
7+
from starlette.requests import Request
8+
9+
from ..config import EndpointMethods
10+
from ..utils import di, filters, requests
11+
12+
13+
FILTER_STATE_KEY = "cql2_filter"
14+
15+
16+
@dataclass(frozen=True)
17+
class BuildCql2FilterMiddleware:
18+
"""Middleware to build the Cql2Filter."""
19+
20+
app: ASGIApp
21+
22+
# Filters
23+
collections_filter: Optional[Callable] = None
24+
items_filter: Optional[Callable] = None
25+
26+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
27+
if scope["type"] != "http":
28+
return await self.app(scope, receive, send)
29+
30+
request = Request(scope)
31+
filter_builder = self._get_filter(request.url.path)
32+
if filter_builder:
33+
cql2_filter = await di.call_with_injected_dependencies(
34+
func=filter_builder,
35+
request=request,
36+
)
37+
cql2_filter.validate()
38+
scope["state"][FILTER_STATE_KEY] = cql2_filter
39+
40+
return await self.app(scope, receive, send)
41+
42+
def _get_filter(self, path: str) -> Optional[Callable[..., Expr]]:
43+
"""Get the CQL2 filter builder for the given path."""
44+
endpoint_filters = [
45+
# TODO: Use collections_filter_endpoints & items_filter_endpoints
46+
(filters.is_collection_endpoint, self.collections_filter),
47+
(filters.is_item_endpoint, self.items_filter),
48+
(filters.is_search_endpoint, self.items_filter),
49+
]
50+
for check, builder in endpoint_filters:
51+
if check(path):
52+
return builder
53+
return None
54+
55+
56+
@dataclass(frozen=True)
57+
class ApplyCql2FilterMiddleware:
58+
"""Middleware to add the OpenAPI spec to the response."""
59+
60+
app: ASGIApp
61+
62+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
63+
"""Add the Cql2Filter to the request."""
64+
request = Request(scope)
65+
cql2_filter = request.state.get(FILTER_STATE_KEY)
66+
67+
if scope["type"] != "http" or not cql2_filter:
68+
return await self.app(scope, receive, send)
69+
70+
# Apply filter if applicable
71+
72+
total_body = b""
73+
74+
async def receive_with_filter(message: Message):
75+
query = request.url.query
76+
77+
# TODO: How do we handle querystrings in middleware?
78+
if request.method == "GET":
79+
query = filters.insert_qs_filter(qs=query, filter=cql2_filter)
80+
81+
if message["type"] == "http.response.body":
82+
nonlocal total_body
83+
total_body += message["body"]
84+
if message["more_body"]:
85+
return await receive({**message, "body": b""})
86+
87+
# TODO: Only on search, not on create or update...
88+
if request.method in ["POST", "PUT"]:
89+
return await receive(
90+
{
91+
"type": "http.response.body",
92+
"body": requests.dict_to_bytes(
93+
filters.append_body_filter(
94+
json.loads(total_body), cql2_filter
95+
)
96+
),
97+
"more_body": False,
98+
}
99+
)
100+
101+
return await receive(message)
102+
103+
await receive(message)
104+
105+
return await self.app(scope, receive_with_filter, send)

src/stac_auth_proxy/utils/filters.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from cql2 import Expr
77

88

9-
def insert_filter(qs: str, filter: Expr) -> str:
9+
def append_qs_filter(qs: str, filter: Expr) -> str:
1010
"""Insert a filter expression into a query string. If a filter already exists, combine them."""
1111
qs_dict = parse_qs(qs)
1212

@@ -19,6 +19,16 @@ def insert_filter(qs: str, filter: Expr) -> str:
1919
return urlencode(qs_dict, doseq=True)
2020

2121

22+
def append_body_filter(body: dict, filter: Expr) -> dict:
23+
"""Insert a filter expression into a request body. If a filter already exists, combine them."""
24+
cur_filter = body.get("filter")
25+
if cur_filter:
26+
filter = filter + Expr(cur_filter)
27+
body["filter"] = filter.to_json()
28+
body["filter-lang"] = "cql2-json"
29+
return body
30+
31+
2232
def is_collection_endpoint(path: str) -> bool:
2333
"""Check if the path is a collection endpoint."""
2434
# TODO: Expand this to cover all cases where a collection filter should be applied

0 commit comments

Comments
 (0)