Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit de981ae

Browse files
David Robertsonclokep
andauthored
Claim local one-time-keys in bulk (#16565)
Co-authored-by: Patrick Cloke <[email protected]>
1 parent 91aa52c commit de981ae

File tree

4 files changed

+308
-114
lines changed

4 files changed

+308
-114
lines changed

changelog.d/16565.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Improve the performance of claiming encryption keys.

synapse/handlers/e2e_keys.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,16 @@ async def claim_client_keys(destination: str) -> None:
753753
async def upload_keys_for_user(
754754
self, user_id: str, device_id: str, keys: JsonDict
755755
) -> JsonDict:
756+
"""
757+
Args:
758+
user_id: user whose keys are being uploaded.
759+
device_id: device whose keys are being uploaded.
760+
keys: the body of a /keys/upload request.
761+
762+
Returns a dictionary with one field:
763+
"one_time_keys": A mapping from algorithm to number of keys for that
764+
algorithm, including those previously persisted.
765+
"""
756766
# This can only be called from the main process.
757767
assert isinstance(self.device_handler, DeviceHandler)
758768

synapse/storage/databases/main/end_to_end_keys.py

Lines changed: 139 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,7 +1111,7 @@ def get_device_stream_token(self) -> int:
11111111
...
11121112

11131113
async def claim_e2e_one_time_keys(
1114-
self, query_list: Iterable[Tuple[str, str, str, int]]
1114+
self, query_list: Collection[Tuple[str, str, str, int]]
11151115
) -> Tuple[
11161116
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
11171117
]:
@@ -1121,131 +1121,63 @@ async def claim_e2e_one_time_keys(
11211121
query_list: An iterable of tuples of (user ID, device ID, algorithm).
11221122
11231123
Returns:
1124-
A tuple pf:
1124+
A tuple (results, missing) of:
11251125
A map of user ID -> a map device ID -> a map of key ID -> JSON.
11261126
1127-
A copy of the input which has not been fulfilled.
1127+
A copy of the input which has not been fulfilled. The returned counts
1128+
may be less than the input counts. In this case, the returned counts
1129+
are the number of claims that were not fulfilled.
11281130
"""
1129-
1130-
@trace
1131-
def _claim_e2e_one_time_key_simple(
1132-
txn: LoggingTransaction,
1133-
user_id: str,
1134-
device_id: str,
1135-
algorithm: str,
1136-
count: int,
1137-
) -> List[Tuple[str, str]]:
1138-
"""Claim OTK for device for DBs that don't support RETURNING.
1139-
1140-
Returns:
1141-
A tuple of key name (algorithm + key ID) and key JSON, if an
1142-
OTK was found.
1143-
"""
1144-
1145-
sql = """
1146-
SELECT key_id, key_json FROM e2e_one_time_keys_json
1147-
WHERE user_id = ? AND device_id = ? AND algorithm = ?
1148-
LIMIT ?
1149-
"""
1150-
1151-
txn.execute(sql, (user_id, device_id, algorithm, count))
1152-
otk_rows = list(txn)
1153-
if not otk_rows:
1154-
return []
1155-
1156-
self.db_pool.simple_delete_many_txn(
1157-
txn,
1158-
table="e2e_one_time_keys_json",
1159-
column="key_id",
1160-
values=[otk_row[0] for otk_row in otk_rows],
1161-
keyvalues={
1162-
"user_id": user_id,
1163-
"device_id": device_id,
1164-
"algorithm": algorithm,
1165-
},
1166-
)
1167-
self._invalidate_cache_and_stream(
1168-
txn, self.count_e2e_one_time_keys, (user_id, device_id)
1169-
)
1170-
1171-
return [
1172-
(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
1173-
]
1174-
1175-
@trace
1176-
def _claim_e2e_one_time_key_returning(
1177-
txn: LoggingTransaction,
1178-
user_id: str,
1179-
device_id: str,
1180-
algorithm: str,
1181-
count: int,
1182-
) -> List[Tuple[str, str]]:
1183-
"""Claim OTK for device for DBs that support RETURNING.
1184-
1185-
Returns:
1186-
A tuple of key name (algorithm + key ID) and key JSON, if an
1187-
OTK was found.
1188-
"""
1189-
1190-
# We can use RETURNING to do the fetch and DELETE in once step.
1191-
sql = """
1192-
DELETE FROM e2e_one_time_keys_json
1193-
WHERE user_id = ? AND device_id = ? AND algorithm = ?
1194-
AND key_id IN (
1195-
SELECT key_id FROM e2e_one_time_keys_json
1196-
WHERE user_id = ? AND device_id = ? AND algorithm = ?
1197-
LIMIT ?
1198-
)
1199-
RETURNING key_id, key_json
1200-
"""
1201-
1202-
txn.execute(
1203-
sql,
1204-
(user_id, device_id, algorithm, user_id, device_id, algorithm, count),
1205-
)
1206-
otk_rows = list(txn)
1207-
if not otk_rows:
1208-
return []
1209-
1210-
self._invalidate_cache_and_stream(
1211-
txn, self.count_e2e_one_time_keys, (user_id, device_id)
1212-
)
1213-
1214-
return [
1215-
(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
1216-
]
1217-
12181131
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
12191132
missing: List[Tuple[str, str, str, int]] = []
1220-
for user_id, device_id, algorithm, count in query_list:
1221-
if self.database_engine.supports_returning:
1222-
# If we support RETURNING clause we can use a single query that
1223-
# allows us to use autocommit mode.
1224-
_claim_e2e_one_time_key = _claim_e2e_one_time_key_returning
1225-
db_autocommit = True
1226-
else:
1227-
_claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
1228-
db_autocommit = False
1133+
if isinstance(self.database_engine, PostgresEngine):
1134+
# If we can use execute_values we can use a single batch query
1135+
# in autocommit mode.
1136+
unfulfilled_claim_counts: Dict[Tuple[str, str, str], int] = {}
1137+
for user_id, device_id, algorithm, count in query_list:
1138+
unfulfilled_claim_counts[user_id, device_id, algorithm] = count
12291139

1230-
claim_rows = await self.db_pool.runInteraction(
1140+
bulk_claims = await self.db_pool.runInteraction(
12311141
"claim_e2e_one_time_keys",
1232-
_claim_e2e_one_time_key,
1233-
user_id,
1234-
device_id,
1235-
algorithm,
1236-
count,
1237-
db_autocommit=db_autocommit,
1142+
self._claim_e2e_one_time_keys_bulk,
1143+
query_list,
1144+
db_autocommit=True,
12381145
)
1239-
if claim_rows:
1146+
1147+
for user_id, device_id, algorithm, key_id, key_json in bulk_claims:
12401148
device_results = results.setdefault(user_id, {}).setdefault(
12411149
device_id, {}
12421150
)
1243-
for claim_row in claim_rows:
1244-
device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
1151+
device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
1152+
unfulfilled_claim_counts[(user_id, device_id, algorithm)] -= 1
1153+
12451154
# Did we get enough OTKs?
1246-
count -= len(claim_rows)
1247-
if count:
1248-
missing.append((user_id, device_id, algorithm, count))
1155+
missing = [
1156+
(user, device, alg, count)
1157+
for (user, device, alg), count in unfulfilled_claim_counts.items()
1158+
if count > 0
1159+
]
1160+
else:
1161+
for user_id, device_id, algorithm, count in query_list:
1162+
claim_rows = await self.db_pool.runInteraction(
1163+
"claim_e2e_one_time_keys",
1164+
self._claim_e2e_one_time_key_simple,
1165+
user_id,
1166+
device_id,
1167+
algorithm,
1168+
count,
1169+
db_autocommit=False,
1170+
)
1171+
if claim_rows:
1172+
device_results = results.setdefault(user_id, {}).setdefault(
1173+
device_id, {}
1174+
)
1175+
for claim_row in claim_rows:
1176+
device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
1177+
# Did we get enough OTKs?
1178+
count -= len(claim_rows)
1179+
if count:
1180+
missing.append((user_id, device_id, algorithm, count))
12491181

12501182
return results, missing
12511183

@@ -1362,6 +1294,99 @@ async def _claim_e2e_fallback_keys_simple(
13621294

13631295
return results
13641296

1297+
@trace
1298+
def _claim_e2e_one_time_key_simple(
1299+
self,
1300+
txn: LoggingTransaction,
1301+
user_id: str,
1302+
device_id: str,
1303+
algorithm: str,
1304+
count: int,
1305+
) -> List[Tuple[str, str]]:
1306+
"""Claim OTK for device for DBs that don't support RETURNING.
1307+
1308+
Returns:
1309+
A tuple of key name (algorithm + key ID) and key JSON, if an
1310+
OTK was found.
1311+
"""
1312+
1313+
sql = """
1314+
SELECT key_id, key_json FROM e2e_one_time_keys_json
1315+
WHERE user_id = ? AND device_id = ? AND algorithm = ?
1316+
LIMIT ?
1317+
"""
1318+
1319+
txn.execute(sql, (user_id, device_id, algorithm, count))
1320+
otk_rows = list(txn)
1321+
if not otk_rows:
1322+
return []
1323+
1324+
self.db_pool.simple_delete_many_txn(
1325+
txn,
1326+
table="e2e_one_time_keys_json",
1327+
column="key_id",
1328+
values=[otk_row[0] for otk_row in otk_rows],
1329+
keyvalues={
1330+
"user_id": user_id,
1331+
"device_id": device_id,
1332+
"algorithm": algorithm,
1333+
},
1334+
)
1335+
self._invalidate_cache_and_stream(
1336+
txn, self.count_e2e_one_time_keys, (user_id, device_id)
1337+
)
1338+
1339+
return [(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows]
1340+
1341+
@trace
1342+
def _claim_e2e_one_time_keys_bulk(
1343+
self,
1344+
txn: LoggingTransaction,
1345+
query_list: Iterable[Tuple[str, str, str, int]],
1346+
) -> List[Tuple[str, str, str, str, str]]:
1347+
"""Bulk claim OTKs, for DBs that support DELETE FROM... RETURNING.
1348+
1349+
Args:
1350+
query_list: Collection of tuples (user_id, device_id, algorithm, count)
1351+
as passed to claim_e2e_one_time_keys.
1352+
1353+
Returns:
1354+
A list of tuples (user_id, device_id, algorithm, key_id, key_json)
1355+
for each OTK claimed.
1356+
"""
1357+
sql = """
1358+
WITH claims(user_id, device_id, algorithm, claim_count) AS (
1359+
VALUES ?
1360+
), ranked_keys AS (
1361+
SELECT
1362+
user_id, device_id, algorithm, key_id, claim_count,
1363+
ROW_NUMBER() OVER (PARTITION BY (user_id, device_id, algorithm)) AS r
1364+
FROM e2e_one_time_keys_json
1365+
JOIN claims USING (user_id, device_id, algorithm)
1366+
)
1367+
DELETE FROM e2e_one_time_keys_json k
1368+
WHERE (user_id, device_id, algorithm, key_id) IN (
1369+
SELECT user_id, device_id, algorithm, key_id
1370+
FROM ranked_keys
1371+
WHERE r <= claim_count
1372+
)
1373+
RETURNING user_id, device_id, algorithm, key_id, key_json;
1374+
"""
1375+
otk_rows = cast(
1376+
List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list)
1377+
)
1378+
1379+
seen_user_device: Set[Tuple[str, str]] = set()
1380+
for user_id, device_id, _, _, _ in otk_rows:
1381+
if (user_id, device_id) in seen_user_device:
1382+
continue
1383+
seen_user_device.add((user_id, device_id))
1384+
self._invalidate_cache_and_stream(
1385+
txn, self.count_e2e_one_time_keys, (user_id, device_id)
1386+
)
1387+
1388+
return otk_rows
1389+
13651390

13661391
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
13671392
def __init__(

0 commit comments

Comments
 (0)