Skip to content

Commit 8c9362e

Browse files
committed
Allow multiple limits per route and limiter fixes.
1 parent bd1b280 commit 8c9362e

File tree

4 files changed

+34
-11
lines changed

4 files changed

+34
-11
lines changed

starlette_plus/core.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,22 @@ def limit(
112112
exempt: ExemptCallable | None = None,
113113
) -> T_LimitDecorator:
114114
def decorator(coro: Callable[..., RouteCoro] | _Route) -> LimitDecorator:
115-
limits: RateLimitData = {"rate": rate, "per": per, "bucket": bucket, "priority": priority, "exempt": exempt}
115+
limits: RateLimitData = {
116+
"rate": rate,
117+
"per": per,
118+
"bucket": bucket,
119+
"priority": priority,
120+
"exempt": exempt,
121+
"is_global": False,
122+
}
116123

117124
if isinstance(coro, _Route):
118125
coro._limits.append(limits)
119126
else:
120-
setattr(coro, "__limits__", [limits])
127+
try:
128+
coro.__limits__.append(limits) # type: ignore
129+
except AttributeError:
130+
setattr(coro, "__limits__", [limits])
121131

122132
return coro
123133

starlette_plus/middleware/ratelimiter.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ def __init__(
5151
self.app: ASGIApp = app
5252

5353
self._ignore_local: bool = ignore_localhost
54+
55+
for limit in global_limits:
56+
limit["is_global"] = True
57+
5458
self._global_limits: list[RateLimitData] = global_limits
5559

5660
self._store: Store = Store(redis=redis)
@@ -82,28 +86,35 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
8286
route = r
8387
break
8488

85-
route_limits: list[RateLimitData] = sorted(getattr(route, "limits", []), key=lambda x: x["priority"])
89+
route_limits: list[RateLimitData] = sorted(getattr(route, "limits", []), key=lambda x: x.get("priority", 0))
90+
for data in route_limits:
91+
# Ensure routes are never treated as global limits...
92+
data["is_global"] = False
8693

8794
for limit in self._global_limits + route_limits:
8895
is_exempt: bool = False
89-
exempt: ExemptCallable | None = limit["exempt"]
96+
exempt: ExemptCallable | None = limit.get("exempt", None)
9097

9198
if exempt is not None:
9299
is_exempt: bool = await exempt(request)
93100

94101
if is_exempt:
95102
continue
96103

97-
bucket: BucketType = limit["bucket"]
104+
bucket: BucketType = limit.get("bucket", "ip")
98105
if bucket == "ip":
99106
if not request.client and not forwarded:
100107
logger.warning("Could not determine the IP address while ratelimiting! Ignoring...")
101108
return await self.app(scope, receive, send)
102109

103110
# forwarded or client.host will exist at this point...
104-
key: str = forwarded.split(",")[0] if forwarded else request.client.host # type: ignore
111+
ip: str = forwarded.split(",")[0] if forwarded else request.client.host # type: ignore
112+
if not limit.get("is_global", False) and route:
113+
key = f"{route.name}@{route.path}::{limit['rate']}.{limit['per']}.ip"
114+
else:
115+
key = ip
105116

106-
if self._ignore_local and key in ("127.0.0.1", "::1", "localhost", "0.0.0.0"):
117+
if self._ignore_local and ip in ("127.0.0.1", "::1", "localhost", "0.0.0.0"):
107118
return await self.app(scope, receive, send)
108119
else:
109120
key: str | None = await bucket(request)

starlette_plus/types_/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@
1212
See the License for the specific language governing permissions and
1313
limitations under the License.
1414
"""
15+
1516
from .limiter import RateLimitData as RateLimitData

starlette_plus/types_/limiter.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"""
1515

1616
from collections.abc import Awaitable, Callable
17-
from typing import Literal, TypeAlias, TypedDict
17+
from typing import Literal, NotRequired, TypeAlias, TypedDict
1818

1919
from starlette.requests import Request
2020
from starlette.responses import Response
@@ -30,6 +30,7 @@
3030
class RateLimitData(TypedDict):
3131
rate: int
3232
per: float
33-
bucket: BucketType
34-
priority: int
35-
exempt: ExemptCallable | None
33+
bucket: NotRequired[BucketType]
34+
priority: NotRequired[int]
35+
exempt: NotRequired[ExemptCallable | None]
36+
is_global: NotRequired[bool]

0 commit comments

Comments
 (0)