Skip to content

Commit 46db50c

Browse files
committed
Fixed ThrottlerModule configuration
1 parent f11ad14 commit 46db50c

File tree

4 files changed

+72
-16
lines changed

4 files changed

+72
-16
lines changed

ellar_throttler/module.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,45 @@
1+
import typing as t
2+
13
from ellar.common import Module
2-
from ellar.core import ModuleBase
4+
from ellar.core import IExecutionContext
5+
from ellar.core.modules import DynamicModule, IModuleConfigure, ModuleBase
36
from ellar.di import ProviderConfig
47

58
from ellar_throttler.interfaces import IThrottlerStorage
6-
from ellar_throttler.throttler_service import CacheThrottlerStorageService
9+
from ellar_throttler.throttler_module_options import ThrottlerModuleOptions
10+
from ellar_throttler.throttler_service import ThrottlerStorageService
11+
712

13+
@Module()
14+
class ThrottlerModule(ModuleBase, IModuleConfigure):
15+
@classmethod
16+
def module_configure(
17+
cls,
18+
ttl: int,
19+
limit: int,
20+
storage: t.Union[t.Type, t.Any] = None,
21+
skip_if: t.Callable[[IExecutionContext], bool] = None,
22+
) -> DynamicModule:
23+
if storage and isinstance(storage, IThrottlerStorage):
24+
_provider = ProviderConfig(IThrottlerStorage, use_value=storage)
25+
elif storage:
26+
_provider = ProviderConfig(IThrottlerStorage, use_class=storage)
27+
else:
28+
_provider = ProviderConfig(
29+
IThrottlerStorage, use_class=ThrottlerStorageService
30+
)
831

9-
@Module(
10-
providers=[
11-
ProviderConfig(IThrottlerStorage, use_class=CacheThrottlerStorageService)
12-
]
13-
)
14-
class ThrottlerModule(ModuleBase):
15-
pass
32+
return DynamicModule(
33+
cls,
34+
providers=[
35+
_provider,
36+
ProviderConfig(
37+
ThrottlerModuleOptions,
38+
use_value=ThrottlerModuleOptions(
39+
limit=limit,
40+
ttl=ttl,
41+
skip_if=skip_if, # type:ignore[arg-type]
42+
),
43+
),
44+
],
45+
)

ellar_throttler/throttler_guard.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,35 @@
1010
from .constants import THROTTLER_LIMIT, THROTTLER_SKIP, THROTTLER_TTL
1111
from .exception import ThrottledException
1212
from .interfaces import IThrottlerStorage
13+
from .throttler_module_options import ThrottlerModuleOptions
1314

1415

1516
@injectable()
1617
class ThrottlerGuard(GuardCanActivate):
1718
header_prefix = "X-RateLimit"
1819

1920
def __init__(
20-
self, storage_service: IThrottlerStorage, reflector: Reflector
21+
self,
22+
storage_service: IThrottlerStorage,
23+
reflector: Reflector,
24+
options: ThrottlerModuleOptions,
2125
) -> None:
2226
self.storage_service = storage_service
2327
self.reflector = reflector
28+
self.options = options
2429

2530
async def can_activate(self, context: IExecutionContext) -> bool:
2631
handler = context.get_handler()
2732
class_ref = context.get_class()
2833

2934
# Return early if the current route should be skipped.
3035
# or self.options.skipIf?.(context)
31-
if self.reflector.get_all_and_override(THROTTLER_SKIP, handler, class_ref):
36+
37+
if (
38+
self.reflector.get_all_and_override(THROTTLER_SKIP, handler, class_ref)
39+
or self.options.skip_if
40+
and self.options.skip_if(context)
41+
):
3242
return True
3343

3444
# Return early when we have no limit or ttl data.
@@ -40,8 +50,8 @@ async def can_activate(self, context: IExecutionContext) -> bool:
4050
)
4151

4252
# Check if specific limits are set at class or route level, otherwise use global options.
43-
limit = route_or_class_limit or 60 # or this.options.limit
44-
ttl = route_or_class_ttl or 60 # or this.options.ttl
53+
limit = route_or_class_limit or self.options.limit
54+
ttl = route_or_class_ttl or self.options.ttl
4555
return await self.handle_request(context, limit, ttl)
4656

4757
@classmethod
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import dataclasses
2+
import typing as t
3+
4+
from ellar.core import IExecutionContext
5+
6+
7+
@dataclasses.dataclass
8+
class ThrottlerModuleOptions:
9+
# The amount of requests that are allowed within the ttl's time window.
10+
limit: int
11+
12+
# The amount of seconds of how many requests are allowed within this time.
13+
ttl: int
14+
15+
# A factory method to determine if throttling should be skipped.
16+
# This can be based on the incoming context.
17+
skip_if: t.Callable[[IExecutionContext], bool] = None # type:ignore[assignment]

ellar_throttler/throttler_service.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,8 @@ def storage(self) -> t.Any:
8080
return self._cache_service.get_backend()
8181

8282
async def get_expiration_time(self, key: str) -> int:
83-
result = await self._cache_service.get_async(f"{key}-ttl") or 60
84-
assert result, "Value can not be none"
85-
return math.floor(result - time.time())
83+
result = await self._cache_service.get_async(f"{key}-ttl")
84+
return math.floor(result - time.time()) if result else -1
8685

8786
async def increment(self, key: str, ttl: int) -> ThrottlerStorageRecord:
8887
if not await self._cache_service.has_key_async(key):

0 commit comments

Comments
 (0)