|
| 1 | +import hashlib |
| 2 | +import typing as t |
| 3 | + |
| 4 | +from ellar.core import GuardCanActivate, IExecutionContext, Response |
| 5 | +from ellar.core.connection import HTTPConnection |
| 6 | +from ellar.di import injectable |
| 7 | +from ellar.helper import get_name |
| 8 | +from ellar.services import Reflector |
| 9 | + |
| 10 | +from .constants import THROTTLER_LIMIT, THROTTLER_SKIP, THROTTLER_TTL |
| 11 | +from .exception import ThrottledException |
| 12 | +from .interfaces import IThrottlerStorage |
| 13 | + |
| 14 | + |
| 15 | +@injectable() |
| 16 | +class ThrottlerGuard(GuardCanActivate): |
| 17 | + header_prefix = "X-RateLimit" |
| 18 | + |
| 19 | + def __init__( |
| 20 | + self, storage_service: IThrottlerStorage, reflector: Reflector |
| 21 | + ) -> None: |
| 22 | + self.storage_service = storage_service |
| 23 | + self.reflector = reflector |
| 24 | + |
| 25 | + async def can_activate(self, context: IExecutionContext) -> bool: |
| 26 | + handler = context.get_handler() |
| 27 | + class_ref = context.get_class() |
| 28 | + |
| 29 | + # Return early if the current route should be skipped. |
| 30 | + # or self.options.skipIf?.(context) |
| 31 | + if self.reflector.get_all_and_override(THROTTLER_SKIP, handler, class_ref): |
| 32 | + return True |
| 33 | + |
| 34 | + # Return early when we have no limit or ttl data. |
| 35 | + route_or_class_limit = self.reflector.get_all_and_override( |
| 36 | + THROTTLER_LIMIT, handler, class_ref |
| 37 | + ) |
| 38 | + route_or_class_ttl = self.reflector.get_all_and_override( |
| 39 | + THROTTLER_TTL, handler, class_ref |
| 40 | + ) |
| 41 | + |
| 42 | + # Check if specific limits are set at class or route level, otherwise use global options. |
| 43 | + limit = route_or_class_limit or 60 # or this.options.limit |
| 44 | + ttl = route_or_class_ttl or 60 # or this.options.ttl |
| 45 | + return await self.handle_request(context, limit, ttl) |
| 46 | + |
| 47 | + @classmethod |
| 48 | + def get_request_response( |
| 49 | + cls, context: IExecutionContext |
| 50 | + ) -> t.Tuple[HTTPConnection, Response]: |
| 51 | + connection_host = context.switch_to_http_connection() |
| 52 | + return connection_host.get_client(), connection_host.get_response() |
| 53 | + |
| 54 | + def get_tracker(self, connection: HTTPConnection) -> str: |
| 55 | + assert connection.client |
| 56 | + return connection.client.host |
| 57 | + |
| 58 | + def generate_key(self, context: IExecutionContext, suffix: str) -> str: |
| 59 | + prefix = f"{get_name(context.get_class())}-{get_name(context.get_handler())}" |
| 60 | + return hashlib.md5(f"{prefix}-{suffix}".encode("utf8")).hexdigest() |
| 61 | + |
| 62 | + async def handle_request( |
| 63 | + self, context: IExecutionContext, limit: int, ttl: int |
| 64 | + ) -> bool: |
| 65 | + connection, response = self.get_request_response(context) |
| 66 | + # TODO: Return early if the current user agent should be ignored. |
| 67 | + |
| 68 | + tracker = self.get_tracker(connection) |
| 69 | + key = self.generate_key(context, tracker) |
| 70 | + result = await self.storage_service.increment(key, ttl) |
| 71 | + |
| 72 | + # Throw an error when the user reached their limit. |
| 73 | + if result.total_hits > limit: |
| 74 | + raise ThrottledException(wait=result.time_to_expire) |
| 75 | + |
| 76 | + response.headers[f"{self.header_prefix}-Limit"] = str(limit) |
| 77 | + response.headers[f"{self.header_prefix}-Remaining"] = str( |
| 78 | + max(0, limit - result.total_hits) |
| 79 | + ) |
| 80 | + response.headers[f"{self.header_prefix}-Reset"] = str(result.time_to_expire) |
| 81 | + |
| 82 | + return True |
0 commit comments