Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 8e97394

Browse files
authored
Add unstable /keys/claim endpoint which always returns fallback keys. (#15462)
It can be useful to always return the fallback key when attempting to claim keys. This adds an unstable endpoint for `/keys/claim` which always returns fallback keys in addition to one-time-keys. The fallback key(s) are not marked as "used" unless there are no corresponding OTKs. This is currently defined in MSC3983 (although likely to be split out to a separate MSC). The endpoint shape may change or be requested differently (i.e. a keyword parameter on the current endpoint), but the core logic should be reasonable.
1 parent b39b02c commit 8e97394

File tree

9 files changed

+371
-29
lines changed

9 files changed

+371
-29
lines changed

changelog.d/15462.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Update support for [MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983) to allow always returning fallback-keys in a `/keys/claim` request.

synapse/federation/federation_server.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,15 +1005,17 @@ async def on_query_user_devices(
10051005

10061006
@trace
10071007
async def on_claim_client_keys(
1008-
self, origin: str, content: JsonDict
1008+
self, origin: str, content: JsonDict, always_include_fallback_keys: bool
10091009
) -> Dict[str, Any]:
10101010
query = []
10111011
for user_id, device_keys in content.get("one_time_keys", {}).items():
10121012
for device_id, algorithm in device_keys.items():
10131013
query.append((user_id, device_id, algorithm))
10141014

10151015
log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
1016-
results = await self._e2e_keys_handler.claim_local_one_time_keys(query)
1016+
results = await self._e2e_keys_handler.claim_local_one_time_keys(
1017+
query, always_include_fallback_keys=always_include_fallback_keys
1018+
)
10171019

10181020
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
10191021
for result in results:

synapse/federation/transport/server/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from synapse.federation.transport.server.federation import (
2626
FEDERATION_SERVLET_CLASSES,
2727
FederationAccountStatusServlet,
28+
FederationUnstableClientKeysClaimServlet,
2829
)
2930
from synapse.http.server import HttpServer, JsonResource
3031
from synapse.http.servlet import (
@@ -298,6 +299,11 @@ def register_servlets(
298299
and not hs.config.experimental.msc3720_enabled
299300
):
300301
continue
302+
if (
303+
servletclass == FederationUnstableClientKeysClaimServlet
304+
and not hs.config.experimental.msc3983_appservice_otk_claims
305+
):
306+
continue
301307

302308
servletclass(
303309
hs=hs,

synapse/federation/transport/server/federation.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,28 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet):
577577
async def on_POST(
578578
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
579579
) -> Tuple[int, JsonDict]:
580-
response = await self.handler.on_claim_client_keys(origin, content)
580+
response = await self.handler.on_claim_client_keys(
581+
origin, content, always_include_fallback_keys=False
582+
)
583+
return 200, response
584+
585+
586+
class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
587+
"""
588+
Identical to the stable endpoint (FederationClientKeysClaimServlet) except it
589+
always includes fallback keys in the response.
590+
"""
591+
592+
PREFIX = FEDERATION_UNSTABLE_PREFIX
593+
PATH = "/user/keys/claim"
594+
CATEGORY = "Federation requests"
595+
596+
async def on_POST(
597+
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
598+
) -> Tuple[int, JsonDict]:
599+
response = await self.handler.on_claim_client_keys(
600+
origin, content, always_include_fallback_keys=True
601+
)
581602
return 200, response
582603

583604

synapse/handlers/appservice.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -842,9 +842,7 @@ async def _check_user_exists(self, user_id: str) -> bool:
842842

843843
async def claim_e2e_one_time_keys(
844844
self, query: Iterable[Tuple[str, str, str]]
845-
) -> Tuple[
846-
Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]], List[Tuple[str, str, str]]
847-
]:
845+
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
848846
"""Claim one time keys from application services.
849847
850848
Users which are exclusively owned by an application service are sent a
@@ -856,7 +854,7 @@ async def claim_e2e_one_time_keys(
856854
857855
Returns:
858856
A tuple of:
859-
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
857+
A map of user ID -> a map device ID -> a map of key ID -> JSON.
860858
861859
A copy of the input which has not been fulfilled (either because
862860
they are not appservice users or the appservice does not support
@@ -897,12 +895,11 @@ async def claim_e2e_one_time_keys(
897895
)
898896

899897
# Patch together the results -- they are all independent (since they
900-
# require exclusive control over the users). They get returned as a list
901-
# and the caller combines them.
902-
claimed_keys: List[Dict[str, Dict[str, Dict[str, JsonDict]]]] = []
898+
# require exclusive control over the users, which is the outermost key).
899+
claimed_keys: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
903900
for success, result in results:
904901
if success:
905-
claimed_keys.append(result[0])
902+
claimed_keys.update(result[0])
906903
missing.extend(result[1])
907904

908905
return claimed_keys, missing

synapse/handlers/e2e_keys.py

Lines changed: 63 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,9 @@ async def on_federation_query_client_keys(
563563
return ret
564564

565565
async def claim_local_one_time_keys(
566-
self, local_query: List[Tuple[str, str, str]]
566+
self,
567+
local_query: List[Tuple[str, str, str]],
568+
always_include_fallback_keys: bool,
567569
) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
568570
"""Claim one time keys for local users.
569571
@@ -573,6 +575,7 @@ async def claim_local_one_time_keys(
573575
574576
Args:
575577
local_query: An iterable of tuples of (user ID, device ID, algorithm).
578+
always_include_fallback_keys: True to always include fallback keys.
576579
577580
Returns:
578581
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
@@ -583,24 +586,73 @@ async def claim_local_one_time_keys(
583586
# If the application services have not provided any keys via the C-S
584587
# API, query it directly for one-time keys.
585588
if self._query_appservices_for_otks:
589+
# TODO Should this query for fallback keys of uploaded OTKs if
590+
# always_include_fallback_keys is True? The MSC is ambiguous.
586591
(
587592
appservice_results,
588593
not_found,
589594
) = await self._appservice_handler.claim_e2e_one_time_keys(not_found)
590595
else:
591-
appservice_results = []
596+
appservice_results = {}
597+
598+
# Calculate which user ID / device ID / algorithm tuples to get fallback
599+
# keys for. This can be either only missing results *or* all results
600+
# (which don't already have a fallback key).
601+
if always_include_fallback_keys:
602+
# Build the fallback query as any part of the original query where
603+
# the appservice didn't respond with a fallback key.
604+
fallback_query = []
605+
606+
# Iterate each item in the original query and search the results
607+
# from the appservice for that user ID / device ID. If it is found,
608+
# check if any of the keys match the requested algorithm & are a
609+
# fallback key.
610+
for user_id, device_id, algorithm in local_query:
611+
# Check if the appservice responded for this query.
612+
as_result = appservice_results.get(user_id, {}).get(device_id, {})
613+
found_otk = False
614+
for key_id, key_json in as_result.items():
615+
if key_id.startswith(f"{algorithm}:"):
616+
# A OTK or fallback key was found for this query.
617+
found_otk = True
618+
# A fallback key was found for this query, no need to
619+
# query further.
620+
if key_json.get("fallback", False):
621+
break
622+
623+
else:
624+
# No fallback key was found from appservices, query for it.
625+
# Only mark the fallback key as used if no OTK was found
626+
# (from either the database or appservices).
627+
mark_as_used = not found_otk and not any(
628+
key_id.startswith(f"{algorithm}:")
629+
for key_id in otk_results.get(user_id, {})
630+
.get(device_id, {})
631+
.keys()
632+
)
633+
fallback_query.append((user_id, device_id, algorithm, mark_as_used))
634+
635+
else:
636+
# All fallback keys get marked as used.
637+
fallback_query = [
638+
(user_id, device_id, algorithm, True)
639+
for user_id, device_id, algorithm in not_found
640+
]
592641

593642
# For each user that does not have a one-time keys available, see if
594643
# there is a fallback key.
595-
fallback_results = await self.store.claim_e2e_fallback_keys(not_found)
644+
fallback_results = await self.store.claim_e2e_fallback_keys(fallback_query)
596645

597646
# Return the results in order, each item from the input query should
598647
# only appear once in the combined list.
599-
return (otk_results, *appservice_results, fallback_results)
648+
return (otk_results, appservice_results, fallback_results)
600649

601650
@trace
602651
async def claim_one_time_keys(
603-
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
652+
self,
653+
query: Dict[str, Dict[str, Dict[str, str]]],
654+
timeout: Optional[int],
655+
always_include_fallback_keys: bool,
604656
) -> JsonDict:
605657
local_query: List[Tuple[str, str, str]] = []
606658
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
@@ -617,15 +669,19 @@ async def claim_one_time_keys(
617669
set_tag("local_key_query", str(local_query))
618670
set_tag("remote_key_query", str(remote_queries))
619671

620-
results = await self.claim_local_one_time_keys(local_query)
672+
results = await self.claim_local_one_time_keys(
673+
local_query, always_include_fallback_keys
674+
)
621675

622676
# A map of user ID -> device ID -> key ID -> key.
623677
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
624678
for result in results:
625679
for user_id, device_keys in result.items():
626680
for device_id, keys in device_keys.items():
627681
for key_id, key in keys.items():
628-
json_result.setdefault(user_id, {})[device_id] = {key_id: key}
682+
json_result.setdefault(user_id, {}).setdefault(
683+
device_id, {}
684+
).update({key_id: key})
629685

630686
# Remote failures.
631687
failures: Dict[str, JsonDict] = {}

synapse/rest/client/keys.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616

1717
import logging
18+
import re
1819
from typing import TYPE_CHECKING, Any, Optional, Tuple
1920

2021
from synapse.api.errors import InvalidAPICallError, SynapseError
@@ -288,7 +289,33 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
288289
await self.auth.get_user_by_req(request, allow_guest=True)
289290
timeout = parse_integer(request, "timeout", 10 * 1000)
290291
body = parse_json_object_from_request(request)
291-
result = await self.e2e_keys_handler.claim_one_time_keys(body, timeout)
292+
result = await self.e2e_keys_handler.claim_one_time_keys(
293+
body, timeout, always_include_fallback_keys=False
294+
)
295+
return 200, result
296+
297+
298+
class UnstableOneTimeKeyServlet(RestServlet):
299+
"""
300+
Identical to the stable endpoint (OneTimeKeyServlet) except it always includes
301+
fallback keys in the response.
302+
"""
303+
304+
PATTERNS = [re.compile(r"^/_matrix/client/unstable/org.matrix.msc3983/keys/claim$")]
305+
CATEGORY = "Encryption requests"
306+
307+
def __init__(self, hs: "HomeServer"):
308+
super().__init__()
309+
self.auth = hs.get_auth()
310+
self.e2e_keys_handler = hs.get_e2e_keys_handler()
311+
312+
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
313+
await self.auth.get_user_by_req(request, allow_guest=True)
314+
timeout = parse_integer(request, "timeout", 10 * 1000)
315+
body = parse_json_object_from_request(request)
316+
result = await self.e2e_keys_handler.claim_one_time_keys(
317+
body, timeout, always_include_fallback_keys=True
318+
)
292319
return 200, result
293320

294321

@@ -394,6 +421,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
394421
KeyQueryServlet(hs).register(http_server)
395422
KeyChangesServlet(hs).register(http_server)
396423
OneTimeKeyServlet(hs).register(http_server)
424+
if hs.config.experimental.msc3983_appservice_otk_claims:
425+
UnstableOneTimeKeyServlet(hs).register(http_server)
397426
if hs.config.worker.worker_app is None:
398427
SigningKeyUploadServlet(hs).register(http_server)
399428
SignaturesUploadServlet(hs).register(http_server)

synapse/storage/databases/main/end_to_end_keys.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,18 +1149,19 @@ def _claim_e2e_one_time_key_returning(
11491149
return results, missing
11501150

11511151
async def claim_e2e_fallback_keys(
1152-
self, query_list: Iterable[Tuple[str, str, str]]
1152+
self, query_list: Iterable[Tuple[str, str, str, bool]]
11531153
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
11541154
"""Take a list of fallback keys out of the database.
11551155
11561156
Args:
1157-
query_list: An iterable of tuples of (user ID, device ID, algorithm).
1157+
query_list: An iterable of tuples of
1158+
(user ID, device ID, algorithm, whether the key should be marked as used).
11581159
11591160
Returns:
11601161
A map of user ID -> a map device ID -> a map of key ID -> JSON.
11611162
"""
11621163
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
1163-
for user_id, device_id, algorithm in query_list:
1164+
for user_id, device_id, algorithm, mark_as_used in query_list:
11641165
row = await self.db_pool.simple_select_one(
11651166
table="e2e_fallback_keys_json",
11661167
keyvalues={
@@ -1180,7 +1181,7 @@ async def claim_e2e_fallback_keys(
11801181
used = row["used"]
11811182

11821183
# Mark fallback key as used if not already.
1183-
if not used:
1184+
if not used and mark_as_used:
11841185
await self.db_pool.simple_update_one(
11851186
table="e2e_fallback_keys_json",
11861187
keyvalues={

0 commit comments

Comments
 (0)