77
88import aioredis
99from starlette .requests import Request
10- from starlette .responses import Response , JSONResponse
10+ from starlette .responses import Response , JSONResponse , StreamingResponse
1111from starlette .routing import Match
1212
1313from . import tokens
@@ -150,6 +150,23 @@ async def get_key(self, zone: str, request: Request) -> str:
150150 key = ratelimit_id_key (request )
151151 return f"{ zone } %{ key } "
152152
153+ async def _transform_and_log (self , request : Request , response : Response ) -> Response :
154+ if isinstance (response , StreamingResponse ):
155+ resp_ = Response (status_code = response .status_code , background = response .background , media_type = cast (str , response .media_type ))
156+ resp_ ._headers = response ._headers
157+ body = b""
158+ async for chunk in response .body_iterator :
159+ if not isinstance (chunk , bytes ):
160+ chunk = chunk .encode (response .charset )
161+
162+ body += chunk
163+
164+ resp_ .body = body
165+ response = resp_
166+
167+ await self .app .state .db .put_log (request , response )
168+ return response
169+
153170 async def get_bucket (self , zone : str , key : str , request : Request ) -> BaseLimitBucket :
154171 if key in self ._keys :
155172 return self ._keys [key ]
@@ -209,7 +226,8 @@ async def middleware(self, request: Request, call_next: _CT) -> Response:
209226
210227 resp = await call_next (request )
211228 resp .headers .update (headers )
212- return resp
229+ return await self ._transform_and_log (request , resp )
230+
213231
214232 else :
215233 headers = {
@@ -225,8 +243,7 @@ async def middleware(self, request: Request, call_next: _CT) -> Response:
225243 }
226244 resp = await call_next (request )
227245 resp .headers .update (headers )
228- self .app .loop .create_task (self .app .state .db .put_log (request , resp ))
229- return resp
246+ return await self ._transform_and_log (request , resp )
230247
231248
232249def parse_ratelimit (limit : str ) -> tuple [int , int ]:
0 commit comments