Skip to content

Commit c842c58

Browse files
authored
When joining a remote room limit the number of events we concurrently check signatures/hashes for (matrix-org#10117)
If we do hundreds of thousands at once the memory overhead can easily reach 500+ MB.
1 parent a0101fc commit c842c58

File tree

5 files changed

+202
-256
lines changed

5 files changed

+202
-256
lines changed

changelog.d/10117.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Significantly reduce memory usage of joining large remote rooms.

synapse/crypto/keyring.py

Lines changed: 12 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -233,41 +233,19 @@ def verify_json_objects_for_server(
233233
for server_name, json_object, validity_time in server_and_json
234234
]
235235

236-
def verify_events_for_server(
237-
self, server_and_events: Iterable[Tuple[str, EventBase, int]]
238-
) -> List[defer.Deferred]:
239-
"""Bulk verification of signatures on events.
240-
241-
Args:
242-
server_and_events:
243-
Iterable of `(server_name, event, validity_time)` tuples.
244-
245-
`server_name` is which server we are verifying the signature for
246-
on the event.
247-
248-
`event` is the event that we'll verify the signatures of for
249-
the given `server_name`.
250-
251-
`validity_time` is a timestamp at which the signing key must be
252-
valid.
253-
254-
Returns:
255-
List<Deferred[None]>: for each input triplet, a deferred indicating success
256-
or failure to verify each event's signature for the given
257-
server_name. The deferreds run their callbacks in the sentinel
258-
logcontext.
259-
"""
260-
return [
261-
run_in_background(
262-
self.process_request,
263-
VerifyJsonRequest.from_event(
264-
server_name,
265-
event,
266-
validity_time,
267-
),
236+
async def verify_event_for_server(
237+
self,
238+
server_name: str,
239+
event: EventBase,
240+
validity_time: int,
241+
) -> None:
242+
await self.process_request(
243+
VerifyJsonRequest.from_event(
244+
server_name,
245+
event,
246+
validity_time,
268247
)
269-
for server_name, event, validity_time in server_and_events
270-
]
248+
)
271249

272250
async def process_request(self, verify_request: VerifyJsonRequest) -> None:
273251
"""Processes the `VerifyJsonRequest`. Raises if the object is not signed

synapse/federation/federation_base.py

Lines changed: 80 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,6 @@
1414
# limitations under the License.
1515
import logging
1616
from collections import namedtuple
17-
from typing import Iterable, List
18-
19-
from twisted.internet import defer
20-
from twisted.internet.defer import Deferred, DeferredList
21-
from twisted.python.failure import Failure
2217

2318
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
2419
from synapse.api.errors import Codes, SynapseError
@@ -28,11 +23,6 @@
2823
from synapse.events import EventBase, make_event_from_dict
2924
from synapse.events.utils import prune_event, validate_canonicaljson
3025
from synapse.http.servlet import assert_params_in_dict
31-
from synapse.logging.context import (
32-
PreserveLoggingContext,
33-
current_context,
34-
make_deferred_yieldable,
35-
)
3626
from synapse.types import JsonDict, get_domain_from_id
3727

3828
logger = logging.getLogger(__name__)
@@ -48,112 +38,82 @@ def __init__(self, hs):
4838
self.store = hs.get_datastore()
4939
self._clock = hs.get_clock()
5040

51-
def _check_sigs_and_hash(
41+
async def _check_sigs_and_hash(
5242
self, room_version: RoomVersion, pdu: EventBase
53-
) -> Deferred:
54-
return make_deferred_yieldable(
55-
self._check_sigs_and_hashes(room_version, [pdu])[0]
56-
)
57-
58-
def _check_sigs_and_hashes(
59-
self, room_version: RoomVersion, pdus: List[EventBase]
60-
) -> List[Deferred]:
61-
"""Checks that each of the received events is correctly signed by the
62-
sending server.
43+
) -> EventBase:
44+
"""Checks that event is correctly signed by the sending server.
6345
6446
Args:
65-
room_version: The room version of the PDUs
66-
pdus: the events to be checked
47+
room_version: The room version of the PDU
48+
pdu: the event to be checked
6749
6850
Returns:
69-
For each input event, a deferred which:
70-
* returns the original event if the checks pass
71-
* returns a redacted version of the event (if the signature
51+
* the original event if the checks pass
52+
* a redacted version of the event (if the signature
7253
matched but the hash did not)
73-
* throws a SynapseError if the signature check failed.
74-
The deferreds run their callbacks in the sentinel
75-
"""
76-
deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus)
77-
78-
ctx = current_context()
79-
80-
@defer.inlineCallbacks
81-
def callback(_, pdu: EventBase):
82-
with PreserveLoggingContext(ctx):
83-
if not check_event_content_hash(pdu):
84-
# let's try to distinguish between failures because the event was
85-
# redacted (which are somewhat expected) vs actual ball-tampering
86-
# incidents.
87-
#
88-
# This is just a heuristic, so we just assume that if the keys are
89-
# about the same between the redacted and received events, then the
90-
# received event was probably a redacted copy (but we then use our
91-
# *actual* redacted copy to be on the safe side.)
92-
redacted_event = prune_event(pdu)
93-
if set(redacted_event.keys()) == set(pdu.keys()) and set(
94-
redacted_event.content.keys()
95-
) == set(pdu.content.keys()):
96-
logger.info(
97-
"Event %s seems to have been redacted; using our redacted "
98-
"copy",
99-
pdu.event_id,
100-
)
101-
else:
102-
logger.warning(
103-
"Event %s content has been tampered, redacting",
104-
pdu.event_id,
105-
)
106-
return redacted_event
107-
108-
result = yield defer.ensureDeferred(
109-
self.spam_checker.check_event_for_spam(pdu)
54+
* throws a SynapseError if the signature check failed."""
55+
try:
56+
await _check_sigs_on_pdu(self.keyring, room_version, pdu)
57+
except SynapseError as e:
58+
logger.warning(
59+
"Signature check failed for %s: %s",
60+
pdu.event_id,
61+
e,
62+
)
63+
raise
64+
65+
if not check_event_content_hash(pdu):
66+
# let's try to distinguish between failures because the event was
67+
# redacted (which are somewhat expected) vs actual ball-tampering
68+
# incidents.
69+
#
70+
# This is just a heuristic, so we just assume that if the keys are
71+
# about the same between the redacted and received events, then the
72+
# received event was probably a redacted copy (but we then use our
73+
# *actual* redacted copy to be on the safe side.)
74+
redacted_event = prune_event(pdu)
75+
if set(redacted_event.keys()) == set(pdu.keys()) and set(
76+
redacted_event.content.keys()
77+
) == set(pdu.content.keys()):
78+
logger.info(
79+
"Event %s seems to have been redacted; using our redacted copy",
80+
pdu.event_id,
11081
)
111-
112-
if result:
113-
logger.warning(
114-
"Event contains spam, redacting %s: %s",
115-
pdu.event_id,
116-
pdu.get_pdu_json(),
117-
)
118-
return prune_event(pdu)
119-
120-
return pdu
121-
122-
def errback(failure: Failure, pdu: EventBase):
123-
failure.trap(SynapseError)
124-
with PreserveLoggingContext(ctx):
82+
else:
12583
logger.warning(
126-
"Signature check failed for %s: %s",
84+
"Event %s content has been tampered, redacting",
12785
pdu.event_id,
128-
failure.getErrorMessage(),
12986
)
130-
return failure
87+
return redacted_event
13188

132-
for deferred, pdu in zip(deferreds, pdus):
133-
deferred.addCallbacks(
134-
callback, errback, callbackArgs=[pdu], errbackArgs=[pdu]
89+
result = await self.spam_checker.check_event_for_spam(pdu)
90+
91+
if result:
92+
logger.warning(
93+
"Event contains spam, redacting %s: %s",
94+
pdu.event_id,
95+
pdu.get_pdu_json(),
13596
)
97+
return prune_event(pdu)
13698

137-
return deferreds
99+
return pdu
138100

139101

140102
class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])):
141103
pass
142104

143105

144-
def _check_sigs_on_pdus(
145-
keyring: Keyring, room_version: RoomVersion, pdus: Iterable[EventBase]
146-
) -> List[Deferred]:
106+
async def _check_sigs_on_pdu(
107+
keyring: Keyring, room_version: RoomVersion, pdu: EventBase
108+
) -> None:
147109
"""Check that the given events are correctly signed
148110
111+
Raise a SynapseError if the event wasn't correctly signed.
112+
149113
Args:
150114
keyring: keyring object to do the checks
151115
room_version: the room version of the PDUs
152116
pdus: the events to be checked
153-
154-
Returns:
155-
A Deferred for each event in pdus, which will either succeed if
156-
the signatures are valid, or fail (with a SynapseError) if not.
157117
"""
158118

159119
# we want to check that the event is signed by:
@@ -177,90 +137,47 @@ def _check_sigs_on_pdus(
177137
# let's start by getting the domain for each pdu, and flattening the event back
178138
# to JSON.
179139

180-
pdus_to_check = [
181-
PduToCheckSig(
182-
pdu=p,
183-
sender_domain=get_domain_from_id(p.sender),
184-
deferreds=[],
185-
)
186-
for p in pdus
187-
]
188-
189140
# First we check that the sender event is signed by the sender's domain
190141
# (except if its a 3pid invite, in which case it may be sent by any server)
191-
pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
192-
193-
more_deferreds = keyring.verify_events_for_server(
194-
[
195-
(
196-
p.sender_domain,
197-
p.pdu,
198-
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
142+
if not _is_invite_via_3pid(pdu):
143+
try:
144+
await keyring.verify_event_for_server(
145+
get_domain_from_id(pdu.sender),
146+
pdu,
147+
pdu.origin_server_ts if room_version.enforce_key_validity else 0,
199148
)
200-
for p in pdus_to_check_sender
201-
]
202-
)
203-
204-
def sender_err(e, pdu_to_check):
205-
errmsg = "event id %s: unable to verify signature for sender %s: %s" % (
206-
pdu_to_check.pdu.event_id,
207-
pdu_to_check.sender_domain,
208-
e.getErrorMessage(),
209-
)
210-
raise SynapseError(403, errmsg, Codes.FORBIDDEN)
211-
212-
for p, d in zip(pdus_to_check_sender, more_deferreds):
213-
d.addErrback(sender_err, p)
214-
p.deferreds.append(d)
149+
except Exception as e:
150+
errmsg = "event id %s: unable to verify signature for sender %s: %s" % (
151+
pdu.event_id,
152+
get_domain_from_id(pdu.sender),
153+
e,
154+
)
155+
raise SynapseError(403, errmsg, Codes.FORBIDDEN)
215156

216157
# now let's look for events where the sender's domain is different to the
217158
# event id's domain (normally only the case for joins/leaves), and add additional
218159
# checks. Only do this if the room version has a concept of event ID domain
219160
# (ie, the room version uses old-style non-hash event IDs).
220-
if room_version.event_format == EventFormatVersions.V1:
221-
pdus_to_check_event_id = [
222-
p
223-
for p in pdus_to_check
224-
if p.sender_domain != get_domain_from_id(p.pdu.event_id)
225-
]
226-
227-
more_deferreds = keyring.verify_events_for_server(
228-
[
229-
(
230-
get_domain_from_id(p.pdu.event_id),
231-
p.pdu,
232-
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
233-
)
234-
for p in pdus_to_check_event_id
235-
]
236-
)
237-
238-
def event_err(e, pdu_to_check):
161+
if room_version.event_format == EventFormatVersions.V1 and get_domain_from_id(
162+
pdu.event_id
163+
) != get_domain_from_id(pdu.sender):
164+
try:
165+
await keyring.verify_event_for_server(
166+
get_domain_from_id(pdu.event_id),
167+
pdu,
168+
pdu.origin_server_ts if room_version.enforce_key_validity else 0,
169+
)
170+
except Exception as e:
239171
errmsg = (
240-
"event id %s: unable to verify signature for event id domain: %s"
241-
% (pdu_to_check.pdu.event_id, e.getErrorMessage())
172+
"event id %s: unable to verify signature for event id domain %s: %s"
173+
% (
174+
pdu.event_id,
175+
get_domain_from_id(pdu.event_id),
176+
e,
177+
)
242178
)
243179
raise SynapseError(403, errmsg, Codes.FORBIDDEN)
244180

245-
for p, d in zip(pdus_to_check_event_id, more_deferreds):
246-
d.addErrback(event_err, p)
247-
p.deferreds.append(d)
248-
249-
# replace lists of deferreds with single Deferreds
250-
return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check]
251-
252-
253-
def _flatten_deferred_list(deferreds: List[Deferred]) -> Deferred:
254-
"""Given a list of deferreds, either return the single deferred,
255-
combine into a DeferredList, or return an already resolved deferred.
256-
"""
257-
if len(deferreds) > 1:
258-
return DeferredList(deferreds, fireOnOneErrback=True, consumeErrors=True)
259-
elif len(deferreds) == 1:
260-
return deferreds[0]
261-
else:
262-
return defer.succeed(None)
263-
264181

265182
def _is_invite_via_3pid(event: EventBase) -> bool:
266183
return (

0 commit comments

Comments
 (0)