@@ -191,7 +191,7 @@ async def get_devices_by_auth_provider_session_id(
191191 @trace
192192 async def get_device_updates_by_remote (
193193 self , destination : str , from_stream_id : int , limit : int
194- ) -> Tuple [int , List [Tuple [str , dict ]]]:
194+ ) -> Tuple [int , List [Tuple [str , JsonDict ]]]:
195195 """Get a stream of device updates to send to the given remote server.
196196
197197 Args:
@@ -200,9 +200,10 @@ async def get_device_updates_by_remote(
200200 limit: Maximum number of device updates to return
201201
202202 Returns:
203- A mapping from the current stream id (ie, the stream id of the last
204- update included in the response), and the list of updates, where
205- each update is a pair of EDU type and EDU contents.
203+ - The current stream id (i.e. the stream id of the last update included
204+ in the response); and
205+ - The list of updates, where each update is a pair of EDU type and
206+ EDU contents.
206207 """
207208 now_stream_id = self .get_device_stream_token ()
208209
@@ -221,6 +222,9 @@ async def get_device_updates_by_remote(
221222 limit ,
222223 )
223224
225+ # We need to ensure `updates` doesn't grow too big.
226+ # Currently: `len(updates) <= limit`.
227+
224228 # Return an empty list if there are no updates
225229 if not updates :
226230 return now_stream_id , []
@@ -277,40 +281,88 @@ async def get_device_updates_by_remote(
277281 query_map = {}
278282 cross_signing_keys_by_user = {}
279283 for user_id , device_id , update_stream_id , update_context in updates :
280- if (
284+ # Calculate the remaining length budget.
285+ # Note that, for now, each entry in `cross_signing_keys_by_user`
286+ # gives rise to two device updates in the result, so those cost twice
287+ # as much (and are the whole reason we need to separately calculate
288+ # the budget; we know len(updates) <= limit otherwise!)
289+ # N.B. len() on dicts is cheap since they store their size.
290+ remaining_length_budget = limit - (
291+ len (query_map ) + 2 * len (cross_signing_keys_by_user )
292+ )
293+ assert remaining_length_budget >= 0
294+
295+ is_master_key_update = (
281296 user_id in master_key_by_user
282297 and device_id == master_key_by_user [user_id ]["device_id" ]
283- ):
284- result = cross_signing_keys_by_user .setdefault (user_id , {})
285- result ["master_key" ] = master_key_by_user [user_id ]["key_info" ]
286- elif (
298+ )
299+ is_self_signing_key_update = (
287300 user_id in self_signing_key_by_user
288301 and device_id == self_signing_key_by_user [user_id ]["device_id" ]
302+ )
303+
304+ is_cross_signing_key_update = (
305+ is_master_key_update or is_self_signing_key_update
306+ )
307+
308+ if (
309+ is_cross_signing_key_update
310+ and user_id not in cross_signing_keys_by_user
289311 ):
312+ # This will give rise to 2 device updates.
313+ # If we don't have the budget, stop here!
314+ if remaining_length_budget < 2 :
315+ break
316+
317+ if is_master_key_update :
318+ result = cross_signing_keys_by_user .setdefault (user_id , {})
319+ result ["master_key" ] = master_key_by_user [user_id ]["key_info" ]
320+ elif is_self_signing_key_update :
290321 result = cross_signing_keys_by_user .setdefault (user_id , {})
291322 result ["self_signing_key" ] = self_signing_key_by_user [user_id ][
292323 "key_info"
293324 ]
294325 else :
295326 key = (user_id , device_id )
296327
328+ if key not in query_map and remaining_length_budget < 1 :
329+ # We don't have space for a new entry
330+ break
331+
297332 previous_update_stream_id , _ = query_map .get (key , (0 , None ))
298333
299334 if update_stream_id > previous_update_stream_id :
335+ # FIXME If this overwrites an older update, this discards the
336+ # previous OpenTracing context.
337+ # It might make it harder to track down issues using OpenTracing.
338+ # If there's a good reason why it doesn't matter, a comment here
339+ # about that would not hurt.
300340 query_map [key ] = (update_stream_id , update_context )
301341
342+ # As this update has been added to the response, advance the stream
343+ # position.
302344 last_processed_stream_id = update_stream_id
303345
346+ # In the worst case scenario, each update is for a distinct user and is
347+ # added either to the query_map or to cross_signing_keys_by_user,
348+ # but not both:
349+ # len(query_map) + len(cross_signing_keys_by_user) <= len(updates) here,
350+ # so len(query_map) + len(cross_signing_keys_by_user) <= limit.
351+
304352 results = await self ._get_device_update_edus_by_remote (
305353 destination , from_stream_id , query_map
306354 )
307355
308- # add the updated cross-signing keys to the results list
356+ # len(results) <= len(query_map) here,
357+ # so len(results) + len(cross_signing_keys_by_user) <= limit.
358+
359+ # Add the updated cross-signing keys to the results list
309360 for user_id , result in cross_signing_keys_by_user .items ():
310361 result ["user_id" ] = user_id
311362 results .append (("m.signing_key_update" , result ))
312363 # also send the unstable version
313364 # FIXME: remove this when enough servers have upgraded
365+ # and remove the length budgeting above.
314366 results .append (("org.matrix.signing_key_update" , result ))
315367
316368 return last_processed_stream_id , results
@@ -322,7 +374,7 @@ def _get_device_updates_by_remote_txn(
322374 from_stream_id : int ,
323375 now_stream_id : int ,
324376 limit : int ,
325- ):
377+ ) -> List [ Tuple [ str , str , int , Optional [ str ]]] :
326378 """Return device update information for a given remote destination
327379
328380 Args:
@@ -333,7 +385,11 @@ def _get_device_updates_by_remote_txn(
333385 limit: Maximum number of device updates to return
334386
335387 Returns:
336- List: List of device updates
388+ List: List of device update tuples:
389+ - user_id
390+ - device_id
391+ - stream_id
392+ - opentracing_context
337393 """
338394 # get the list of device updates that need to be sent
339395 sql = """
@@ -357,15 +413,21 @@ async def _get_device_update_edus_by_remote(
357413 Args:
358414 destination: The host the device updates are intended for
359415 from_stream_id: The minimum stream_id to filter updates by, exclusive
360- query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping
361- user_id/device_id to update stream_id and the relevant json-encoded
362- opentracing context
416+ query_map: Dictionary mapping (user_id, device_id) to
417+ (update stream_id, the relevant json-encoded opentracing context)
363418
364419 Returns:
365- List of objects representing an device update EDU
420+ List of objects representing a device update EDU.
421+
422+ Postconditions:
423+ The returned list has a length not exceeding that of the query_map:
424+ len(result) <= len(query_map)
366425 """
367426 devices = (
368427 await self .get_e2e_device_keys_and_signatures (
428+ # Because these are (user_id, device_id) tuples with all
429+ # device_ids not being None, the returned list's length will not
430+ # exceed that of query_map.
369431 query_map .keys (),
370432 include_all_devices = True ,
371433 include_deleted_devices = True ,
0 commit comments