Skip to content

Commit 129d088

Browse files
committed
Finalize fix
1 parent 6462bed commit 129d088

File tree

1 file changed

+100
-112
lines changed

1 file changed

+100
-112
lines changed

src/stac_auth_proxy/middleware/ApplyCql2FilterMiddleware.py

Lines changed: 100 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22

33
import json
44
import re
5-
from dataclasses import dataclass, field
5+
from dataclasses import dataclass
66
from logging import getLogger
77
from typing import Optional
88

99
from cql2 import Expr
10-
from starlette.datastructures import MutableHeaders, State
10+
from starlette.datastructures import MutableHeaders
1111
from starlette.requests import Request
1212
from starlette.types import ASGIApp, Message, Receive, Scope, Send
1313

@@ -45,60 +45,18 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
4545

4646
# Handle POST, PUT, PATCH
4747
if request.method in ["POST", "PUT", "PATCH"]:
48-
body = b""
49-
more_body = True
50-
receive_ = receive
51-
52-
async def buffered_receive():
53-
nonlocal body, more_body
54-
if more_body:
55-
message = await receive_()
56-
if message["type"] == "http.request":
57-
body += message.get("body", b"")
58-
more_body = message.get("more_body", False)
59-
return message
60-
return {"type": "http.request", "body": b"", "more_body": False}
61-
62-
while more_body:
63-
await buffered_receive()
64-
65-
# Modify body
66-
# modified_body = body + b"\nAppended content."
67-
try:
68-
body = json.loads(body)
69-
except json.JSONDecodeError as e:
70-
logger.warning("Failed to parse request body as JSON")
71-
# TODO: Return a 400 error
72-
raise e
73-
74-
new_body = json.dumps(filters.append_body_filter(body, cql2_filter)).encode(
75-
"utf-8"
48+
req_body_handler = Cql2RequestBodyAugmentor(
49+
app=self.app,
50+
cql2_filter=cql2_filter,
7651
)
77-
78-
# Override the receive function with a generator sending our modified body
79-
async def new_receive():
80-
nonlocal new_body
81-
chunk = new_body
82-
new_body = b""
83-
return {
84-
"type": "http.request",
85-
"body": chunk,
86-
"more_body": False,
87-
}
88-
89-
# Patch content-length in the headers
90-
headers = dict(scope["headers"])
91-
headers[b"content-length"] = str(len(new_body)).encode("latin1")
92-
scope["headers"] = list(headers.items())
93-
94-
await self.app(scope, new_receive, send)
52+
return await req_body_handler(scope, receive, send)
9553

9654
if re.match(r"^/collections/([^/]+)/items/([^/]+)$", request.url.path):
97-
return await self.app(
98-
scope,
99-
receive,
100-
Cql2ResponseBodyValidator(cql2_filter=cql2_filter, send=send),
55+
res_body_validator = Cql2ResponseBodyValidator(
56+
app=self.app,
57+
cql2_filter=cql2_filter,
10158
)
59+
return await res_body_validator(scope, send, receive)
10260

10361
scope["query_string"] = filters.append_qs_filter(request.url.query, cql2_filter)
10462
return await self.app(scope, receive, send)
@@ -108,85 +66,115 @@ async def new_receive():
10866
class Cql2RequestBodyAugmentor:
10967
"""Handler to augment the request body with a CQL2 filter."""
11068

111-
receive: Receive
112-
state: State
69+
app: ASGIApp
11370
cql2_filter: Expr
114-
initial_message: Optional[Message] = field(init=False)
115-
116-
async def __call__(self) -> Message:
117-
"""Process a request body and augment with a CQL2 filter if available."""
118-
message = await self.receive()
119-
if message["type"] == "http.response.start":
120-
self.initial_message = message
121-
return
122-
123-
if not message.get("body"):
124-
return message
12571

72+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
73+
"""Augment the request body with a CQL2 filter."""
74+
body = b""
75+
more_body = True
76+
77+
# Read the body
78+
while more_body:
79+
message = await receive()
80+
if message["type"] == "http.request":
81+
body += message.get("body", b"")
82+
more_body = message.get("more_body", False)
83+
84+
# Modify body
12685
try:
127-
body = json.loads(message.get("body"))
86+
body = json.loads(body)
12887
except json.JSONDecodeError as e:
12988
logger.warning("Failed to parse request body as JSON")
13089
# TODO: Return a 400 error
13190
raise e
13291

133-
new_body = filters.append_body_filter(body, self.cql2_filter)
134-
message["body"] = json.dumps(new_body).encode("utf-8")
135-
return message
92+
# Augment the body
93+
assert isinstance(body, dict), "Request body must be a JSON object"
94+
new_body = json.dumps(
95+
filters.append_body_filter(body, self.cql2_filter)
96+
).encode("utf-8")
97+
98+
# Patch content-length in the headers
99+
headers = dict(scope["headers"])
100+
headers[b"content-length"] = str(len(new_body)).encode("latin1")
101+
scope["headers"] = list(headers.items())
102+
103+
async def new_receive():
104+
return {
105+
"type": "http.request",
106+
"body": new_body,
107+
"more_body": False,
108+
}
109+
110+
await self.app(scope, new_receive, send)
136111

137112

138113
@dataclass
139114
class Cql2ResponseBodyValidator:
140115
"""Handler to validate response body with CQL2."""
141116

142-
send: Send
117+
app: ASGIApp
143118
cql2_filter: Expr
144-
initial_message: Optional[Message] = field(init=False)
145-
body: bytes = field(init=False, default_factory=bytes)
146119

147-
async def __call__(self, message: Message) -> None:
120+
async def __call__(self, scope: Scope, send: Send, receive: Receive) -> None:
148121
"""Process a response message and apply filtering if needed."""
149-
if message["type"] == "http.response.start":
150-
self.initial_message = message
151-
return
152-
153-
assert self.initial_message, "Initial message not set"
154-
155-
self.body += message["body"]
156-
if message.get("more_body"):
157-
return
158-
159-
try:
160-
body_json = json.loads(self.body)
161-
except json.JSONDecodeError:
162-
logger.warning("Failed to parse response body as JSON")
163-
await self._send_error_response(502, "Not found")
164-
return
165-
166-
logger.debug("Applying %s filter to %s", self.cql2_filter.to_text(), body_json)
167-
if self.cql2_filter.matches(body_json):
168-
await self.send(self.initial_message)
169-
return await self.send(
122+
if scope["type"] != "http":
123+
return await self.app(scope, send, receive)
124+
125+
body = b""
126+
initial_message: Optional[Message] = None
127+
128+
async def _send_error_response(status: int, message: str) -> None:
129+
"""Send an error response with the given status and message."""
130+
assert initial_message, "Initial message not set"
131+
error_body = json.dumps({"message": message}).encode("utf-8")
132+
headers = MutableHeaders(scope=initial_message)
133+
headers["content-length"] = str(len(error_body))
134+
initial_message["status"] = status
135+
await send(initial_message)
136+
await send(
170137
{
171138
"type": "http.response.body",
172-
"body": json.dumps(body_json).encode("utf-8"),
139+
"body": error_body,
173140
"more_body": False,
174141
}
175142
)
176-
return await self._send_error_response(404, "Not found")
177-
178-
async def _send_error_response(self, status: int, message: str) -> None:
179-
"""Send an error response with the given status and message."""
180-
assert self.initial_message, "Initial message not set"
181-
error_body = json.dumps({"message": message}).encode("utf-8")
182-
headers = MutableHeaders(scope=self.initial_message)
183-
headers["content-length"] = str(len(error_body))
184-
self.initial_message["status"] = status
185-
await self.send(self.initial_message)
186-
await self.send(
187-
{
188-
"type": "http.response.body",
189-
"body": error_body,
190-
"more_body": False,
191-
}
192-
)
143+
144+
async def buffered_send(message: Message) -> None:
145+
"""Process a response message and apply filtering if needed."""
146+
nonlocal body
147+
nonlocal initial_message
148+
149+
if message["type"] == "http.response.start":
150+
initial_message = message
151+
return
152+
153+
assert initial_message, "Initial message not set"
154+
155+
body += message["body"]
156+
if message.get("more_body"):
157+
return
158+
159+
try:
160+
body_json = json.loads(body)
161+
except json.JSONDecodeError:
162+
logger.warning("Failed to parse response body as JSON")
163+
await _send_error_response(502, "Not found")
164+
return
165+
166+
logger.debug(
167+
"Applying %s filter to %s", self.cql2_filter.to_text(), body_json
168+
)
169+
if self.cql2_filter.matches(body_json):
170+
await send(initial_message)
171+
return await send(
172+
{
173+
"type": "http.response.body",
174+
"body": json.dumps(body_json).encode("utf-8"),
175+
"more_body": False,
176+
}
177+
)
178+
return await _send_error_response(404, "Not found")
179+
180+
return await self.app(scope, receive, buffered_send)

0 commit comments

Comments
 (0)