Skip to content

Commit e1b28d3

Browse files
committed
Allow using OlmMachine.share_keys without OTK count
1 parent c6979e8 commit e1b28d3

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

mautrix/crypto/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ class BaseOlmMachine:
6666
_prev_unwedge: dict[IdentityKey, float]
6767
_fetch_keys_lock: asyncio.Lock
6868
_megolm_decrypt_lock: asyncio.Lock
69+
_share_keys_lock: asyncio.Lock
70+
_last_key_share: float
6971
_cs_fetch_attempted: set[UserID]
7072

7173
async def wait_for_session(

mautrix/crypto/machine.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Optional
99
import asyncio
1010
import logging
11+
import time
1112

1213
from mautrix import client as cli
1314
from mautrix.errors import GroupSessionWithheldError
@@ -18,6 +19,7 @@
1819
DeviceLists,
1920
DeviceOTKCount,
2021
EncryptionAlgorithm,
22+
EncryptionKeyAlgorithm,
2123
EventType,
2224
Member,
2325
Membership,
@@ -87,6 +89,8 @@ def __init__(
8789

8890
self._fetch_keys_lock = asyncio.Lock()
8991
self._megolm_decrypt_lock = asyncio.Lock()
92+
self._share_keys_lock = asyncio.Lock()
93+
self._last_key_share = time.monotonic() - 60
9094
self._key_request_waiters = {}
9195
self._inbound_session_waiters = {}
9296
self._prev_unwedge = {}
@@ -267,14 +271,29 @@ async def handle_beep_room_key_ack(self, evt: ToDeviceEvent) -> None:
267271
else:
268272
self.log.debug(f"Received room key ack for {sess.id}")
269273

270-
async def share_keys(self, current_otk_count: int) -> None:
274+
async def share_keys(self, current_otk_count: int | None = None) -> None:
271275
"""
272276
Share any keys that need to be shared. This is automatically called from
273277
:meth:`handle_otk_count`, so you should not need to call this yourself.
274278
275279
Args:
276280
current_otk_count: The current number of signed curve25519 keys present on the server.
281+
If omitted, the count will be fetched from the server.
277282
"""
283+
async with self._share_keys_lock:
284+
await self._share_keys(current_otk_count)
285+
286+
async def _share_keys(self, current_otk_count: int | None) -> None:
287+
if current_otk_count is None or (
288+
# If the last key share was recent and the new count is very low, re-check the count
289+
# from the server to avoid any race conditions.
290+
self._last_key_share + 60 > time.monotonic()
291+
and current_otk_count < 10
292+
):
293+
self.log.debug("Checking OTK count on server")
294+
current_otk_count = (await self.client.upload_keys()).get(
295+
EncryptionKeyAlgorithm.SIGNED_CURVE25519
296+
)
278297
device_keys = (
279298
self.account.get_device_keys(self.client.mxid, self.client.device_id)
280299
if not self.account.shared
@@ -289,7 +308,8 @@ async def share_keys(self, current_otk_count: int) -> None:
289308
if device_keys:
290309
self.log.debug("Going to upload initial account keys")
291310
self.log.debug(f"Uploading {len(one_time_keys)} one-time keys")
292-
await self.client.upload_keys(one_time_keys=one_time_keys, device_keys=device_keys)
311+
resp = await self.client.upload_keys(one_time_keys=one_time_keys, device_keys=device_keys)
293312
self.account.shared = True
313+
self._last_key_share = time.monotonic()
294314
await self.crypto_store.put_account(self.account)
295-
self.log.debug("Shared keys and saved account")
315+
self.log.debug(f"Shared keys and saved account, new keys: {resp}")

0 commit comments

Comments
 (0)