@@ -27,6 +27,33 @@ class Ratelimiter:
2727    """ 
2828    Ratelimit actions marked by arbitrary keys. 
2929
30+     (Note that the source code speaks of "actions" and "burst_count" rather than 
31+     "tokens" and a "bucket_size".) 
32+ 
33+     This is a "leaky bucket as a meter". For each key to be tracked there is a bucket 
34+     containing some number 0 <= T <= `burst_count` of tokens corresponding to previously 
35+     permitted requests for that key. Each bucket starts empty, and gradually leaks 
36+     tokens at a rate of `rate_hz`. 
37+ 
38+     Upon an incoming request, we must determine: 
39+     - the key that this request falls under (which bucket to inspect), and 
40+     - the cost C of this request in tokens. 
41+     Then, if there is room in the bucket for C tokens (T + C <= `burst_count`), 
42+     the request is permitted and `cost` tokens are added to the bucket. 
43+     Otherwise the request is denied, and the bucket continues to hold T tokens. 
44+ 
45+     This means that the limiter enforces an average request frequency of `rate_hz`, 
46+     while accumulating a buffer of up to `burst_count` requests which can be consumed 
47+     instantaneously. 
48+ 
49+     The tricky bit is the leaking. We do not want to have a periodic process which 
50+     leaks every bucket! Instead, we track 
51+     - the time point when the bucket was last completely empty, and 
52+     - how many tokens have added to the bucket permitted since then. 
53+     Then for each incoming request, we can calculate how many tokens have leaked 
54+     since this time point, and use that to decide if we should accept or reject the 
55+     request. 
56+ 
3057    Args: 
3158        clock: A homeserver clock, for retrieving the current time 
3259        rate_hz: The long term number of actions that can be performed in a second. 
@@ -41,14 +68,30 @@ def __init__(
4168        self .burst_count  =  burst_count 
4269        self .store  =  store 
4370
44-         # A ordered dictionary keeping track of actions, when they were last 
45-         # performed and how often. Each entry is a mapping from a key of arbitrary type 
46-         # to a tuple representing: 
47-         #   * How many times an action has occurred since a point in time 
48-         #   * The point in time 
49-         #   * The rate_hz of this particular entry. This can vary per request 
71+         # An ordered dictionary representing the token buckets tracked by this rate 
72+         # limiter. Each entry maps a key of arbitrary type to a tuple representing: 
73+         #   * The number of tokens currently in the bucket, 
74+         #   * The time point when the bucket was last completely empty, and 
75+         #   * The rate_hz (leak rate) of this particular bucket. 
5076        self .actions : OrderedDict [Hashable , Tuple [float , float , float ]] =  OrderedDict ()
5177
78+     def  _get_key (
79+         self , requester : Optional [Requester ], key : Optional [Hashable ]
80+     ) ->  Hashable :
81+         """Use the requester's MXID as a fallback key if no key is provided.""" 
82+         if  key  is  None :
83+             if  not  requester :
84+                 raise  ValueError ("Must supply at least one of `requester` or `key`" )
85+ 
86+             key  =  requester .user .to_string ()
87+         return  key 
88+ 
89+     def  _get_action_counts (
90+         self , key : Hashable , time_now_s : float 
91+     ) ->  Tuple [float , float , float ]:
92+         """Retrieve the action counts, with a fallback representing an empty bucket.""" 
93+         return  self .actions .get (key , (0.0 , time_now_s , 0.0 ))
94+ 
5295    async  def  can_do_action (
5396        self ,
5497        requester : Optional [Requester ],
@@ -88,11 +131,7 @@ async def can_do_action(
88131                * The reactor timestamp for when the action can be performed next. 
89132                  -1 if rate_hz is less than or equal to zero 
90133        """ 
91-         if  key  is  None :
92-             if  not  requester :
93-                 raise  ValueError ("Must supply at least one of `requester` or `key`" )
94- 
95-             key  =  requester .user .to_string ()
134+         key  =  self ._get_key (requester , key )
96135
97136        if  requester :
98137            # Disable rate limiting of users belonging to any AS that is configured 
@@ -121,7 +160,7 @@ async def can_do_action(
121160        self ._prune_message_counts (time_now_s )
122161
123162        # Check if there is an existing count entry for this key 
124-         action_count , time_start , _  =  self .actions . get (key , ( 0.0 ,  time_now_s ,  0.0 ) )
163+         action_count , time_start , _  =  self ._get_action_counts (key , time_now_s )
125164
126165        # Check whether performing another action is allowed 
127166        time_delta  =  time_now_s  -  time_start 
@@ -164,6 +203,37 @@ async def can_do_action(
164203
165204        return  allowed , time_allowed 
166205
206+     def  record_action (
207+         self ,
208+         requester : Optional [Requester ],
209+         key : Optional [Hashable ] =  None ,
210+         n_actions : int  =  1 ,
211+         _time_now_s : Optional [float ] =  None ,
212+     ) ->  None :
213+         """Record that an action(s) took place, even if they violate the rate limit. 
214+ 
215+         This is useful for tracking the frequency of events that happen across 
216+         federation which we still want to impose local rate limits on. For instance, if 
217+         we are alice.com monitoring a particular room, we cannot prevent bob.com 
218+         from joining users to that room. However, we can track the number of recent 
219+         joins in the room and refuse to serve new joins ourselves if there have been too 
220+         many in the room across both homeservers. 
221+ 
222+         Args: 
223+             requester: The requester that is doing the action, if any. 
224+             key: An arbitrary key used to classify an action. Defaults to the 
225+                 requester's user ID. 
226+             n_actions: The number of times the user wants to do this action. If the user 
227+                 cannot do all of the actions, the user's action count is not incremented 
228+                 at all. 
229+             _time_now_s: The current time. Optional, defaults to the current time according 
230+                 to self.clock. Only used by tests. 
231+         """ 
232+         key  =  self ._get_key (requester , key )
233+         time_now_s  =  _time_now_s  if  _time_now_s  is  not   None  else  self .clock .time ()
234+         action_count , time_start , rate_hz  =  self ._get_action_counts (key , time_now_s )
235+         self .actions [key ] =  (action_count  +  n_actions , time_start , rate_hz )
236+ 
167237    def  _prune_message_counts (self , time_now_s : float ) ->  None :
168238        """Remove message count entries that have not exceeded their defined 
169239        rate_hz limit 
0 commit comments