Skip to content

Commit f8c7bea

Browse files
committed
In progress...
1 parent 0b67176 commit f8c7bea

File tree

2 files changed

+104
-90
lines changed

2 files changed

+104
-90
lines changed

src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py

Lines changed: 79 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33
import json
44
import re
55
from dataclasses import dataclass, field
6-
from functools import partial
76
from logging import getLogger
8-
from typing import Callable, Optional
7+
from typing import Optional
98

109
from cql2 import Expr
1110
from starlette.datastructures import MutableHeaders, State
@@ -39,25 +38,60 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
3938

4039
request = Request(scope)
4140

42-
get_cql2_filter: Callable[[], Optional[Expr]] = partial(
43-
getattr, request.state, self.state_key, None
44-
)
41+
cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None)
42+
43+
if not cql2_filter:
44+
return await self.app(scope, receive, send)
4545

4646
# Handle POST, PUT, PATCH
4747
if request.method in ["POST", "PUT", "PATCH"]:
48-
return await self.app(
49-
scope,
50-
Cql2RequestBodyAugmentor(
51-
receive=receive,
52-
state=request.state,
53-
get_cql2_filter=get_cql2_filter,
54-
),
55-
send,
48+
body = b""
49+
more_body = True
50+
receive_ = receive
51+
52+
async def buffered_receive():
53+
nonlocal body, more_body
54+
if more_body:
55+
message = await receive_()
56+
if message["type"] == "http.request":
57+
body += message.get("body", b"")
58+
more_body = message.get("more_body", False)
59+
return message
60+
return {"type": "http.request", "body": b"", "more_body": False}
61+
62+
while more_body:
63+
await buffered_receive()
64+
65+
# Modify body
66+
# modified_body = body + b"\nAppended content."
67+
try:
68+
body = json.loads(body)
69+
except json.JSONDecodeError as e:
70+
logger.warning("Failed to parse request body as JSON")
71+
# TODO: Return a 400 error
72+
raise e
73+
74+
new_body = json.dumps(filters.append_body_filter(body, cql2_filter)).encode(
75+
"utf-8"
5676
)
5777

58-
cql2_filter = get_cql2_filter()
59-
if not cql2_filter:
60-
return await self.app(scope, receive, send)
78+
# Override the receive function with a generator sending our modified body
79+
async def new_receive():
80+
nonlocal new_body
81+
chunk = new_body
82+
new_body = b""
83+
return {
84+
"type": "http.request",
85+
"body": chunk,
86+
"more_body": False,
87+
}
88+
89+
# Patch content-length in the headers
90+
headers = dict(scope["headers"])
91+
headers[b"content-length"] = str(len(new_body)).encode("latin1")
92+
scope["headers"] = list(headers.items())
93+
94+
await self.app(scope, new_receive, send)
6195

6296
if re.match(r"^/collections/([^/]+)/items/([^/]+)$", request.url.path):
6397
return await self.app(
@@ -76,27 +110,27 @@ class Cql2RequestBodyAugmentor:
76110

77111
receive: Receive
78112
state: State
79-
get_cql2_filter: Callable[[], Optional[Expr]]
113+
cql2_filter: Expr
114+
initial_message: Optional[Message] = field(init=False)
80115

81116
async def __call__(self) -> Message:
82117
"""Process a request body and augment with a CQL2 filter if available."""
83118
message = await self.receive()
84-
if message["type"] != "http.request":
85-
return message
119+
if message["type"] == "http.response.start":
120+
self.initial_message = message
121+
return
86122

87-
# NOTE: Can only get cql2 filter _after_ calling self.receive()
88-
cql2_filter = self.get_cql2_filter()
89-
if not cql2_filter:
123+
if not message.get("body"):
90124
return message
91125

92126
try:
93-
body = json.loads(message.get("body", b"{}"))
127+
body = json.loads(message.get("body"))
94128
except json.JSONDecodeError as e:
95129
logger.warning("Failed to parse request body as JSON")
96130
# TODO: Return a 400 error
97131
raise e
98132

99-
new_body = filters.append_body_filter(body, cql2_filter)
133+
new_body = filters.append_body_filter(body, self.cql2_filter)
100134
message["body"] = json.dumps(new_body).encode("utf-8")
101135
return message
102136

@@ -116,33 +150,30 @@ async def __call__(self, message: Message) -> None:
116150
self.initial_message = message
117151
return
118152

119-
if message["type"] == "http.response.body":
120-
assert self.initial_message, "Initial message not set"
153+
assert self.initial_message, "Initial message not set"
121154

122-
self.body += message["body"]
123-
if message.get("more_body"):
124-
return
155+
self.body += message["body"]
156+
if message.get("more_body"):
157+
return
125158

126-
try:
127-
body_json = json.loads(self.body)
128-
except json.JSONDecodeError:
129-
logger.warning("Failed to parse response body as JSON")
130-
await self._send_error_response(502, "Not found")
131-
return
132-
133-
logger.debug(
134-
"Applying %s filter to %s", self.cql2_filter.to_text(), body_json
159+
try:
160+
body_json = json.loads(self.body)
161+
except json.JSONDecodeError:
162+
logger.warning("Failed to parse response body as JSON")
163+
await self._send_error_response(502, "Not found")
164+
return
165+
166+
logger.debug("Applying %s filter to %s", self.cql2_filter.to_text(), body_json)
167+
if self.cql2_filter.matches(body_json):
168+
await self.send(self.initial_message)
169+
return await self.send(
170+
{
171+
"type": "http.response.body",
172+
"body": json.dumps(body_json).encode("utf-8"),
173+
"more_body": False,
174+
}
135175
)
136-
if self.cql2_filter.matches(body_json):
137-
await self.send(self.initial_message)
138-
return await self.send(
139-
{
140-
"type": "http.response.body",
141-
"body": json.dumps(body_json).encode("utf-8"),
142-
"more_body": False,
143-
}
144-
)
145-
return await self._send_error_response(404, "Not found")
176+
return await self._send_error_response(404, "Not found")
146177

147178
async def _send_error_response(self, status: int, message: str) -> None:
148179
"""Send an error response with the given status and message."""

src/stac_auth_proxy/middleware/BuildCql2FilterMiddleware.py

Lines changed: 25 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
"""Middleware to build the Cql2Filter."""
22

3-
import json
3+
import logging
44
import re
55
from dataclasses import dataclass
66
from typing import Any, Awaitable, Callable, Optional
77

8-
from cql2 import Expr
8+
from cql2 import Expr, ValidationError
99
from starlette.requests import Request
10-
from starlette.types import ASGIApp, Message, Receive, Scope, Send
10+
from starlette.responses import Response
11+
from starlette.types import ASGIApp, Receive, Scope, Send
1112

1213
from ..utils import requests
1314

15+
logger = logging.getLogger(__name__)
16+
1417

1518
@dataclass(frozen=True)
1619
class BuildCql2FilterMiddleware:
@@ -35,47 +38,27 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
3538
if not filter_builder:
3639
return await self.app(scope, receive, send)
3740

38-
async def set_filter(body: Optional[dict] = None) -> None:
39-
assert filter_builder is not None
40-
filter_expr = await filter_builder(
41-
{
42-
"req": {
43-
"path": request.url.path,
44-
"method": request.method,
45-
"query_params": dict(request.query_params),
46-
"path_params": requests.extract_variables(request.url.path),
47-
"headers": dict(request.headers),
48-
"body": body,
49-
},
50-
**scope["state"],
51-
}
52-
)
53-
cql2_filter = Expr(filter_expr)
41+
filter_expr = await filter_builder(
42+
{
43+
"req": {
44+
"path": request.url.path,
45+
"method": request.method,
46+
"query_params": dict(request.query_params),
47+
"path_params": requests.extract_variables(request.url.path),
48+
"headers": dict(request.headers),
49+
},
50+
**scope["state"],
51+
}
52+
)
53+
cql2_filter = Expr(filter_expr)
54+
try:
5455
cql2_filter.validate()
55-
setattr(request.state, self.state_key, cql2_filter)
56-
57-
# For GET requests, we can build the filter immediately
58-
if request.method == "GET":
59-
await set_filter()
60-
return await self.app(scope, receive, send)
61-
62-
total_body = b""
63-
64-
async def receive_build_filter() -> Message:
65-
"""
66-
Receive the body of the request and build the filter.
67-
NOTE: This is not called for GET requests.
68-
"""
69-
nonlocal total_body
70-
71-
message = await receive()
72-
total_body += message.get("body", b"")
73-
74-
if not message.get("more_body"):
75-
await set_filter(json.loads(total_body) if total_body else None)
76-
return message
56+
except ValidationError:
57+
logger.exception("Invalid CQL2 filter: %s", filter_expr)
58+
return await Response(status_code=502, content="Invalid CQL2 filter")
59+
setattr(request.state, self.state_key, cql2_filter)
7760

78-
return await self.app(scope, receive_build_filter, send)
61+
return await self.app(scope, receive, send)
7962

8063
def _get_filter(
8164
self, path: str

0 commit comments

Comments
 (0)