88from typing import Optional
99import asyncio
1010import logging
11+ import time
1112
1213from mautrix import client as cli
1314from mautrix .errors import GroupSessionWithheldError
1819 DeviceLists ,
1920 DeviceOTKCount ,
2021 EncryptionAlgorithm ,
22+ EncryptionKeyAlgorithm ,
2123 EventType ,
2224 Member ,
2325 Membership ,
@@ -87,6 +89,8 @@ def __init__(
8789
8890 self ._fetch_keys_lock = asyncio .Lock ()
8991 self ._megolm_decrypt_lock = asyncio .Lock ()
92+ self ._share_keys_lock = asyncio .Lock ()
93+ self ._last_key_share = time .monotonic () - 60
9094 self ._key_request_waiters = {}
9195 self ._inbound_session_waiters = {}
9296 self ._prev_unwedge = {}
@@ -267,14 +271,29 @@ async def handle_beep_room_key_ack(self, evt: ToDeviceEvent) -> None:
267271 else :
268272 self .log .debug (f"Received room key ack for { sess .id } " )
269273
270- async def share_keys (self , current_otk_count : int ) -> None :
274+ async def share_keys (self , current_otk_count : int | None = None ) -> None :
271275 """
272276 Share any keys that need to be shared. This is automatically called from
273277 :meth:`handle_otk_count`, so you should not need to call this yourself.
274278
275279 Args:
276280 current_otk_count: The current number of signed curve25519 keys present on the server.
281+ If omitted, the count will be fetched from the server.
277282 """
283+ async with self ._share_keys_lock :
284+ await self ._share_keys (current_otk_count )
285+
286+ async def _share_keys (self , current_otk_count : int | None ) -> None :
287+ if current_otk_count is None or (
288+ # If the last key share was recent and the new count is very low, re-check the count
289+ # from the server to avoid any race conditions.
290+ self ._last_key_share + 60 > time .monotonic ()
291+ and current_otk_count < 10
292+ ):
293+ self .log .debug ("Checking OTK count on server" )
294+ current_otk_count = (await self .client .upload_keys ()).get (
295+ EncryptionKeyAlgorithm .SIGNED_CURVE25519
296+ )
278297 device_keys = (
279298 self .account .get_device_keys (self .client .mxid , self .client .device_id )
280299 if not self .account .shared
@@ -289,7 +308,8 @@ async def share_keys(self, current_otk_count: int) -> None:
289308 if device_keys :
290309 self .log .debug ("Going to upload initial account keys" )
291310 self .log .debug (f"Uploading { len (one_time_keys )} one-time keys" )
292- await self .client .upload_keys (one_time_keys = one_time_keys , device_keys = device_keys )
311+ resp = await self .client .upload_keys (one_time_keys = one_time_keys , device_keys = device_keys )
293312 self .account .shared = True
313+ self ._last_key_share = time .monotonic ()
294314 await self .crypto_store .put_account (self .account )
295- self .log .debug ("Shared keys and saved account" )
315+ self .log .debug (f "Shared keys and saved account, new keys: { resp } " )
0 commit comments