Skip to content

Commit 86f6d17

Browse files
authored
refactor(cql2-middleware): breakup cql2 middleware into smaller components (#66)
1 parent 4740cf4 commit 86f6d17

File tree

7 files changed

+305
-214
lines changed

7 files changed

+305
-214
lines changed

src/stac_auth_proxy/app.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
from .handlers import HealthzHandler, ReverseProxyHandler, SwaggerUI
1717
from .middleware import (
1818
AddProcessTimeHeaderMiddleware,
19-
ApplyCql2FilterMiddleware,
2019
AuthenticationExtensionMiddleware,
21-
BuildCql2FilterMiddleware,
20+
Cql2ApplyFilterBodyMiddleware,
21+
Cql2ApplyFilterQueryStringMiddleware,
22+
Cql2BuildFilterMiddleware,
23+
Cql2ValidateResponseBodyMiddleware,
2224
EnforceAuthMiddleware,
2325
OpenApiMiddleware,
2426
ProcessLinksMiddleware,
@@ -132,11 +134,11 @@ async def lifespan(app: FastAPI):
132134
)
133135

134136
if settings.items_filter or settings.collections_filter:
137+
app.add_middleware(Cql2ValidateResponseBodyMiddleware)
138+
app.add_middleware(Cql2ApplyFilterBodyMiddleware)
139+
app.add_middleware(Cql2ApplyFilterQueryStringMiddleware)
135140
app.add_middleware(
136-
ApplyCql2FilterMiddleware,
137-
)
138-
app.add_middleware(
139-
BuildCql2FilterMiddleware,
141+
Cql2BuildFilterMiddleware,
140142
items_filter=settings.items_filter() if settings.items_filter else None,
141143
collections_filter=(
142144
settings.collections_filter() if settings.collections_filter else None

src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py

Lines changed: 0 additions & 202 deletions
This file was deleted.
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""Middleware to augment the request body with a CQL2 filter for POST/PUT/PATCH requests."""
2+
3+
import json
4+
from dataclasses import dataclass
5+
from logging import getLogger
6+
from typing import Optional
7+
8+
from cql2 import Expr
9+
from starlette.requests import Request
10+
from starlette.types import ASGIApp, Receive, Scope, Send
11+
12+
from ..utils import filters
13+
from ..utils.middleware import required_conformance
14+
15+
logger = getLogger(__name__)
16+
17+
18+
@required_conformance(
19+
r"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2",
20+
r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-text",
21+
r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-json",
22+
)
23+
@dataclass(frozen=True)
24+
class Cql2ApplyFilterBodyMiddleware:
25+
"""Middleware to augment the request body with a CQL2 filter for POST/PUT/PATCH requests."""
26+
27+
app: ASGIApp
28+
state_key: str = "cql2_filter"
29+
30+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
31+
"""Apply the CQL2 filter to the request body."""
32+
if scope["type"] != "http":
33+
return await self.app(scope, receive, send)
34+
35+
request = Request(scope)
36+
cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None)
37+
if not cql2_filter:
38+
return await self.app(scope, receive, send)
39+
40+
if request.method not in ["POST", "PUT", "PATCH"]:
41+
return await self.app(scope, receive, send)
42+
43+
body = b""
44+
more_body = True
45+
while more_body:
46+
message = await receive()
47+
if message["type"] == "http.request":
48+
body += message.get("body", b"")
49+
more_body = message.get("more_body", False)
50+
51+
try:
52+
body_json = json.loads(body) if body else {}
53+
except json.JSONDecodeError:
54+
logger.warning("Failed to parse request body as JSON")
55+
from starlette.responses import JSONResponse
56+
57+
response = JSONResponse(
58+
{
59+
"code": "ParseError",
60+
"description": "Request body must be valid JSON.",
61+
},
62+
status_code=400,
63+
)
64+
await response(scope, receive, send)
65+
return
66+
67+
if not isinstance(body_json, dict):
68+
logger.warning("Request body must be a JSON object")
69+
from starlette.responses import JSONResponse
70+
71+
response = JSONResponse(
72+
{
73+
"code": "TypeError",
74+
"description": "Request body must be a JSON object.",
75+
},
76+
status_code=400,
77+
)
78+
await response(scope, receive, send)
79+
return
80+
81+
new_body = json.dumps(
82+
filters.append_body_filter(body_json, cql2_filter)
83+
).encode("utf-8")
84+
85+
# Patch content-length in the headers
86+
headers = dict(scope["headers"])
87+
headers[b"content-length"] = str(len(new_body)).encode("latin1")
88+
scope = dict(scope)
89+
scope["headers"] = list(headers.items())
90+
91+
async def new_receive():
92+
return {
93+
"type": "http.request",
94+
"body": new_body,
95+
"more_body": False,
96+
}
97+
98+
await self.app(scope, new_receive, send)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""Middleware to inject CQL2 filters into the query string for GET/list endpoints."""
2+
3+
import re
4+
from dataclasses import dataclass
5+
from logging import getLogger
6+
from typing import Optional
7+
8+
from cql2 import Expr
9+
from starlette.requests import Request
10+
from starlette.types import ASGIApp, Receive, Scope, Send
11+
12+
from ..utils import filters
13+
from ..utils.middleware import required_conformance
14+
15+
logger = getLogger(__name__)
16+
17+
18+
@required_conformance(
19+
r"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2",
20+
r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-text",
21+
r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-json",
22+
)
23+
@dataclass(frozen=True)
24+
class Cql2ApplyFilterQueryStringMiddleware:
25+
"""Middleware to inject CQL2 filters into the query string for GET/list endpoints."""
26+
27+
app: ASGIApp
28+
state_key: str = "cql2_filter"
29+
30+
single_record_endpoints = [
31+
r"^/collections/([^/]+)/items/([^/]+)$",
32+
r"^/collections/([^/]+)$",
33+
]
34+
35+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
36+
"""Apply the CQL2 filter to the query string."""
37+
if scope["type"] != "http":
38+
return await self.app(scope, receive, send)
39+
40+
request = Request(scope)
41+
cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None)
42+
if not cql2_filter:
43+
return await self.app(scope, receive, send)
44+
45+
# Only handle GET requests that are not single-record endpoints
46+
if request.method != "GET":
47+
return await self.app(scope, receive, send)
48+
if any(
49+
re.match(expr, request.url.path) for expr in self.single_record_endpoints
50+
):
51+
return await self.app(scope, receive, send)
52+
53+
# Inject filter into query string
54+
scope = dict(scope)
55+
scope["query_string"] = filters.append_qs_filter(request.url.query, cql2_filter)
56+
return await self.app(scope, receive, send)

src/stac_auth_proxy/middleware/BuildCql2FilterMiddleware.py renamed to src/stac_auth_proxy/middleware/Cql2BuildFilterMiddleware.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
"http://www.opengis.net/spec/cql2/1.0/conf/cql2-json",
2323
)
2424
@dataclass(frozen=True)
25-
class BuildCql2FilterMiddleware:
25+
class Cql2BuildFilterMiddleware:
2626
"""Middleware to build the Cql2Filter."""
2727

2828
app: ASGIApp

0 commit comments

Comments
 (0)