Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 599c403

Browse files
David Robertsonclokep
andauthored
Allow rate limiters to passively record actions they cannot limit (#13253)
Co-authored-by: Patrick Cloke <[email protected]>
1 parent 0eb7e69 commit 599c403

File tree

3 files changed

+157
-12
lines changed

3 files changed

+157
-12
lines changed

changelog.d/13253.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Preparatory work for a per-room rate limiter on joins.

synapse/api/ratelimiting.py

Lines changed: 82 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,33 @@ class Ratelimiter:
2727
"""
2828
Ratelimit actions marked by arbitrary keys.
2929
30+
(Note that the source code speaks of "actions" and "burst_count" rather than
31+
"tokens" and a "bucket_size".)
32+
33+
This is a "leaky bucket as a meter". For each key to be tracked there is a bucket
34+
containing some number 0 <= T <= `burst_count` of tokens corresponding to previously
35+
permitted requests for that key. Each bucket starts empty, and gradually leaks
36+
tokens at a rate of `rate_hz`.
37+
38+
Upon an incoming request, we must determine:
39+
- the key that this request falls under (which bucket to inspect), and
40+
- the cost C of this request in tokens.
41+
Then, if there is room in the bucket for C tokens (T + C <= `burst_count`),
42+
the request is permitted and `cost` tokens are added to the bucket.
43+
Otherwise the request is denied, and the bucket continues to hold T tokens.
44+
45+
This means that the limiter enforces an average request frequency of `rate_hz`,
46+
while accumulating a buffer of up to `burst_count` requests which can be consumed
47+
instantaneously.
48+
49+
The tricky bit is the leaking. We do not want to have a periodic process which
50+
leaks every bucket! Instead, we track
51+
- the time point when the bucket was last completely empty, and
52+
- how many tokens have added to the bucket permitted since then.
53+
Then for each incoming request, we can calculate how many tokens have leaked
54+
since this time point, and use that to decide if we should accept or reject the
55+
request.
56+
3057
Args:
3158
clock: A homeserver clock, for retrieving the current time
3259
rate_hz: The long term number of actions that can be performed in a second.
@@ -41,14 +68,30 @@ def __init__(
4168
self.burst_count = burst_count
4269
self.store = store
4370

44-
# A ordered dictionary keeping track of actions, when they were last
45-
# performed and how often. Each entry is a mapping from a key of arbitrary type
46-
# to a tuple representing:
47-
# * How many times an action has occurred since a point in time
48-
# * The point in time
49-
# * The rate_hz of this particular entry. This can vary per request
71+
# An ordered dictionary representing the token buckets tracked by this rate
72+
# limiter. Each entry maps a key of arbitrary type to a tuple representing:
73+
# * The number of tokens currently in the bucket,
74+
# * The time point when the bucket was last completely empty, and
75+
# * The rate_hz (leak rate) of this particular bucket.
5076
self.actions: OrderedDict[Hashable, Tuple[float, float, float]] = OrderedDict()
5177

78+
def _get_key(
79+
self, requester: Optional[Requester], key: Optional[Hashable]
80+
) -> Hashable:
81+
"""Use the requester's MXID as a fallback key if no key is provided."""
82+
if key is None:
83+
if not requester:
84+
raise ValueError("Must supply at least one of `requester` or `key`")
85+
86+
key = requester.user.to_string()
87+
return key
88+
89+
def _get_action_counts(
90+
self, key: Hashable, time_now_s: float
91+
) -> Tuple[float, float, float]:
92+
"""Retrieve the action counts, with a fallback representing an empty bucket."""
93+
return self.actions.get(key, (0.0, time_now_s, 0.0))
94+
5295
async def can_do_action(
5396
self,
5497
requester: Optional[Requester],
@@ -88,11 +131,7 @@ async def can_do_action(
88131
* The reactor timestamp for when the action can be performed next.
89132
-1 if rate_hz is less than or equal to zero
90133
"""
91-
if key is None:
92-
if not requester:
93-
raise ValueError("Must supply at least one of `requester` or `key`")
94-
95-
key = requester.user.to_string()
134+
key = self._get_key(requester, key)
96135

97136
if requester:
98137
# Disable rate limiting of users belonging to any AS that is configured
@@ -121,7 +160,7 @@ async def can_do_action(
121160
self._prune_message_counts(time_now_s)
122161

123162
# Check if there is an existing count entry for this key
124-
action_count, time_start, _ = self.actions.get(key, (0.0, time_now_s, 0.0))
163+
action_count, time_start, _ = self._get_action_counts(key, time_now_s)
125164

126165
# Check whether performing another action is allowed
127166
time_delta = time_now_s - time_start
@@ -164,6 +203,37 @@ async def can_do_action(
164203

165204
return allowed, time_allowed
166205

206+
def record_action(
207+
self,
208+
requester: Optional[Requester],
209+
key: Optional[Hashable] = None,
210+
n_actions: int = 1,
211+
_time_now_s: Optional[float] = None,
212+
) -> None:
213+
"""Record that an action(s) took place, even if they violate the rate limit.
214+
215+
This is useful for tracking the frequency of events that happen across
216+
federation which we still want to impose local rate limits on. For instance, if
217+
we are alice.com monitoring a particular room, we cannot prevent bob.com
218+
from joining users to that room. However, we can track the number of recent
219+
joins in the room and refuse to serve new joins ourselves if there have been too
220+
many in the room across both homeservers.
221+
222+
Args:
223+
requester: The requester that is doing the action, if any.
224+
key: An arbitrary key used to classify an action. Defaults to the
225+
requester's user ID.
226+
n_actions: The number of times the user wants to do this action. If the user
227+
cannot do all of the actions, the user's action count is not incremented
228+
at all.
229+
_time_now_s: The current time. Optional, defaults to the current time according
230+
to self.clock. Only used by tests.
231+
"""
232+
key = self._get_key(requester, key)
233+
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
234+
action_count, time_start, rate_hz = self._get_action_counts(key, time_now_s)
235+
self.actions[key] = (action_count + n_actions, time_start, rate_hz)
236+
167237
def _prune_message_counts(self, time_now_s: float) -> None:
168238
"""Remove message count entries that have not exceeded their defined
169239
rate_hz limit

tests/api/test_ratelimiting.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,3 +314,77 @@ def consume_at(time: float) -> bool:
314314

315315
# Check that we get rate limited after using that token.
316316
self.assertFalse(consume_at(11.1))
317+
318+
def test_record_action_which_doesnt_fill_bucket(self) -> None:
319+
limiter = Ratelimiter(
320+
store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
321+
)
322+
323+
# Observe two actions, leaving room in the bucket for one more.
324+
limiter.record_action(requester=None, key="a", n_actions=2, _time_now_s=0.0)
325+
326+
# We should be able to take a new action now.
327+
success, _ = self.get_success_or_raise(
328+
limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
329+
)
330+
self.assertTrue(success)
331+
332+
# ... but not two.
333+
success, _ = self.get_success_or_raise(
334+
limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
335+
)
336+
self.assertFalse(success)
337+
338+
def test_record_action_which_fills_bucket(self) -> None:
339+
limiter = Ratelimiter(
340+
store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
341+
)
342+
343+
# Observe three actions, filling up the bucket.
344+
limiter.record_action(requester=None, key="a", n_actions=3, _time_now_s=0.0)
345+
346+
# We should be unable to take a new action now.
347+
success, _ = self.get_success_or_raise(
348+
limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
349+
)
350+
self.assertFalse(success)
351+
352+
# If we wait 10 seconds to leak a token, we should be able to take one action...
353+
success, _ = self.get_success_or_raise(
354+
limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
355+
)
356+
self.assertTrue(success)
357+
358+
# ... but not two.
359+
success, _ = self.get_success_or_raise(
360+
limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
361+
)
362+
self.assertFalse(success)
363+
364+
def test_record_action_which_overfills_bucket(self) -> None:
365+
limiter = Ratelimiter(
366+
store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
367+
)
368+
369+
# Observe four actions, exceeding the bucket.
370+
limiter.record_action(requester=None, key="a", n_actions=4, _time_now_s=0.0)
371+
372+
# We should be prevented from taking a new action now.
373+
success, _ = self.get_success_or_raise(
374+
limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
375+
)
376+
self.assertFalse(success)
377+
378+
# If we wait 10 seconds to leak a token, we should be unable to take an action
379+
# because the bucket is still full.
380+
success, _ = self.get_success_or_raise(
381+
limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
382+
)
383+
self.assertFalse(success)
384+
385+
# But after another 10 seconds we leak a second token, giving us room for
386+
# action.
387+
success, _ = self.get_success_or_raise(
388+
limiter.can_do_action(requester=None, key="a", _time_now_s=20.0)
389+
)
390+
self.assertTrue(success)

0 commit comments

Comments
 (0)