Skip to content

Commit 513e925

Browse files
committed
Use transaction for entire fetch keys step
1 parent d675cf3 commit 513e925

File tree

3 files changed

+67
-51
lines changed

3 files changed

+67
-51
lines changed

mautrix/crypto/device_lists.py

Lines changed: 56 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -59,61 +59,67 @@ async def _fetch_keys(
5959
data = {}
6060
for user_id, devices in resp.device_keys.items():
6161
missing_users.remove(user_id)
62-
63-
new_devices = {}
64-
existing_devices = (await self.crypto_store.get_devices(user_id)) or {}
65-
66-
self.log.trace(
67-
f"Updating devices for {user_id}, got {len(devices)}, "
68-
f"have {len(existing_devices)} in store"
69-
)
70-
changed = False
71-
ssks = resp.self_signing_keys.get(user_id)
72-
ssk = ssks.first_ed25519_key if ssks else None
73-
for device_id, device_keys in devices.items():
74-
try:
75-
existing = existing_devices[device_id]
76-
except KeyError:
77-
existing = None
78-
changed = True
79-
self.log.trace(f"Validating device {device_keys} of {user_id}")
80-
try:
81-
new_device = await self._validate_device(
82-
user_id, device_id, device_keys, existing
83-
)
84-
except DeviceValidationError as e:
85-
self.log.warning(f"Failed to validate device {device_id} of {user_id}: {e}")
86-
else:
87-
if new_device:
88-
new_devices[device_id] = new_device
89-
await self._store_device_self_signatures(device_keys, ssk)
90-
self.log.debug(
91-
f"Storing new device list for {user_id} containing {len(new_devices)} devices"
92-
)
93-
await self.crypto_store.put_devices(user_id, new_devices)
94-
data[user_id] = new_devices
95-
96-
if changed or len(new_devices) != len(existing_devices):
97-
if self.delete_keys_on_device_delete:
98-
for device_id in existing_devices.keys() - new_devices.keys():
99-
device = existing_devices[device_id]
100-
removed_ids = await self.crypto_store.redact_group_sessions(
101-
room_id=None, sender_key=device.identity_key, reason="device removed"
102-
)
103-
self.log.info(
104-
"Redacted megolm sessions sent by removed device "
105-
f"{device.user_id}/{device.device_id}: {removed_ids}"
106-
)
107-
await self.on_devices_changed(user_id)
62+
async with self.crypto_store.transaction():
63+
data[user_id] = await self._process_fetched_keys(user_id, devices, resp)
10864

10965
for user_id in missing_users:
11066
self.log.warning(f"Didn't get any devices for user {user_id}")
11167

112-
for user_id in users:
113-
await self._store_cross_signing_keys(resp, user_id)
114-
11568
return data
11669

70+
async def _process_fetched_keys(
71+
self,
72+
user_id: UserID,
73+
devices: dict[DeviceID, DeviceKeys],
74+
resp: QueryKeysResponse,
75+
) -> dict[DeviceID, DeviceIdentity]:
76+
new_devices = {}
77+
existing_devices = (await self.crypto_store.get_devices(user_id)) or {}
78+
79+
self.log.trace(
80+
f"Updating devices for {user_id}, got {len(devices)}, "
81+
f"have {len(existing_devices)} in store"
82+
)
83+
changed = False
84+
ssks = resp.self_signing_keys.get(user_id)
85+
ssk = ssks.first_ed25519_key if ssks else None
86+
for device_id, device_keys in devices.items():
87+
try:
88+
existing = existing_devices[device_id]
89+
except KeyError:
90+
existing = None
91+
changed = True
92+
self.log.trace(f"Validating device {device_keys} of {user_id}")
93+
try:
94+
new_device = await self._validate_device(user_id, device_id, device_keys, existing)
95+
except DeviceValidationError as e:
96+
self.log.warning(f"Failed to validate device {device_id} of {user_id}: {e}")
97+
else:
98+
if new_device:
99+
new_devices[device_id] = new_device
100+
await self._store_device_self_signatures(device_keys, ssk)
101+
self.log.debug(
102+
f"Storing new device list for {user_id} containing {len(new_devices)} devices"
103+
)
104+
await self.crypto_store.put_devices(user_id, new_devices)
105+
106+
if changed or len(new_devices) != len(existing_devices):
107+
if self.delete_keys_on_device_delete:
108+
for device_id in existing_devices.keys() - new_devices.keys():
109+
device = existing_devices[device_id]
110+
removed_ids = await self.crypto_store.redact_group_sessions(
111+
room_id=None, sender_key=device.identity_key, reason="device removed"
112+
)
113+
self.log.info(
114+
"Redacted megolm sessions sent by removed device "
115+
f"{device.user_id}/{device.device_id}: {removed_ids}"
116+
)
117+
await self.on_devices_changed(user_id)
118+
119+
await self._store_cross_signing_keys(resp, user_id)
120+
121+
return new_devices
122+
117123
async def _store_device_self_signatures(
118124
self, device_keys: DeviceKeys, self_signing_key: SigningKey | None
119125
) -> None:
@@ -343,7 +349,7 @@ async def _try_resolve_trust(
343349
ssk = their_keys[CrossSigningUsage.SELF]
344350
except KeyError as e:
345351
if allow_fetch:
346-
self.log.error(f"Didn't find cross-signing key {e.args[0]} of {device.user_id}")
352+
self.log.warning(f"Didn't find cross-signing key {e.args[0]} of {device.user_id}")
347353
return TrustState.UNVERIFIED
348354
ssk_signed = await self.crypto_store.is_key_signed_by(
349355
target=CrossSigner(device.user_id, ssk.key),

mautrix/crypto/store/abstract.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
66
from __future__ import annotations
77

8-
from typing import NamedTuple
8+
from typing import AsyncContextManager, NamedTuple
99
from abc import ABC, abstractmethod
1010

1111
from mautrix.types import (
@@ -87,6 +87,10 @@ async def close(self) -> None:
8787
async def flush(self) -> None:
8888
"""Flush the store. If all the methods persist data immediately, this can be a no-op."""
8989

90+
async def transaction(self) -> AsyncContextManager[None]:
91+
"""Run a database transaction. If the store doesn't support transactions, this can be a no-op."""
92+
pass
93+
9094
@abstractmethod
9195
async def delete(self) -> None:
9296
"""Delete the data in the store."""

mautrix/crypto/store/asyncpg/store.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from __future__ import annotations
77

88
from collections import defaultdict
9+
from contextlib import asynccontextmanager
910
from datetime import timedelta
1011

1112
from asyncpg import UniqueViolationError
@@ -79,6 +80,11 @@ def __init__(self, account_id: str, pickle_key: str, db: Database) -> None:
7980
self._account = None
8081
self._olm_cache = defaultdict(lambda: {})
8182

83+
@asynccontextmanager
84+
async def transaction(self) -> None:
85+
async with self.db.acquire() as conn, conn.transaction():
86+
yield
87+
8288
async def delete(self) -> None:
8389
tables = ("crypto_account", "crypto_olm_session", "crypto_megolm_outbound_session")
8490
async with self.db.acquire() as conn, conn.transaction():

0 commit comments

Comments
 (0)