Skip to content

Commit 5a42885

Browse files
committed
Feature: Add TokenBucket policy
1 parent 2cefb7a commit 5a42885

File tree

3 files changed

+157
-7
lines changed

3 files changed

+157
-7
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ async def _():
204204
- [x] 固定窗口
205205
- [x] 滑动窗口
206206
- [x] 漏桶
207-
- [ ] 令牌桶
207+
- [x] 令牌桶
208208
- [x] reject 依赖注入回调函数(试验性支持)
209209
- [ ] 重置用量
210210
- [x] 本地持久化状态

nonebot_plugin_limiter/cooldown.py

Lines changed: 124 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def _increase_action():
324324
class LeakyBucketUsage:
325325
last_update_time: datetime
326326
capacity: int
327-
available: int
327+
used: int
328328

329329

330330
_LeakyBucketCooldownDict: dict[str, dict[str, LeakyBucketUsage]] = {}
@@ -417,13 +417,13 @@ async def _limiter_dependency(
417417

418418
# Update bucket available capacity
419419
leaked_size = int((now - usage.last_update_time).total_seconds()) * leak_speed
420-
usage.available = max(leaked_size + usage.available, usage.capacity)
420+
usage.used = max(usage.used - leaked_size, 0)
421421
usage.last_update_time = now
422422

423423
def _increase_action():
424-
usage.available -= pour_size
424+
usage.used += pour_size
425425

426-
if usage.available >= pour_size:
426+
if usage.used < pour_size:
427427
if set_increaser:
428428
inject_increaser(state, _increase_action)
429429
else:
@@ -436,4 +436,123 @@ def _increase_action():
436436

437437
return Depends(_limiter_dependency)
438438

439-
# endregion
439+
# endregion
440+
441+
# region: TokenBucket
442+
@dataclass
443+
class TokenBucketUsage:
444+
last_update_time: datetime
445+
capacity: int
446+
available: int
447+
448+
449+
_TokenBucketCooldownDict: dict[str, dict[str, TokenBucketUsage]] = {}
450+
451+
452+
def TokenBucketCooldown(
453+
entity: CooldownEntity | _DependentCallable[str],
454+
capacity: int,
455+
add_speed: int,
456+
*,
457+
consume_size: int | _DependentCallable[int] = 10,
458+
reject: None | SupportMsgType | _DependentCallable[Any] = None,
459+
set_increaser: bool = False,
460+
name: None | str = None,
461+
):
462+
"""
463+
**令牌桶算法限制器**
464+
465+
用于控制给定资源内能处理的最大请求量。
466+
467+
参数:
468+
entity (CooldownEntity | _DependentCallable[str]):
469+
设置需要进行速率限制的对象。
470+
- 可传入 `CooldownEntity` 对象,如 `UserScope`, `GroupScope` 等。
471+
- 可传入返回值为 `str` 的函数,自定义限制对象的**唯一 ID**,支持依赖注入。
472+
473+
capacity (int):
474+
设置令牌桶的最大容量。
475+
476+
add_speed (int):
477+
设置令牌桶每秒添加 token 的数量。
478+
479+
consume_size (int, _DependentCallable[int]):
480+
可选,设置每次添加任务时需要消耗的 token 数量。默认为 10。
481+
- 可传入返回值为 `int` 的函数,自定义消耗数量,支持依赖注入。
482+
483+
reject (None | SupportMsgType | _DependentCallable):
484+
可选,当超出限制时的响应行为。默认为 `None`。
485+
- 若为 `str` 或消息对象,将作为限制使用时的提示消息发送给用户。
486+
- 若为依赖注入函数,将会在拒绝时进行调用。
487+
488+
set_increaser (bool):
489+
可选,是否获取限制器的增加器。默认为 False。
490+
- 当启用该选项时,限制器默认的自增将会关闭,需要在事件处理时依赖获取 Increaser 并手动操作增加。
491+
492+
name (None | str):
493+
可选,设置当前限制器的使用统计集合。默认为 `None` ,即私有集合。
494+
- 当传入 `str` ,将创建或加入一个同名公共集合,可用于与其他命令的限制器共享使用统计。
495+
496+
示例:
497+
```python
498+
from nonebot.permission import SUPERUSER
499+
from nonebot_plugin_limiter.entity import UserScope
500+
501+
# 收到一个请求,处理消化这个请求需要 10 个 token ,最多同时处理两个请求
502+
@matcher.handle(parameterless=[
503+
TokenBucketCooldown(
504+
UserScope(permission=SUPERUSER),
505+
20,
506+
1,
507+
consume_size = 10,
508+
reject="操作过于频繁,请稍后再试。"
509+
)
510+
])
511+
async def handler(...): ...
512+
```
513+
"""
514+
515+
if isinstance(name, str):
516+
if name not in _TokenBucketCooldownDict.keys():
517+
_TokenBucketCooldownDict[name] = {}
518+
bucket = _TokenBucketCooldownDict[name]
519+
else:
520+
bucket: dict[str, TokenBucketUsage] = {}
521+
522+
async def _limiter_dependency(
523+
state: T_State,
524+
entity_id: str = Depends(_entity_id_dep_wrapper(entity)),
525+
consume_size: int = Depends(_limit_dep_wrapper(consume_size)),
526+
reject_cb: Callable[..., Awaitable[Any]] = Depends(_reject_dep_wrapper(reject))
527+
) -> None:
528+
if entity_id == BYPASS_ENTITY:
529+
return
530+
531+
now = datetime.now(tz=_tz)
532+
533+
if entity_id not in bucket:
534+
bucket[entity_id] = TokenBucketUsage(now, capacity, 0)
535+
usage = bucket[entity_id]
536+
537+
# Update bucket token count
538+
resume_size = int((now - usage.last_update_time).total_seconds()) * add_speed
539+
usage.available = min(resume_size + usage.available, usage.capacity)
540+
usage.last_update_time = now
541+
542+
def _increase_action():
543+
usage.available -= consume_size
544+
545+
if usage.available >= consume_size:
546+
if set_increaser:
547+
inject_increaser(state, _increase_action)
548+
else:
549+
_increase_action()
550+
return # Didn't exceed
551+
552+
# Exceeded
553+
await reject_cb()
554+
raise FinishedException()
555+
556+
return Depends(_limiter_dependency)
557+
558+
# endregion

nonebot_plugin_limiter/persist.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
FixWindowUsage,
1616
LeakyBucketUsage,
1717
SlidingWindowUsage,
18+
TokenBucketUsage,
1819
_FixWindowCooldownDict,
1920
_LeakyBucketCooldownDict,
2021
_SlidingWindowCooldownDict,
22+
_TokenBucketCooldownDict,
2123
_tz,
2224
)
2325

@@ -40,10 +42,17 @@ class SlidingWindowSet(BaseModel):
4042
class LeakyBucketSet(BaseModel):
4143
last_update_time: int
4244
capacity: int
43-
available: int
45+
used: int
4446

4547
leaky_bucket: dict[str, dict[str, LeakyBucketSet]] | None = None
4648

49+
class TokenBucketSet(BaseModel):
50+
last_update_time: int
51+
capacity: int
52+
available: int
53+
54+
token_bucket: dict[str, dict[str, TokenBucketSet]] | None = None
55+
4756
def load_usage_data() -> None:
4857
"""加载本地存储的用量数据"""
4958

@@ -83,6 +92,18 @@ def load_usage_data() -> None:
8392
bucket = _LeakyBucketCooldownDict[name]
8493
for _id, usage in usage_set.items():
8594
bucket[_id] = LeakyBucketUsage(
95+
last_update_time = datetime.fromtimestamp(usage.last_update_time, tz=_tz),
96+
capacity=usage.capacity,
97+
used=usage.used
98+
)
99+
100+
if data.token_bucket is not None:
101+
for name, usage_set in data.token_bucket.items():
102+
if name not in _TokenBucketCooldownDict:
103+
_TokenBucketCooldownDict[name] = {}
104+
bucket = _TokenBucketCooldownDict[name]
105+
for _id, usage in usage_set.items():
106+
bucket[_id] = TokenBucketUsage(
86107
last_update_time = datetime.fromtimestamp(usage.last_update_time, tz=_tz),
87108
capacity=usage.capacity,
88109
available=usage.available
@@ -115,6 +136,16 @@ def save_usage_data() -> None:
115136

116137
for name, usage_set in _LeakyBucketCooldownDict.items():
117138
j["leaky_bucket"][name] = {
139+
_id: {
140+
"last_update_time": int(usage.last_update_time.timestamp()),
141+
"capacity": usage.capacity,
142+
"used": usage.used
143+
}
144+
for _id, usage in usage_set.items()
145+
}
146+
147+
for name, usage_set in _TokenBucketCooldownDict.items():
148+
j["token_bucket"][name] = {
118149
_id: {
119150
"last_update_time": int(usage.last_update_time.timestamp()),
120151
"capacity": usage.capacity,

0 commit comments

Comments
 (0)