Skip to content
This repository was archived by the owner on Apr 12, 2024. It is now read-only.

Commit 14b5b48

Browse files
authored
Fix ratelimiting for federation /send requests. (#8342)
c.f. #8295 for rationale
1 parent ad055ea commit 14b5b48

File tree

4 files changed

+54
-17
lines changed

4 files changed

+54
-17
lines changed

changelog.d/8342.bugfix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix ratelimitng of federation `/send` requests.

synapse/federation/federation_server.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,16 @@ def __init__(self, hs):
9797
self.state = hs.get_state_handler()
9898

9999
self.device_handler = hs.get_device_handler()
100+
self._federation_ratelimiter = hs.get_federation_ratelimiter()
100101

101102
self._server_linearizer = Linearizer("fed_server")
102103
self._transaction_linearizer = Linearizer("fed_txn_handler")
103104

105+
# We cache results for transaction with the same ID
106+
self._transaction_resp_cache = ResponseCache(
107+
hs, "fed_txn_handler", timeout_ms=30000
108+
)
109+
104110
self.transaction_actions = TransactionActions(self.store)
105111

106112
self.registry = hs.get_federation_registry()
@@ -135,22 +141,44 @@ async def on_incoming_transaction(
135141
request_time = self._clock.time_msec()
136142

137143
transaction = Transaction(**transaction_data)
144+
transaction_id = transaction.transaction_id # type: ignore
138145

139-
if not transaction.transaction_id: # type: ignore
146+
if not transaction_id:
140147
raise Exception("Transaction missing transaction_id")
141148

142-
logger.debug("[%s] Got transaction", transaction.transaction_id) # type: ignore
149+
logger.debug("[%s] Got transaction", transaction_id)
143150

144-
# use a linearizer to ensure that we don't process the same transaction
145-
# multiple times in parallel.
146-
with (
147-
await self._transaction_linearizer.queue(
148-
(origin, transaction.transaction_id) # type: ignore
149-
)
150-
):
151-
result = await self._handle_incoming_transaction(
152-
origin, transaction, request_time
153-
)
151+
# We wrap in a ResponseCache so that we de-duplicate retried
152+
# transactions.
153+
return await self._transaction_resp_cache.wrap(
154+
(origin, transaction_id),
155+
self._on_incoming_transaction_inner,
156+
origin,
157+
transaction,
158+
request_time,
159+
)
160+
161+
async def _on_incoming_transaction_inner(
162+
self, origin: str, transaction: Transaction, request_time: int
163+
) -> Tuple[int, Dict[str, Any]]:
164+
# Use a linearizer to ensure that transactions from a remote are
165+
# processed in order.
166+
with await self._transaction_linearizer.queue(origin):
167+
# We rate limit here *after* we've queued up the incoming requests,
168+
# so that we don't fill up the ratelimiter with blocked requests.
169+
#
170+
# This is important as the ratelimiter allows N concurrent requests
171+
# at a time, and only starts ratelimiting if there are more requests
172+
# than that being processed at a time. If we queued up requests in
173+
# the linearizer/response cache *after* the ratelimiting then those
174+
# queued up requests would count as part of the allowed limit of N
175+
# concurrent requests.
176+
with self._federation_ratelimiter.ratelimit(origin) as d:
177+
await d
178+
179+
result = await self._handle_incoming_transaction(
180+
origin, transaction, request_time
181+
)
154182

155183
return result
156184

synapse/federation/transport/server.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
)
4646
from synapse.server import HomeServer
4747
from synapse.types import ThirdPartyInstanceID, get_domain_from_id
48-
from synapse.util.ratelimitutils import FederationRateLimiter
4948
from synapse.util.versionstring import get_version_string
5049

5150
logger = logging.getLogger(__name__)
@@ -72,9 +71,7 @@ def __init__(self, hs, servlet_groups=None):
7271
super(TransportLayerServer, self).__init__(hs, canonical_json=False)
7372

7473
self.authenticator = Authenticator(hs)
75-
self.ratelimiter = FederationRateLimiter(
76-
self.clock, config=hs.config.rc_federation
77-
)
74+
self.ratelimiter = hs.get_federation_ratelimiter()
7875

7976
self.register_servlets()
8077

@@ -272,6 +269,8 @@ class BaseFederationServlet:
272269

273270
PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version
274271

272+
RATELIMIT = True # Whether to rate limit requests or not
273+
275274
def __init__(self, handler, authenticator, ratelimiter, server_name):
276275
self.handler = handler
277276
self.authenticator = authenticator
@@ -335,7 +334,7 @@ async def new_func(request, *args, **kwargs):
335334
)
336335

337336
with scope:
338-
if origin:
337+
if origin and self.RATELIMIT:
339338
with ratelimiter.ratelimit(origin) as d:
340339
await d
341340
if request._disconnected:
@@ -372,6 +371,10 @@ def register(self, server):
372371
class FederationSendServlet(BaseFederationServlet):
373372
PATH = "/send/(?P<transaction_id>[^/]*)/?"
374373

374+
# We ratelimit manually in the handler as we queue up the requests and we
375+
# don't want to fill up the ratelimiter with blocked requests.
376+
RATELIMIT = False
377+
375378
def __init__(self, handler, server_name, **kwargs):
376379
super(FederationSendServlet, self).__init__(
377380
handler, server_name=server_name, **kwargs

synapse/server.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@
114114
from synapse.types import DomainSpecificString
115115
from synapse.util import Clock
116116
from synapse.util.distributor import Distributor
117+
from synapse.util.ratelimitutils import FederationRateLimiter
117118
from synapse.util.stringutils import random_string
118119

119120
logger = logging.getLogger(__name__)
@@ -642,6 +643,10 @@ def get_replication_data_handler(self) -> ReplicationDataHandler:
642643
def get_replication_streams(self) -> Dict[str, Stream]:
643644
return {stream.NAME: stream(self) for stream in STREAMS_MAP.values()}
644645

646+
@cache_in_self
647+
def get_federation_ratelimiter(self) -> FederationRateLimiter:
648+
return FederationRateLimiter(self.clock, config=self.config.rc_federation)
649+
645650
async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
646651
return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
647652

0 commit comments

Comments
 (0)