Skip to content

Commit 60262a9

Browse files
committed
starlette sucks
1 parent c77424c commit 60262a9

File tree

2 files changed

+24
-5
lines changed

2 files changed

+24
-5
lines changed

mystbin/backend/utils/db.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -939,16 +939,18 @@ async def put_log(
939939
INSERT INTO logs VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
940940
"""
941941
try:
942-
body = await request.body()
942+
body = request._body
943943
except:
944944
body = None
945+
945946
try:
946947
resp = str(response.body)
947948
except AttributeError:
948949
resp = None
949950
await self._do_query(
950951
query,
951952
request.headers.get("X-Forwarded-For", request.client.host),
953+
request.state.user and request.state.user.id,
952954
datetime.datetime.utcnow(),
953955
request.headers.get("CF-RAY"),
954956
request.headers.get("CF-IPCOUNTRY"),

mystbin/backend/utils/ratelimits.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import aioredis
99
from starlette.requests import Request
10-
from starlette.responses import Response, JSONResponse
10+
from starlette.responses import Response, JSONResponse, StreamingResponse
1111
from starlette.routing import Match
1212

1313
from . import tokens
@@ -150,6 +150,23 @@ async def get_key(self, zone: str, request: Request) -> str:
150150
key = ratelimit_id_key(request)
151151
return f"{zone}%{key}"
152152

153+
async def _transform_and_log(self, request: Request, response: Response) -> Response:
154+
if isinstance(response, StreamingResponse):
155+
resp_ = Response(status_code=response.status_code, background=response.background, media_type=cast(str, response.media_type))
156+
resp_._headers = response._headers
157+
body = b""
158+
async for chunk in response.body_iterator:
159+
if not isinstance(chunk, bytes):
160+
chunk = chunk.encode(response.charset)
161+
162+
body += chunk
163+
164+
resp_.body = body
165+
response = resp_
166+
167+
await self.app.state.db.put_log(request, response)
168+
return response
169+
153170
async def get_bucket(self, zone: str, key: str, request: Request) -> BaseLimitBucket:
154171
if key in self._keys:
155172
return self._keys[key]
@@ -209,7 +226,8 @@ async def middleware(self, request: Request, call_next: _CT) -> Response:
209226

210227
resp = await call_next(request)
211228
resp.headers.update(headers)
212-
return resp
229+
return await self._transform_and_log(request, resp)
230+
213231

214232
else:
215233
headers = {
@@ -225,8 +243,7 @@ async def middleware(self, request: Request, call_next: _CT) -> Response:
225243
}
226244
resp = await call_next(request)
227245
resp.headers.update(headers)
228-
self.app.loop.create_task(self.app.state.db.put_log(request, resp))
229-
return resp
246+
return await self._transform_and_log(request, resp)
230247

231248

232249
def parse_ratelimit(limit: str) -> tuple[int, int]:

0 commit comments

Comments
 (0)