Skip to content

Commit f11ad14

Browse files
committed
first commit
1 parent 8995643 commit f11ad14

File tree

13 files changed

+348
-27
lines changed

13 files changed

+348
-27
lines changed

.github/workflows/publish.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,3 @@ jobs:
2323
FLIT_USERNAME: ${{ secrets.FLIT_USERNAME }}
2424
FLIT_PASSWORD: ${{ secrets.FLIT_PASSWORD }}
2525
run: flit publish
26-
- name: Deploy Documentation
27-
run: make doc-deploy

.github/workflows/test_full.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ jobs:
3838
- name: Install Dependencies
3939
run: flit install --symlink
4040
- name: Black
41-
run: black --check ellar tests
41+
run: black --check ellar_throttler tests
4242
- name: isort
43-
run: isort --check ellar tests
43+
run: isort --check ellar_throttler tests
4444
- name: Flake8
45-
run: flake8 ellar tests
45+
run: flake8 ellar_throttler tests
4646
- name: mypy
47-
run: mypy ellar
47+
run: mypy ellar_throttler

LICENSE

Lines changed: 0 additions & 21 deletions
This file was deleted.

ellar_throttler/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
THROTTLER_LIMIT = "THROTTLER:LIMIT"
2+
THROTTLER_TTL = "THROTTLER:TTL"
3+
THROTTLER_OPTIONS = "THROTTLER:MODULE_OPTIONS"
4+
THROTTLER_SKIP = "THROTTLER:SKIP"

ellar_throttler/decorators.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import typing as t
2+
3+
from ellar.reflect import reflect
4+
5+
from .constants import THROTTLER_LIMIT, THROTTLER_SKIP, THROTTLER_TTL
6+
7+
8+
def _set_throttler_metadata(target: t.Callable, limit: int, ttl: int) -> None:
9+
reflect.define_metadata(THROTTLER_TTL, ttl, target)
10+
reflect.define_metadata(THROTTLER_LIMIT, limit, target)
11+
12+
13+
def throttle(*, limit: int = 20, ttl: int = 60) -> t.Callable:
14+
"""
15+
Adds metadata to the target which will be handled by the ThrottlerGuard to
16+
handle incoming requests based on the given metadata.
17+
18+
Usage:
19+
Use on controllers or individual route functions
20+
21+
@throttle(limit=20, ttl=300)
22+
class ControllerSample:
23+
@throttle(limit=20, ttl=300)
24+
async def index(self):
25+
...
26+
27+
:param limit: Guard Type or Instance
28+
:param ttl: Guard Type or Instance
29+
:return: Callable
30+
"""
31+
32+
def decorator(func: t.Callable) -> t.Callable:
33+
_set_throttler_metadata(func, limit=limit, ttl=ttl)
34+
return func
35+
36+
return decorator
37+
38+
39+
def skip_throttle(skip: bool = True) -> t.Callable:
40+
"""
41+
Adds metadata to the target which will be handled by the ThrottlerGuard
42+
whether to skip throttling for this context.
43+
44+
Usage:
45+
Use on controllers or individual route functions
46+
47+
@throttle(limit=20, ttl=300)
48+
class ControllerSample:
49+
@skip_throttle()
50+
async def index(self):
51+
...
52+
53+
@skip_throttle(false)
54+
async def create(self):
55+
...
56+
57+
:param skip:
58+
:return:
59+
"""
60+
61+
def decorator(func: t.Callable) -> t.Callable:
62+
reflect.define_metadata(THROTTLER_SKIP, skip, func)
63+
return func
64+
65+
return decorator

ellar_throttler/exception.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import math
2+
import typing as t
3+
4+
from ellar.core.exceptions import APIException
5+
from starlette import status
6+
7+
8+
class ThrottledException(APIException):
9+
status_code = status.HTTP_429_TOO_MANY_REQUESTS
10+
default_detail = "Request was throttled."
11+
extra_detail_singular = "Expected available in {wait} second."
12+
extra_detail_plural = "Expected available in {wait} seconds."
13+
14+
def __init__(self, wait: float = None, detail: t.Any = None) -> None:
15+
if detail is None:
16+
detail = self.default_detail
17+
if wait is not None:
18+
wait = math.ceil(wait)
19+
detail = " ".join(
20+
(
21+
detail,
22+
self.extra_detail_singular.format(wait=wait)
23+
if wait < 1
24+
else self.extra_detail_plural.format(wait=wait),
25+
)
26+
)
27+
headers = {"Retry-After": "%d" % float(wait or 0.0)}
28+
super().__init__(detail=detail, headers=headers)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .throttler_storage import IThrottlerStorage
2+
3+
__all__ = [
4+
"IThrottlerStorage",
5+
]
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import typing as t
2+
from abc import ABC, abstractmethod
3+
4+
from ..throttler_storage_options import ThrottlerStorageOption
5+
from ..throttler_storage_record import ThrottlerStorageRecord
6+
7+
8+
class IThrottlerStorage(ABC):
9+
@property
10+
@abstractmethod
11+
def storage(self) -> t.Dict[str, ThrottlerStorageOption]:
12+
"""
13+
The internal storage with all the request records.
14+
The key is a hashed key based on the current context and IP.
15+
:return:
16+
"""
17+
18+
@abstractmethod
19+
async def increment(self, key: str, ttl: int) -> ThrottlerStorageRecord:
20+
"""
21+
Increment the amount of requests for a given record. The record will
22+
automatically be removed from the storage once its TTL has been reached.
23+
:param key:
24+
:param ttl:
25+
:return:
26+
"""

ellar_throttler/module.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from ellar.common import Module
2+
from ellar.core import ModuleBase
3+
from ellar.di import ProviderConfig
4+
5+
from ellar_throttler.interfaces import IThrottlerStorage
6+
from ellar_throttler.throttler_service import CacheThrottlerStorageService
7+
8+
9+
@Module(
10+
providers=[
11+
ProviderConfig(IThrottlerStorage, use_class=CacheThrottlerStorageService)
12+
]
13+
)
14+
class ThrottlerModule(ModuleBase):
15+
pass

ellar_throttler/throttler_guard.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import hashlib
2+
import typing as t
3+
4+
from ellar.core import GuardCanActivate, IExecutionContext, Response
5+
from ellar.core.connection import HTTPConnection
6+
from ellar.di import injectable
7+
from ellar.helper import get_name
8+
from ellar.services import Reflector
9+
10+
from .constants import THROTTLER_LIMIT, THROTTLER_SKIP, THROTTLER_TTL
11+
from .exception import ThrottledException
12+
from .interfaces import IThrottlerStorage
13+
14+
15+
@injectable()
16+
class ThrottlerGuard(GuardCanActivate):
17+
header_prefix = "X-RateLimit"
18+
19+
def __init__(
20+
self, storage_service: IThrottlerStorage, reflector: Reflector
21+
) -> None:
22+
self.storage_service = storage_service
23+
self.reflector = reflector
24+
25+
async def can_activate(self, context: IExecutionContext) -> bool:
26+
handler = context.get_handler()
27+
class_ref = context.get_class()
28+
29+
# Return early if the current route should be skipped.
30+
# or self.options.skipIf?.(context)
31+
if self.reflector.get_all_and_override(THROTTLER_SKIP, handler, class_ref):
32+
return True
33+
34+
# Return early when we have no limit or ttl data.
35+
route_or_class_limit = self.reflector.get_all_and_override(
36+
THROTTLER_LIMIT, handler, class_ref
37+
)
38+
route_or_class_ttl = self.reflector.get_all_and_override(
39+
THROTTLER_TTL, handler, class_ref
40+
)
41+
42+
# Check if specific limits are set at class or route level, otherwise use global options.
43+
limit = route_or_class_limit or 60 # or this.options.limit
44+
ttl = route_or_class_ttl or 60 # or this.options.ttl
45+
return await self.handle_request(context, limit, ttl)
46+
47+
@classmethod
48+
def get_request_response(
49+
cls, context: IExecutionContext
50+
) -> t.Tuple[HTTPConnection, Response]:
51+
connection_host = context.switch_to_http_connection()
52+
return connection_host.get_client(), connection_host.get_response()
53+
54+
def get_tracker(self, connection: HTTPConnection) -> str:
55+
assert connection.client
56+
return connection.client.host
57+
58+
def generate_key(self, context: IExecutionContext, suffix: str) -> str:
59+
prefix = f"{get_name(context.get_class())}-{get_name(context.get_handler())}"
60+
return hashlib.md5(f"{prefix}-{suffix}".encode("utf8")).hexdigest()
61+
62+
async def handle_request(
63+
self, context: IExecutionContext, limit: int, ttl: int
64+
) -> bool:
65+
connection, response = self.get_request_response(context)
66+
# TODO: Return early if the current user agent should be ignored.
67+
68+
tracker = self.get_tracker(connection)
69+
key = self.generate_key(context, tracker)
70+
result = await self.storage_service.increment(key, ttl)
71+
72+
# Throw an error when the user reached their limit.
73+
if result.total_hits > limit:
74+
raise ThrottledException(wait=result.time_to_expire)
75+
76+
response.headers[f"{self.header_prefix}-Limit"] = str(limit)
77+
response.headers[f"{self.header_prefix}-Remaining"] = str(
78+
max(0, limit - result.total_hits)
79+
)
80+
response.headers[f"{self.header_prefix}-Reset"] = str(result.time_to_expire)
81+
82+
return True

0 commit comments

Comments
 (0)