Skip to content

Commit b602ba1

Browse files
reivilibreDavid Robertson
andauthored
Fix a bug introduced in Synapse v1.50.0rc1 whereby outbound federation could fail because too many EDUs were produced for device updates. (matrix-org#11730)
Co-authored-by: David Robertson <[email protected]>
1 parent 22abfca commit b602ba1

File tree

3 files changed

+190
-17
lines changed

3 files changed

+190
-17
lines changed

changelog.d/11730.bugfix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix a bug introduced in Synapse v1.50.0rc1 whereby outbound federation could fail because too many EDUs were produced for device updates.

synapse/storage/databases/main/devices.py

Lines changed: 78 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ async def get_devices_by_auth_provider_session_id(
191191
@trace
192192
async def get_device_updates_by_remote(
193193
self, destination: str, from_stream_id: int, limit: int
194-
) -> Tuple[int, List[Tuple[str, dict]]]:
194+
) -> Tuple[int, List[Tuple[str, JsonDict]]]:
195195
"""Get a stream of device updates to send to the given remote server.
196196
197197
Args:
@@ -200,9 +200,10 @@ async def get_device_updates_by_remote(
200200
limit: Maximum number of device updates to return
201201
202202
Returns:
203-
A mapping from the current stream id (ie, the stream id of the last
204-
update included in the response), and the list of updates, where
205-
each update is a pair of EDU type and EDU contents.
203+
- The current stream id (i.e. the stream id of the last update included
204+
in the response); and
205+
- The list of updates, where each update is a pair of EDU type and
206+
EDU contents.
206207
"""
207208
now_stream_id = self.get_device_stream_token()
208209

@@ -221,6 +222,9 @@ async def get_device_updates_by_remote(
221222
limit,
222223
)
223224

225+
# We need to ensure `updates` doesn't grow too big.
226+
# Currently: `len(updates) <= limit`.
227+
224228
# Return an empty list if there are no updates
225229
if not updates:
226230
return now_stream_id, []
@@ -277,40 +281,88 @@ async def get_device_updates_by_remote(
277281
query_map = {}
278282
cross_signing_keys_by_user = {}
279283
for user_id, device_id, update_stream_id, update_context in updates:
280-
if (
284+
# Calculate the remaining length budget.
285+
# Note that, for now, each entry in `cross_signing_keys_by_user`
286+
# gives rise to two device updates in the result, so those cost twice
287+
# as much (and are the whole reason we need to separately calculate
288+
# the budget; we know len(updates) <= limit otherwise!)
289+
# N.B. len() on dicts is cheap since they store their size.
290+
remaining_length_budget = limit - (
291+
len(query_map) + 2 * len(cross_signing_keys_by_user)
292+
)
293+
assert remaining_length_budget >= 0
294+
295+
is_master_key_update = (
281296
user_id in master_key_by_user
282297
and device_id == master_key_by_user[user_id]["device_id"]
283-
):
284-
result = cross_signing_keys_by_user.setdefault(user_id, {})
285-
result["master_key"] = master_key_by_user[user_id]["key_info"]
286-
elif (
298+
)
299+
is_self_signing_key_update = (
287300
user_id in self_signing_key_by_user
288301
and device_id == self_signing_key_by_user[user_id]["device_id"]
302+
)
303+
304+
is_cross_signing_key_update = (
305+
is_master_key_update or is_self_signing_key_update
306+
)
307+
308+
if (
309+
is_cross_signing_key_update
310+
and user_id not in cross_signing_keys_by_user
289311
):
312+
# This will give rise to 2 device updates.
313+
# If we don't have the budget, stop here!
314+
if remaining_length_budget < 2:
315+
break
316+
317+
if is_master_key_update:
318+
result = cross_signing_keys_by_user.setdefault(user_id, {})
319+
result["master_key"] = master_key_by_user[user_id]["key_info"]
320+
elif is_self_signing_key_update:
290321
result = cross_signing_keys_by_user.setdefault(user_id, {})
291322
result["self_signing_key"] = self_signing_key_by_user[user_id][
292323
"key_info"
293324
]
294325
else:
295326
key = (user_id, device_id)
296327

328+
if key not in query_map and remaining_length_budget < 1:
329+
# We don't have space for a new entry
330+
break
331+
297332
previous_update_stream_id, _ = query_map.get(key, (0, None))
298333

299334
if update_stream_id > previous_update_stream_id:
335+
# FIXME If this overwrites an older update, this discards the
336+
# previous OpenTracing context.
337+
# It might make it harder to track down issues using OpenTracing.
338+
# If there's a good reason why it doesn't matter, a comment here
339+
# about that would not hurt.
300340
query_map[key] = (update_stream_id, update_context)
301341

342+
# As this update has been added to the response, advance the stream
343+
# position.
302344
last_processed_stream_id = update_stream_id
303345

346+
# In the worst case scenario, each update is for a distinct user and is
347+
# added either to the query_map or to cross_signing_keys_by_user,
348+
# but not both:
349+
# len(query_map) + len(cross_signing_keys_by_user) <= len(updates) here,
350+
# so len(query_map) + len(cross_signing_keys_by_user) <= limit.
351+
304352
results = await self._get_device_update_edus_by_remote(
305353
destination, from_stream_id, query_map
306354
)
307355

308-
# add the updated cross-signing keys to the results list
356+
# len(results) <= len(query_map) here,
357+
# so len(results) + len(cross_signing_keys_by_user) <= limit.
358+
359+
# Add the updated cross-signing keys to the results list
309360
for user_id, result in cross_signing_keys_by_user.items():
310361
result["user_id"] = user_id
311362
results.append(("m.signing_key_update", result))
312363
# also send the unstable version
313364
# FIXME: remove this when enough servers have upgraded
365+
# and remove the length budgeting above.
314366
results.append(("org.matrix.signing_key_update", result))
315367

316368
return last_processed_stream_id, results
@@ -322,7 +374,7 @@ def _get_device_updates_by_remote_txn(
322374
from_stream_id: int,
323375
now_stream_id: int,
324376
limit: int,
325-
):
377+
) -> List[Tuple[str, str, int, Optional[str]]]:
326378
"""Return device update information for a given remote destination
327379
328380
Args:
@@ -333,7 +385,11 @@ def _get_device_updates_by_remote_txn(
333385
limit: Maximum number of device updates to return
334386
335387
Returns:
336-
List: List of device updates
388+
List: List of device update tuples:
389+
- user_id
390+
- device_id
391+
- stream_id
392+
- opentracing_context
337393
"""
338394
# get the list of device updates that need to be sent
339395
sql = """
@@ -357,15 +413,21 @@ async def _get_device_update_edus_by_remote(
357413
Args:
358414
destination: The host the device updates are intended for
359415
from_stream_id: The minimum stream_id to filter updates by, exclusive
360-
query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping
361-
user_id/device_id to update stream_id and the relevant json-encoded
362-
opentracing context
416+
query_map: Dictionary mapping (user_id, device_id) to
417+
(update stream_id, the relevant json-encoded opentracing context)
363418
364419
Returns:
365-
List of objects representing an device update EDU
420+
List of objects representing a device update EDU.
421+
422+
Postconditions:
423+
The returned list has a length not exceeding that of the query_map:
424+
len(result) <= len(query_map)
366425
"""
367426
devices = (
368427
await self.get_e2e_device_keys_and_signatures(
428+
# Because these are (user_id, device_id) tuples with all
429+
# device_ids not being None, the returned list's length will not
430+
# exceed that of query_map.
369431
query_map.keys(),
370432
include_all_devices=True,
371433
include_deleted_devices=True,

tests/storage/test_devices.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def test_get_device_updates_by_remote_can_limit_properly(self):
125125
self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
126126
)
127127

128-
# Get all device updates ever meant for this remote
128+
# Get device updates meant for this remote
129129
next_stream_id, device_updates = self.get_success(
130130
self.store.get_device_updates_by_remote("somehost", -1, limit=3)
131131
)
@@ -155,6 +155,116 @@ def test_get_device_updates_by_remote_can_limit_properly(self):
155155
# Check the newly-added device_ids are contained within these updates
156156
self._check_devices_in_updates(device_ids, device_updates)
157157

158+
# Check there are no more device updates left.
159+
_, device_updates = self.get_success(
160+
self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
161+
)
162+
self.assertEqual(device_updates, [])
163+
164+
def test_get_device_updates_by_remote_cross_signing_key_updates(
165+
self,
166+
) -> None:
167+
"""
168+
Tests that `get_device_updates_by_remote` limits the length of the return value
169+
properly when cross-signing key updates are present.
170+
Current behaviour is that the cross-signing key updates will always come in pairs,
171+
even if that means leaving an earlier batch one EDU short of the limit.
172+
"""
173+
174+
assert self.hs.is_mine_id(
175+
"@user_id:test"
176+
), "Test not valid: this MXID should be considered local"
177+
178+
self.get_success(
179+
self.store.set_e2e_cross_signing_key(
180+
"@user_id:test",
181+
"master",
182+
{
183+
"keys": {
184+
"ed25519:fakeMaster": "aaafakefakefake1AAAAAAAAAAAAAAAAAAAAAAAAAAA="
185+
},
186+
"signatures": {
187+
"@user_id:test": {
188+
"ed25519:fake2": "aaafakefakefake2AAAAAAAAAAAAAAAAAAAAAAAAAAA="
189+
}
190+
},
191+
},
192+
)
193+
)
194+
self.get_success(
195+
self.store.set_e2e_cross_signing_key(
196+
"@user_id:test",
197+
"self_signing",
198+
{
199+
"keys": {
200+
"ed25519:fakeSelfSigning": "aaafakefakefake3AAAAAAAAAAAAAAAAAAAAAAAAAAA="
201+
},
202+
"signatures": {
203+
"@user_id:test": {
204+
"ed25519:fake4": "aaafakefakefake4AAAAAAAAAAAAAAAAAAAAAAAAAAA="
205+
}
206+
},
207+
},
208+
)
209+
)
210+
211+
# Add some device updates with sequential `stream_id`s
212+
# Note that the public cross-signing keys occupy the same space as device IDs,
213+
# so also notify that those have updated.
214+
device_ids = [
215+
"device_id1",
216+
"device_id2",
217+
"fakeMaster",
218+
"fakeSelfSigning",
219+
]
220+
221+
self.get_success(
222+
self.store.add_device_change_to_streams(
223+
"@user_id:test", device_ids, ["somehost"]
224+
)
225+
)
226+
227+
# Get device updates meant for this remote
228+
next_stream_id, device_updates = self.get_success(
229+
self.store.get_device_updates_by_remote("somehost", -1, limit=3)
230+
)
231+
232+
# Here we expect the device updates for `device_id1` and `device_id2`.
233+
# That means we only receive 2 updates this time around.
234+
# If we had a higher limit, we would expect to see the pair of
235+
# (unstable-prefixed & unprefixed) signing key updates for the device
236+
# represented by `fakeMaster` and `fakeSelfSigning`.
237+
# Our implementation only sends these two variants together, so we get
238+
# a short batch.
239+
self.assertEqual(len(device_updates), 2, device_updates)
240+
241+
# Check the first two devices (device_id1, device_id2) came out.
242+
self._check_devices_in_updates(device_ids[:2], device_updates)
243+
244+
# Get more device updates meant for this remote
245+
next_stream_id, device_updates = self.get_success(
246+
self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
247+
)
248+
249+
# The next 2 updates should be a cross-signing key update
250+
# (the master key update and the self-signing key update are combined into
251+
# one 'signing key update', but the cross-signing key update is emitted
252+
# twice, once with an unprefixed type and once again with an unstable-prefixed type)
253+
# (This is a temporary arrangement for backwards compatibility!)
254+
self.assertEqual(len(device_updates), 2, device_updates)
255+
self.assertEqual(
256+
device_updates[0][0], "m.signing_key_update", device_updates[0]
257+
)
258+
self.assertEqual(
259+
device_updates[1][0], "org.matrix.signing_key_update", device_updates[1]
260+
)
261+
262+
# Check there are no more device updates left.
263+
_, device_updates = self.get_success(
264+
self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
265+
)
266+
self.assertEqual(device_updates, [])
267+
158268
def _check_devices_in_updates(self, expected_device_ids, device_updates):
159269
"""Check that an specific device ids exist in a list of device update EDUs"""
160270
self.assertEqual(len(device_updates), len(expected_device_ids))

0 commit comments

Comments
 (0)