|
16 | 16 | import logging |
17 | 17 | from typing import TYPE_CHECKING, Any, Dict |
18 | 18 |
|
| 19 | +from synapse.api.constants import EduTypes |
19 | 20 | from synapse.api.errors import SynapseError |
| 21 | +from synapse.api.ratelimiting import Ratelimiter |
20 | 22 | from synapse.logging.context import run_in_background |
21 | 23 | from synapse.logging.opentracing import ( |
22 | 24 | get_active_span_text_map, |
|
25 | 27 | start_active_span, |
26 | 28 | ) |
27 | 29 | from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet |
28 | | -from synapse.types import JsonDict, UserID, get_domain_from_id |
| 30 | +from synapse.types import JsonDict, Requester, UserID, get_domain_from_id |
29 | 31 | from synapse.util import json_encoder |
30 | 32 | from synapse.util.stringutils import random_string |
31 | 33 |
|
@@ -78,6 +80,12 @@ def __init__(self, hs: "HomeServer"): |
78 | 80 | ReplicationUserDevicesResyncRestServlet.make_client(hs) |
79 | 81 | ) |
80 | 82 |
|
| 83 | + self._ratelimiter = Ratelimiter( |
| 84 | + clock=hs.get_clock(), |
| 85 | + rate_hz=hs.config.rc_key_requests.per_second, |
| 86 | + burst_count=hs.config.rc_key_requests.burst_count, |
| 87 | + ) |
| 88 | + |
81 | 89 | async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: |
82 | 90 | local_messages = {} |
83 | 91 | sender_user_id = content["sender"] |
@@ -168,15 +176,27 @@ async def _check_for_unknown_devices( |
168 | 176 |
|
169 | 177 | async def send_device_message( |
170 | 178 | self, |
171 | | - sender_user_id: str, |
| 179 | + requester: Requester, |
172 | 180 | message_type: str, |
173 | 181 | messages: Dict[str, Dict[str, JsonDict]], |
174 | 182 | ) -> None: |
| 183 | + sender_user_id = requester.user.to_string() |
| 184 | + |
175 | 185 | set_tag("number_of_messages", len(messages)) |
176 | 186 | set_tag("sender", sender_user_id) |
177 | 187 | local_messages = {} |
178 | 188 | remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]] |
179 | 189 | for user_id, by_device in messages.items(): |
| 190 | + # Ratelimit local cross-user key requests by the sending device. |
| 191 | + if ( |
| 192 | + message_type == EduTypes.RoomKeyRequest |
| 193 | + and user_id != sender_user_id |
| 194 | + and self._ratelimiter.can_do_action( |
| 195 | + (sender_user_id, requester.device_id) |
| 196 | + ) |
| 197 | + ): |
| 198 | + continue |
| 199 | + |
180 | 200 | # we use UserID.from_string to catch invalid user ids |
181 | 201 | if self.is_mine(UserID.from_string(user_id)): |
182 | 202 | messages_by_device = { |
|
0 commit comments