Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 37db625

Browse files
authored
Convert additional databases to async/await part 3 (#8201)
1 parent 7d103a5 commit 37db625

File tree

7 files changed

+121
-87
lines changed

7 files changed

+121
-87
lines changed

changelog.d/8201.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Convert various parts of the codebase to async/await.

synapse/storage/background_updates.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -433,15 +433,15 @@ async def _end_background_update(self, update_name: str) -> None:
433433
"background_updates", keyvalues={"update_name": update_name}
434434
)
435435

436-
def _background_update_progress(self, update_name: str, progress: dict):
436+
async def _background_update_progress(self, update_name: str, progress: dict):
437437
"""Update the progress of a background update
438438
439439
Args:
440440
update_name: The name of the background update task
441441
progress: The progress of the update.
442442
"""
443443

444-
return self.db_pool.runInteraction(
444+
return await self.db_pool.runInteraction(
445445
"background_update_progress",
446446
self._background_update_progress_txn,
447447
update_name,

synapse/storage/databases/main/account_data.py

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616

1717
import abc
1818
import logging
19-
from typing import List, Optional, Tuple
20-
21-
from twisted.internet import defer
19+
from typing import Dict, List, Optional, Tuple
2220

2321
from synapse.storage._base import SQLBaseStore, db_to_json
2422
from synapse.storage.database import DatabasePool
@@ -58,14 +56,16 @@ def get_max_account_data_stream_id(self):
5856
raise NotImplementedError()
5957

6058
@cached()
61-
def get_account_data_for_user(self, user_id):
59+
async def get_account_data_for_user(
60+
self, user_id: str
61+
) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
6262
"""Get all the client account_data for a user.
6363
6464
Args:
65-
user_id(str): The user to get the account_data for.
65+
user_id: The user to get the account_data for.
6666
Returns:
67-
A deferred pair of a dict of global account_data and a dict
68-
mapping from room_id string to per room account_data dicts.
67+
A 2-tuple of a dict of global account_data and a dict mapping from
68+
room_id string to per room account_data dicts.
6969
"""
7070

7171
def get_account_data_for_user_txn(txn):
@@ -94,7 +94,7 @@ def get_account_data_for_user_txn(txn):
9494

9595
return global_account_data, by_room
9696

97-
return self.db_pool.runInteraction(
97+
return await self.db_pool.runInteraction(
9898
"get_account_data_for_user", get_account_data_for_user_txn
9999
)
100100

@@ -120,14 +120,16 @@ async def get_global_account_data_by_type_for_user(
120120
return None
121121

122122
@cached(num_args=2)
123-
def get_account_data_for_room(self, user_id, room_id):
123+
async def get_account_data_for_room(
124+
self, user_id: str, room_id: str
125+
) -> Dict[str, JsonDict]:
124126
"""Get all the client account_data for a user for a room.
125127
126128
Args:
127-
user_id(str): The user to get the account_data for.
128-
room_id(str): The room to get the account_data for.
129+
user_id: The user to get the account_data for.
130+
room_id: The room to get the account_data for.
129131
Returns:
130-
A deferred dict of the room account_data
132+
A dict of the room account_data
131133
"""
132134

133135
def get_account_data_for_room_txn(txn):
@@ -142,21 +144,22 @@ def get_account_data_for_room_txn(txn):
142144
row["account_data_type"]: db_to_json(row["content"]) for row in rows
143145
}
144146

145-
return self.db_pool.runInteraction(
147+
return await self.db_pool.runInteraction(
146148
"get_account_data_for_room", get_account_data_for_room_txn
147149
)
148150

149151
@cached(num_args=3, max_entries=5000)
150-
def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type):
152+
async def get_account_data_for_room_and_type(
153+
self, user_id: str, room_id: str, account_data_type: str
154+
) -> Optional[JsonDict]:
151155
"""Get the client account_data of given type for a user for a room.
152156
153157
Args:
154-
user_id(str): The user to get the account_data for.
155-
room_id(str): The room to get the account_data for.
156-
account_data_type (str): The account data type to get.
158+
user_id: The user to get the account_data for.
159+
room_id: The room to get the account_data for.
160+
account_data_type: The account data type to get.
157161
Returns:
158-
A deferred of the room account_data for that type, or None if
159-
there isn't any set.
162+
The room account_data for that type, or None if there isn't any set.
160163
"""
161164

162165
def get_account_data_for_room_and_type_txn(txn):
@@ -174,7 +177,7 @@ def get_account_data_for_room_and_type_txn(txn):
174177

175178
return db_to_json(content_json) if content_json else None
176179

177-
return self.db_pool.runInteraction(
180+
return await self.db_pool.runInteraction(
178181
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
179182
)
180183

@@ -238,12 +241,14 @@ def get_updated_room_account_data_txn(txn):
238241
"get_updated_room_account_data", get_updated_room_account_data_txn
239242
)
240243

241-
def get_updated_account_data_for_user(self, user_id, stream_id):
244+
async def get_updated_account_data_for_user(
245+
self, user_id: str, stream_id: int
246+
) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
242247
"""Get all the client account_data for a that's changed for a user
243248
244249
Args:
245-
user_id(str): The user to get the account_data for.
246-
stream_id(int): The point in the stream since which to get updates
250+
user_id: The user to get the account_data for.
251+
stream_id: The point in the stream since which to get updates
247252
Returns:
248253
A deferred pair of a dict of global account_data and a dict
249254
mapping from room_id string to per room account_data dicts.
@@ -277,9 +282,9 @@ def get_updated_account_data_for_user_txn(txn):
277282
user_id, int(stream_id)
278283
)
279284
if not changed:
280-
return defer.succeed(({}, {}))
285+
return ({}, {})
281286

282-
return self.db_pool.runInteraction(
287+
return await self.db_pool.runInteraction(
283288
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
284289
)
285290

@@ -416,7 +421,7 @@ async def add_account_data_for_user(
416421

417422
return self._account_data_id_gen.get_current_token()
418423

419-
def _update_max_stream_id(self, next_id: int):
424+
async def _update_max_stream_id(self, next_id: int) -> None:
420425
"""Update the max stream_id
421426
422427
Args:
@@ -435,4 +440,4 @@ def _update(txn):
435440
)
436441
txn.execute(update_max_id_sql, (next_id, next_id))
437442

438-
return self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
443+
await self.db_pool.runInteraction("update_account_data_max_stream_id", _update)

synapse/storage/databases/main/end_to_end_keys.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,15 @@
3434

3535

3636
class EndToEndKeyWorkerStore(SQLBaseStore):
37-
def get_e2e_device_keys_for_federation_query(self, user_id: str):
37+
async def get_e2e_device_keys_for_federation_query(
38+
self, user_id: str
39+
) -> Tuple[int, List[JsonDict]]:
3840
"""Get all devices (with any device keys) for a user
3941
4042
Returns:
41-
Deferred which resolves to (stream_id, devices)
43+
(stream_id, devices)
4244
"""
43-
return self.db_pool.runInteraction(
45+
return await self.db_pool.runInteraction(
4446
"get_e2e_device_keys_for_federation_query",
4547
self._get_e2e_device_keys_for_federation_query_txn,
4648
user_id,
@@ -292,10 +294,12 @@ def _add_e2e_one_time_keys(txn):
292294
)
293295

294296
@cached(max_entries=10000)
295-
def count_e2e_one_time_keys(self, user_id, device_id):
297+
async def count_e2e_one_time_keys(
298+
self, user_id: str, device_id: str
299+
) -> Dict[str, int]:
296300
""" Count the number of one time keys the server has for a device
297301
Returns:
298-
Dict mapping from algorithm to number of keys for that algorithm.
302+
A mapping from algorithm to number of keys for that algorithm.
299303
"""
300304

301305
def _count_e2e_one_time_keys(txn):
@@ -310,7 +314,7 @@ def _count_e2e_one_time_keys(txn):
310314
result[algorithm] = key_count
311315
return result
312316

313-
return self.db_pool.runInteraction(
317+
return await self.db_pool.runInteraction(
314318
"count_e2e_one_time_keys", _count_e2e_one_time_keys
315319
)
316320

@@ -348,24 +352,23 @@ def _get_bare_e2e_cross_signing_keys(self, user_id):
348352
list_name="user_ids",
349353
num_args=1,
350354
)
351-
def _get_bare_e2e_cross_signing_keys_bulk(
355+
async def _get_bare_e2e_cross_signing_keys_bulk(
352356
self, user_ids: List[str]
353357
) -> Dict[str, Dict[str, dict]]:
354358
"""Returns the cross-signing keys for a set of users. The output of this
355359
function should be passed to _get_e2e_cross_signing_signatures_txn if
356360
the signatures for the calling user need to be fetched.
357361
358362
Args:
359-
user_ids (list[str]): the users whose keys are being requested
363+
user_ids: the users whose keys are being requested
360364
361365
Returns:
362-
dict[str, dict[str, dict]]: mapping from user ID to key type to key
363-
data. If a user's cross-signing keys were not found, either
364-
their user ID will not be in the dict, or their user ID will map
365-
to None.
366+
A mapping from user ID to key type to key data. If a user's cross-signing
367+
keys were not found, either their user ID will not be in the dict, or
368+
their user ID will map to None.
366369
367370
"""
368-
return self.db_pool.runInteraction(
371+
return await self.db_pool.runInteraction(
369372
"get_bare_e2e_cross_signing_keys_bulk",
370373
self._get_bare_e2e_cross_signing_keys_bulk_txn,
371374
user_ids,
@@ -588,7 +591,9 @@ def get_device_stream_token(self) -> int:
588591

589592

590593
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
591-
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
594+
async def set_e2e_device_keys(
595+
self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
596+
) -> bool:
592597
"""Stores device keys for a device. Returns whether there was a change
593598
or the keys were already in the database.
594599
"""
@@ -624,12 +629,21 @@ def _set_e2e_device_keys_txn(txn):
624629
log_kv({"message": "Device keys stored."})
625630
return True
626631

627-
return self.db_pool.runInteraction(
632+
return await self.db_pool.runInteraction(
628633
"set_e2e_device_keys", _set_e2e_device_keys_txn
629634
)
630635

631-
def claim_e2e_one_time_keys(self, query_list):
632-
"""Take a list of one time keys out of the database"""
636+
async def claim_e2e_one_time_keys(
637+
self, query_list: Iterable[Tuple[str, str, str]]
638+
) -> Dict[str, Dict[str, Dict[str, bytes]]]:
639+
"""Take a list of one time keys out of the database.
640+
641+
Args:
642+
query_list: An iterable of tuples of (user ID, device ID, algorithm).
643+
644+
Returns:
645+
A map of user ID -> a map device ID -> a map of key ID -> JSON bytes.
646+
"""
633647

634648
@trace
635649
def _claim_e2e_one_time_keys(txn):
@@ -665,11 +679,11 @@ def _claim_e2e_one_time_keys(txn):
665679
)
666680
return result
667681

668-
return self.db_pool.runInteraction(
682+
return await self.db_pool.runInteraction(
669683
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
670684
)
671685

672-
def delete_e2e_keys_by_device(self, user_id, device_id):
686+
async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
673687
def delete_e2e_keys_by_device_txn(txn):
674688
log_kv(
675689
{
@@ -692,7 +706,7 @@ def delete_e2e_keys_by_device_txn(txn):
692706
txn, self.count_e2e_one_time_keys, (user_id, device_id)
693707
)
694708

695-
return self.db_pool.runInteraction(
709+
await self.db_pool.runInteraction(
696710
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
697711
)
698712

0 commit comments

Comments
 (0)