Skip to content

Commit 0b8c4f0

Browse files
committed
refactor: enhance CQL2 filter middleware
- Removed outdated conformance URLs from ApplyCql2FilterMiddleware. - Refactored error response handling to include error codes in BuildCql2FilterMiddleware. - Improved logging and validation processes for response bodies in Cql2ResponseBodyValidator. - Added required conformance checks in BuildCql2FilterMiddleware based on filter functions.
1 parent f20f259 commit 0b8c4f0

File tree

2 files changed

+76
-29
lines changed

2 files changed

+76
-29
lines changed

src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
r"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2",
2222
r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-text",
2323
r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-json",
24-
r"http://www.opengis.net/spec/ogcapi-features-3/1.0/conf/features-filter",
25-
r"https://api.stacspec.org/v1\.\d+\.\d+(?:-[\w\.]+)?/item-search#filter",
2624
)
2725
@dataclass(frozen=True)
2826
class ApplyCql2FilterMiddleware:
@@ -31,6 +29,11 @@ class ApplyCql2FilterMiddleware:
3129
app: ASGIApp
3230
state_key: str = "cql2_filter"
3331

32+
single_record_endpoints = [
33+
r"^/collections/([^/]+)/items/([^/]+)$",
34+
r"^/collections/([^/]+)$",
35+
]
36+
3437
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
3538
"""Add the Cql2Filter to the request."""
3639
if scope["type"] != "http":
@@ -52,11 +55,9 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
5255
return await req_body_handler(scope, receive, send)
5356

5457
# Handle single record requests (ie non-filterable endpoints)
55-
single_record_endpoints = [
56-
r"^/collections/([^/]+)/items/([^/]+)$",
57-
r"^/collections/([^/]+)$",
58-
]
59-
if any(re.match(expr, request.url.path) for expr in single_record_endpoints):
58+
if any(
59+
re.match(expr, request.url.path) for expr in self.single_record_endpoints
60+
):
6061
res_body_validator = Cql2ResponseBodyValidator(
6162
app=self.app,
6263
cql2_filter=cql2_filter,
@@ -130,18 +131,22 @@ async def __call__(self, scope: Scope, send: Send, receive: Receive) -> None:
130131
body = b""
131132
initial_message: Optional[Message] = None
132133

133-
async def _send_error_response(status: int, message: str) -> None:
134+
async def _send_error_response(status: int, code: str, message: str) -> None:
134135
"""Send an error response with the given status and message."""
135136
assert initial_message, "Initial message not set"
136-
error_body = json.dumps({"message": message}).encode("utf-8")
137+
response_dict = {
138+
"code": code,
139+
"description": message,
140+
}
141+
response_bytes = json.dumps(response_dict).encode("utf-8")
137142
headers = MutableHeaders(scope=initial_message)
138-
headers["content-length"] = str(len(error_body))
143+
headers["content-length"] = str(len(response_bytes))
139144
initial_message["status"] = status
140145
await send(initial_message)
141146
await send(
142147
{
143148
"type": "http.response.body",
144-
"body": error_body,
149+
"body": response_bytes,
145150
"more_body": False,
146151
}
147152
)
@@ -150,40 +155,48 @@ async def buffered_send(message: Message) -> None:
150155
"""Process a response message and apply filtering if needed."""
151156
nonlocal body
152157
nonlocal initial_message
158+
initial_message = initial_message or message
159+
# NOTE: to avoid data-leak, we process 404s so their responses are the same as rejected 200s
160+
should_process = initial_message["status"] in [200, 404]
161+
162+
if not should_process:
163+
return await send(message)
153164

154165
if message["type"] == "http.response.start":
155-
initial_message = message
166+
# Hold off on sending response headers until we've validated the response body
156167
return
157168

158-
assert initial_message, "Initial message not set"
159-
160169
body += message["body"]
161170
if message.get("more_body"):
162171
return
163172

164173
try:
165174
body_json = json.loads(body)
166175
except json.JSONDecodeError:
167-
logger.warning("Failed to parse response body as JSON")
168-
await _send_error_response(502, "Not found")
176+
msg = "Failed to parse response body as JSON"
177+
logger.warning(msg)
178+
await _send_error_response(status=502, code="ParseError", message=msg)
169179
return
170180

171-
logger.debug(
172-
"Applying %s filter to %s", self.cql2_filter.to_text(), body_json
173-
)
174181
try:
175-
if self.cql2_filter.matches(body_json):
176-
await send(initial_message)
177-
return await send(
178-
{
179-
"type": "http.response.body",
180-
"body": json.dumps(body_json).encode("utf-8"),
181-
"more_body": False,
182-
}
183-
)
182+
cql2_matches = self.cql2_filter.matches(body_json)
184183
except Exception as e:
184+
cql2_matches = False
185185
logger.warning("Failed to apply filter: %s", e)
186186

187-
return await _send_error_response(404, "Not found")
187+
if cql2_matches:
188+
logger.debug("Response matches filter, returning record")
189+
await send(initial_message)
190+
return await send(
191+
{
192+
"type": "http.response.body",
193+
"body": json.dumps(body_json).encode("utf-8"),
194+
"more_body": False,
195+
}
196+
)
197+
logger.debug("Response did not match filter, returning 404")
198+
return await _send_error_response(
199+
status=404, code="NotFoundError", message="Record not found."
200+
)
188201

189202
return await self.app(scope, receive, buffered_send)

src/stac_auth_proxy/middleware/BuildCql2FilterMiddleware.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,16 @@
1111
from starlette.types import ASGIApp, Receive, Scope, Send
1212

1313
from ..utils import requests
14+
from ..utils.middleware import required_conformance
1415

1516
logger = logging.getLogger(__name__)
1617

1718

19+
@required_conformance(
20+
"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2",
21+
"http://www.opengis.net/spec/cql2/1.0/conf/cql2-text",
22+
"http://www.opengis.net/spec/cql2/1.0/conf/cql2-json",
23+
)
1824
@dataclass(frozen=True)
1925
class BuildCql2FilterMiddleware:
2026
"""Middleware to build the Cql2Filter."""
@@ -29,6 +35,34 @@ class BuildCql2FilterMiddleware:
2935
items_filter: Optional[Callable] = None
3036
items_filter_path: str = r"^(/collections/([^/]+)/items(/[^/]+)?$|/search$)"
3137

38+
def __post_init__(self):
39+
"""Set required conformances based on the filter functions."""
40+
required_conformances = set()
41+
if self.collections_filter:
42+
logger.debug("Appending required conformance for collections filter")
43+
required_conformances.update(
44+
[
45+
"http://www.opengis.net/spec/ogcapi-features-3/1.0/conf/filter",
46+
"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2",
47+
r"https://api.stacspec.org/v1\.0\.0(?:-[\w\.]+)?/item-search#filter",
48+
"http://www.opengis.net/spec/ogcapi-features-3/1.0/conf/features-filter",
49+
]
50+
)
51+
if self.items_filter:
52+
logger.debug("Appending required conformance for items filter")
53+
required_conformances.update(
54+
[
55+
"https://api.stacspec.org/v1.0.0/core",
56+
r"https://api.stacspec.org/v1\.0\.0(?:-[\w\.]+)?/collection-search#filter",
57+
"http://www.opengis.net/spec/ogcapi-common-2/1.0/conf/simple-query",
58+
]
59+
)
60+
61+
# Must set required conformances on class
62+
self.__class__.__required_conformances__ = required_conformances.union(
63+
getattr(self.__class__, "__required_conformances__", [])
64+
)
65+
3266
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
3367
"""Build the CQL2 filter, place on the request state."""
3468
if scope["type"] != "http":

0 commit comments

Comments
 (0)