Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit fdce83e

Browse files
author
David Robertson
authored
Claim fallback keys in bulk (#16570)
1 parent a3f6200 commit fdce83e

File tree

5 files changed

+162
-0
lines changed

5 files changed

+162
-0
lines changed

changelog.d/16570.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Improve the performance of claiming encryption keys.

synapse/handlers/e2e_keys.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,20 @@ async def claim_one_time_keys(
659659
timeout: Optional[int],
660660
always_include_fallback_keys: bool,
661661
) -> JsonDict:
662+
"""
663+
Args:
664+
query: A chain of maps from (user_id, device_id, algorithm) to the requested
665+
number of keys to claim.
666+
user: The user who is claiming these keys.
667+
timeout: How long to wait for any federation key claim requests before
668+
giving up.
669+
always_include_fallback_keys: always include a fallback key for local users'
670+
devices, even if we managed to claim a one-time-key.
671+
672+
Returns: a heterogeneous dict with two keys:
673+
one_time_keys: chain of maps user ID -> device ID -> key ID -> key.
674+
failures: map from remote destination to a JsonDict describing the error.
675+
"""
662676
local_query: List[Tuple[str, str, str, int]] = []
663677
remote_queries: Dict[str, Dict[str, Dict[str, Dict[str, int]]]] = {}
664678

synapse/storage/database.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,16 @@ def execute(self, sql: str, parameters: SQLQueryParameters = ()) -> None:
420420
self._do_execute(self.txn.execute, sql, parameters)
421421

422422
def executemany(self, sql: str, *args: Any) -> None:
423+
"""Repeatedly execute the same piece of SQL with different parameters.
424+
425+
See https://peps.python.org/pep-0249/#executemany. Note in particular that
426+
427+
> Use of this method for an operation which produces one or more result sets
428+
> constitutes undefined behavior
429+
430+
so you can't use this for e.g. a SELECT, an UPDATE ... RETURNING, or a
431+
DELETE FROM... RETURNING.
432+
"""
423433
# TODO: we should add a type for *args here. Looking at Cursor.executemany
424434
# and DBAPI2 it ought to be Sequence[_Parameter], but we pass in
425435
# Iterable[Iterable[Any]] in execute_batch and execute_values above, which mypy

synapse/storage/databases/main/end_to_end_keys.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
Mapping,
2525
Optional,
2626
Sequence,
27+
Set,
2728
Tuple,
2829
Union,
2930
cast,
@@ -1260,6 +1261,65 @@ async def claim_e2e_fallback_keys(
12601261
Returns:
12611262
A map of user ID -> a map device ID -> a map of key ID -> JSON.
12621263
"""
1264+
if isinstance(self.database_engine, PostgresEngine):
1265+
return await self.db_pool.runInteraction(
1266+
"_claim_e2e_fallback_keys_bulk",
1267+
self._claim_e2e_fallback_keys_bulk_txn,
1268+
query_list,
1269+
db_autocommit=True,
1270+
)
1271+
# Use an UPDATE FROM... RETURNING combined with a VALUES block to do
1272+
# everything in one query. Note: this is also supported in SQLite 3.33.0,
1273+
# (see https://www.sqlite.org/lang_update.html#update_from), but we do not
1274+
# have an equivalent of psycopg2's execute_values to do this in one query.
1275+
else:
1276+
return await self._claim_e2e_fallback_keys_simple(query_list)
1277+
1278+
def _claim_e2e_fallback_keys_bulk_txn(
1279+
self,
1280+
txn: LoggingTransaction,
1281+
query_list: Iterable[Tuple[str, str, str, bool]],
1282+
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
1283+
"""Efficient implementation of claim_e2e_fallback_keys for Postgres.
1284+
1285+
Safe to autocommit: this is a single query.
1286+
"""
1287+
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
1288+
1289+
sql = """
1290+
WITH claims(user_id, device_id, algorithm, mark_as_used) AS (
1291+
VALUES ?
1292+
)
1293+
UPDATE e2e_fallback_keys_json k
1294+
SET used = used OR mark_as_used
1295+
FROM claims
1296+
WHERE (k.user_id, k.device_id, k.algorithm) = (claims.user_id, claims.device_id, claims.algorithm)
1297+
RETURNING k.user_id, k.device_id, k.algorithm, k.key_id, k.key_json;
1298+
"""
1299+
claimed_keys = cast(
1300+
List[Tuple[str, str, str, str, str]],
1301+
txn.execute_values(sql, query_list),
1302+
)
1303+
1304+
seen_user_device: Set[Tuple[str, str]] = set()
1305+
for user_id, device_id, algorithm, key_id, key_json in claimed_keys:
1306+
device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
1307+
device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
1308+
1309+
if (user_id, device_id) in seen_user_device:
1310+
continue
1311+
seen_user_device.add((user_id, device_id))
1312+
self._invalidate_cache_and_stream(
1313+
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
1314+
)
1315+
1316+
return results
1317+
1318+
async def _claim_e2e_fallback_keys_simple(
1319+
self,
1320+
query_list: Iterable[Tuple[str, str, str, bool]],
1321+
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
1322+
"""Naive, inefficient implementation of claim_e2e_fallback_keys for SQLite."""
12631323
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
12641324
for user_id, device_id, algorithm, mark_as_used in query_list:
12651325
row = await self.db_pool.simple_select_one(

tests/handlers/test_e2e_keys.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,83 @@ def test_fallback_key(self) -> None:
322322
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
323323
)
324324

325+
def test_fallback_key_bulk(self) -> None:
326+
"""Like test_fallback_key, but claims multiple keys in one handler call."""
327+
alice = f"@alice:{self.hs.hostname}"
328+
brian = f"@brian:{self.hs.hostname}"
329+
chris = f"@chris:{self.hs.hostname}"
330+
331+
# Have three users upload fallback keys for two devices.
332+
fallback_keys = {
333+
alice: {
334+
"alice_dev_1": {"alg1:k1": "fallback_key1"},
335+
"alice_dev_2": {"alg2:k2": "fallback_key2"},
336+
},
337+
brian: {
338+
"brian_dev_1": {"alg1:k3": "fallback_key3"},
339+
"brian_dev_2": {"alg2:k4": "fallback_key4"},
340+
},
341+
chris: {
342+
"chris_dev_1": {"alg1:k5": "fallback_key5"},
343+
"chris_dev_2": {"alg2:k6": "fallback_key6"},
344+
},
345+
}
346+
347+
for user_id, devices in fallback_keys.items():
348+
for device_id, key_dict in devices.items():
349+
self.get_success(
350+
self.handler.upload_keys_for_user(
351+
user_id,
352+
device_id,
353+
{"fallback_keys": key_dict},
354+
)
355+
)
356+
357+
# Each device should have an unused fallback key.
358+
for user_id, devices in fallback_keys.items():
359+
for device_id in devices:
360+
fallback_res = self.get_success(
361+
self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
362+
)
363+
expected_algorithm_name = f"alg{device_id[-1]}"
364+
self.assertEqual(fallback_res, [expected_algorithm_name])
365+
366+
# Claim the fallback key for one device per user.
367+
claim_res = self.get_success(
368+
self.handler.claim_one_time_keys(
369+
{
370+
alice: {"alice_dev_1": {"alg1": 1}},
371+
brian: {"brian_dev_2": {"alg2": 1}},
372+
chris: {"chris_dev_2": {"alg2": 1}},
373+
},
374+
self.requester,
375+
timeout=None,
376+
always_include_fallback_keys=False,
377+
)
378+
)
379+
expected_claims = {
380+
alice: {"alice_dev_1": {"alg1:k1": "fallback_key1"}},
381+
brian: {"brian_dev_2": {"alg2:k4": "fallback_key4"}},
382+
chris: {"chris_dev_2": {"alg2:k6": "fallback_key6"}},
383+
}
384+
self.assertEqual(
385+
claim_res,
386+
{"failures": {}, "one_time_keys": expected_claims},
387+
)
388+
389+
for user_id, devices in fallback_keys.items():
390+
for device_id in devices:
391+
fallback_res = self.get_success(
392+
self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
393+
)
394+
# Claimed fallback keys should no longer show up as unused.
395+
# Unclaimed fallback keys should still be unused.
396+
if device_id in expected_claims[user_id]:
397+
self.assertEqual(fallback_res, [])
398+
else:
399+
expected_algorithm_name = f"alg{device_id[-1]}"
400+
self.assertEqual(fallback_res, [expected_algorithm_name])
401+
325402
def test_fallback_key_always_returned(self) -> None:
326403
local_user = "@boris:" + self.hs.hostname
327404
device_id = "xyz"

0 commit comments

Comments
 (0)