Skip to content

Commit 89d6a3f

Browse files
committed
chore: separate CQL2 filter middelware into separate files
1 parent 2ad4c28 commit 89d6a3f

File tree

3 files changed

+66
-55
lines changed

3 files changed

+66
-55
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""Middleware to apply CQL2 filters."""
2+
3+
import json
4+
from dataclasses import dataclass
5+
from logging import getLogger
6+
7+
from starlette.requests import Request
8+
from starlette.types import ASGIApp, Message, Receive, Scope, Send
9+
10+
from ..utils import filters
11+
12+
logger = getLogger(__name__)
13+
14+
15+
@dataclass(frozen=True)
16+
class ApplyCql2FilterMiddleware:
17+
"""Middleware to apply the Cql2Filter to the request."""
18+
19+
app: ASGIApp
20+
21+
state_key: str = "cql2_filter"
22+
23+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
24+
"""Add the Cql2Filter to the request."""
25+
if scope["type"] != "http":
26+
return await self.app(scope, receive, send)
27+
28+
request = Request(scope)
29+
30+
if request.method == "GET":
31+
cql2_filter = getattr(request.state, self.state_key, None)
32+
if cql2_filter:
33+
# TODO: Differentiate between list/search and lookup
34+
scope["query_string"] = filters.append_qs_filter(
35+
request.url.query, cql2_filter
36+
)
37+
return await self.app(scope, receive, send)
38+
39+
elif request.method in ["POST", "PUT", "PATCH"]:
40+
41+
async def receive_and_apply_filter() -> Message:
42+
message = await receive()
43+
if message["type"] != "http.request":
44+
return message
45+
46+
cql2_filter = getattr(request.state, self.state_key, None)
47+
if cql2_filter:
48+
try:
49+
body = message.get("body", b"{}")
50+
except json.JSONDecodeError as e:
51+
logger.warning("Failed to parse request body as JSON")
52+
# TODO: Return a 400 error
53+
raise e
54+
55+
new_body = filters.append_body_filter(json.loads(body), cql2_filter)
56+
message["body"] = json.dumps(new_body).encode("utf-8")
57+
return message
58+
59+
return await self.app(scope, receive_and_apply_filter, send)
60+
61+
return await self.app(scope, receive, send)

src/stac_auth_proxy/middleware/Cql2FilterMiddleware.py renamed to src/stac_auth_proxy/middleware/BuildCql2FilterMiddleware.py

Lines changed: 3 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
"""Middleware to build and apply CQL2 filters."""
1+
"""Middleware to build the Cql2Filter."""
22

33
import json
44
from dataclasses import dataclass
5-
from logging import getLogger
65
from typing import Callable, Optional
76

87
from cql2 import Expr
@@ -11,21 +10,19 @@
1110

1211
from ..utils import filters, requests
1312

14-
logger = getLogger(__name__)
15-
1613

1714
@dataclass(frozen=True)
1815
class BuildCql2FilterMiddleware:
1916
"""Middleware to build the Cql2Filter."""
2017

2118
app: ASGIApp
2219

20+
state_key: str = "cql2_filter"
21+
2322
# Filters
2423
collections_filter: Optional[Callable] = None
2524
items_filter: Optional[Callable] = None
2625

27-
state_key: str = "cql2_filter"
28-
2926
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
3027
"""Build the CQL2 filter, place on the request state."""
3128
if scope["type"] != "http":
@@ -87,51 +84,3 @@ def _get_filter(self, path: str) -> Optional[Callable[..., Expr]]:
8784
if check(path):
8885
return builder
8986
return None
90-
91-
92-
@dataclass(frozen=True)
93-
class ApplyCql2FilterMiddleware:
94-
"""Middleware to apply the Cql2Filter to the request."""
95-
96-
app: ASGIApp
97-
98-
state_key: str = "cql2_filter"
99-
100-
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
101-
"""Add the Cql2Filter to the request."""
102-
if scope["type"] != "http":
103-
return await self.app(scope, receive, send)
104-
105-
request = Request(scope)
106-
107-
if request.method == "GET":
108-
cql2_filter = getattr(request.state, self.state_key, None)
109-
if cql2_filter:
110-
scope["query_string"] = filters.append_qs_filter(
111-
request.url.query, cql2_filter
112-
)
113-
return await self.app(scope, receive, send)
114-
115-
elif request.method in ["POST", "PUT", "PATCH"]:
116-
117-
async def receive_and_apply_filter() -> Message:
118-
message = await receive()
119-
if message["type"] != "http.request":
120-
return message
121-
122-
cql2_filter = getattr(request.state, self.state_key, None)
123-
if cql2_filter:
124-
try:
125-
body = message.get("body", b"{}")
126-
except json.JSONDecodeError as e:
127-
logger.warning("Failed to parse request body as JSON")
128-
# TODO: Return a 400 error
129-
raise e
130-
131-
new_body = filters.append_body_filter(json.loads(body), cql2_filter)
132-
message["body"] = json.dumps(new_body).encode("utf-8")
133-
return message
134-
135-
return await self.app(scope, receive_and_apply_filter, send)
136-
137-
return await self.app(scope, receive, send)

src/stac_auth_proxy/middleware/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""Custom middleware."""
22

33
from .AddProcessTimeHeaderMiddleware import AddProcessTimeHeaderMiddleware
4-
from .Cql2FilterMiddleware import ApplyCql2FilterMiddleware, BuildCql2FilterMiddleware
4+
from .ApplyCql2FilterMiddleware import ApplyCql2FilterMiddleware
5+
from .BuildCql2FilterMiddleware import BuildCql2FilterMiddleware
56
from .EnforceAuthMiddleware import EnforceAuthMiddleware
67
from .UpdateOpenApiMiddleware import OpenApiMiddleware
78

0 commit comments

Comments
 (0)