Skip to content

Commit e496c2f

Browse files
committed
Add utilities for generating and using recovery keys
1 parent 0ea9586 commit e496c2f

File tree

6 files changed

+287
-3
lines changed

6 files changed

+287
-3
lines changed

mautrix/client/api/modules/crypto.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from mautrix.types import (
1313
JSON,
1414
ClaimKeysResponse,
15+
CrossSigningKeys,
16+
CrossSigningUsage,
1517
DeviceID,
1618
DeviceKeys,
1719
EncryptionKeyAlgorithm,
@@ -122,6 +124,43 @@ async def upload_keys(
122124
except AttributeError as e:
123125
raise MatrixResponseError("Invalid `one_time_key_counts` field in response.") from e
124126

127+
async def upload_cross_signing_keys(
128+
self,
129+
keys: dict[CrossSigningUsage, CrossSigningKeys],
130+
auth: dict[str, JSON] | None = None,
131+
) -> None:
132+
await self.api.request(
133+
Method.POST,
134+
Path.v3.keys.device_signing.upload,
135+
{f"{usage}_key": key.serialize() for usage, key in keys.items()}
136+
| ({"auth": auth} if auth else {}),
137+
)
138+
139+
async def upload_one_signature(
140+
self,
141+
user_id: UserID,
142+
device_id: DeviceID,
143+
keys: Union[DeviceKeys, CrossSigningKeys],
144+
) -> None:
145+
await self.api.request(
146+
Method.POST, Path.v3.keys.signatures.upload, {user_id: {device_id: keys.serialize()}}
147+
)
148+
# TODO check failures
149+
150+
async def upload_many_signatures(
151+
self,
152+
signatures: dict[UserID, dict[DeviceID, Union[DeviceKeys, CrossSigningKeys]]],
153+
) -> None:
154+
await self.api.request(
155+
Method.POST,
156+
Path.v3.keys.signatures.upload,
157+
{
158+
user_id: {device_id: keys.serialize() for device_id, keys in devices.items()}
159+
for user_id, devices in signatures.items()
160+
},
161+
)
162+
# TODO check failures
163+
125164
async def query_keys(
126165
self,
127166
device_keys: list[UserID] | set[UserID] | dict[UserID, list[DeviceID]],

mautrix/crypto/cross_signing.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# Copyright (c) 2025 Tulir Asokan
2+
#
3+
# This Source Code Form is subject to the terms of the Mozilla Public
4+
# License, v. 2.0. If a copy of the MPL was not distributed with this
5+
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
6+
from ..types import (
7+
JSON,
8+
CrossSigner,
9+
CrossSigningKeys,
10+
CrossSigningUsage,
11+
DeviceIdentity,
12+
EventType,
13+
KeyID,
14+
UserID,
15+
)
16+
from .cross_signing_key import CrossSigningPrivateKeys, CrossSigningPublicKeys, CrossSigningSeeds
17+
from .device_lists import DeviceListMachine
18+
from .signature import sign_olm
19+
from .ssss import Key as SSSSKey
20+
21+
22+
class CrossSigningMachine(DeviceListMachine):
23+
_cross_signing_public_keys: CrossSigningPublicKeys | None
24+
_cross_signing_public_keys_fetched: bool
25+
_cross_signing_private_keys: CrossSigningPrivateKeys | None
26+
27+
async def verify_with_recovery_key(self, recovery_key: str) -> None:
28+
key_id, key_data = await self.ssss.get_default_key_data()
29+
ssss_key = key_data.verify_recovery_key(key_id, recovery_key)
30+
seeds = await self._fetch_cross_signing_keys_from_ssss(ssss_key)
31+
self._import_cross_signing_keys(seeds)
32+
await self.sign_own_device(self.own_identity)
33+
34+
def _import_cross_signing_keys(self, seeds: CrossSigningSeeds) -> None:
35+
self._cross_signing_private_keys = seeds.to_keys()
36+
self._cross_signing_public_keys = self._cross_signing_private_keys.public_keys
37+
38+
async def generate_recovery_key(
39+
self, passphrase: str | None = None, seeds: CrossSigningSeeds | None = None
40+
) -> str:
41+
seeds = seeds or CrossSigningSeeds.generate()
42+
ssss_key = await self.ssss.generate_and_upload_key(passphrase)
43+
await self._upload_cross_signing_keys_to_ssss(ssss_key, seeds)
44+
await self._publish_cross_signing_keys(seeds.to_keys())
45+
await self.ssss.set_default_key_id(ssss_key.id)
46+
await self.sign_own_device(self.own_identity)
47+
return ssss_key.recovery_key
48+
49+
async def _fetch_cross_signing_keys_from_ssss(self, key: SSSSKey) -> CrossSigningSeeds:
50+
return CrossSigningSeeds(
51+
master_key=await self.ssss.get_decrypted_account_data(
52+
EventType.CROSS_SIGNING_MASTER, key
53+
),
54+
user_signing_key=await self.ssss.get_decrypted_account_data(
55+
EventType.CROSS_SIGNING_USER_SIGNING, key
56+
),
57+
self_signing_key=await self.ssss.get_decrypted_account_data(
58+
EventType.CROSS_SIGNING_SELF_SIGNING, key
59+
),
60+
)
61+
62+
async def _upload_cross_signing_keys_to_ssss(
63+
self, key: SSSSKey, seeds: CrossSigningSeeds
64+
) -> None:
65+
await self.ssss.set_encrypted_account_data(
66+
EventType.CROSS_SIGNING_MASTER, seeds.master_key, key
67+
)
68+
await self.ssss.set_encrypted_account_data(
69+
EventType.CROSS_SIGNING_USER_SIGNING, seeds.user_signing_key, key
70+
)
71+
await self.ssss.set_encrypted_account_data(
72+
EventType.CROSS_SIGNING_SELF_SIGNING, seeds.self_signing_key, key
73+
)
74+
75+
async def get_own_cross_signing_public_keys(self) -> CrossSigningPublicKeys | None:
76+
if self._cross_signing_public_keys or self._cross_signing_public_keys_fetched:
77+
return self._cross_signing_public_keys
78+
keys = await self.get_cross_signing_public_keys(self.client.mxid)
79+
self._cross_signing_public_keys_fetched = True
80+
if keys:
81+
self._cross_signing_public_keys = keys
82+
return keys
83+
84+
async def get_cross_signing_public_keys(
85+
self, user_id: UserID
86+
) -> CrossSigningPublicKeys | None:
87+
db_keys = await self.crypto_store.get_cross_signing_keys(user_id)
88+
if CrossSigningUsage.MASTER not in db_keys:
89+
await self._fetch_keys([user_id], include_untracked=True)
90+
db_keys = await self.crypto_store.get_cross_signing_keys(user_id)
91+
if CrossSigningUsage.MASTER not in db_keys:
92+
return None
93+
return CrossSigningPublicKeys(
94+
master_key=db_keys[CrossSigningUsage.MASTER].key,
95+
self_signing_key=(
96+
db_keys[CrossSigningUsage.SELF].key if CrossSigningUsage.SELF in db_keys else None
97+
),
98+
user_signing_key=(
99+
db_keys[CrossSigningUsage.USER].key if CrossSigningUsage.USER in db_keys else None
100+
),
101+
)
102+
103+
async def sign_own_device(self, device: DeviceIdentity) -> None:
104+
full_keys = await self._get_full_device_keys(device)
105+
ssk = self._cross_signing_private_keys.self_signing_key
106+
signature = sign_olm(full_keys, ssk)
107+
full_keys.signatures = {self.client.mxid: {KeyID.ed25519(ssk.public_key): signature}}
108+
await self.client.upload_one_signature(device.user_id, device.device_id, full_keys)
109+
await self.crypto_store.put_signature(
110+
CrossSigner(device.user_id, device.signing_key),
111+
CrossSigner(self.client.mxid, ssk.public_key),
112+
signature,
113+
)
114+
115+
async def _publish_cross_signing_keys(
116+
self,
117+
keys: CrossSigningPrivateKeys,
118+
auth: dict[str, JSON] | None = None,
119+
) -> None:
120+
public = keys.public_keys
121+
master_key = CrossSigningKeys(
122+
user_id=self.client.mxid,
123+
usage=[CrossSigningUsage.MASTER],
124+
keys={KeyID.ed25519(public.master_key): public.master_key},
125+
)
126+
master_key.signatures = {
127+
self.client.mxid: {
128+
KeyID.ed25519(self.client.device_id): sign_olm(master_key, self.account),
129+
}
130+
}
131+
self_key = CrossSigningKeys(
132+
user_id=self.client.mxid,
133+
usage=[CrossSigningUsage.SELF],
134+
keys={KeyID.ed25519(public.self_signing_key): public.self_signing_key},
135+
)
136+
self_key.signatures = {
137+
self.client.mxid: {
138+
KeyID.ed25519(public.master_key): sign_olm(self_key, keys.master_key),
139+
}
140+
}
141+
user_key = CrossSigningKeys(
142+
user_id=self.client.mxid,
143+
usage=[CrossSigningUsage.USER],
144+
keys={KeyID.ed25519(public.user_signing_key): public.user_signing_key},
145+
)
146+
user_key.signatures = {
147+
self.client.mxid: {
148+
KeyID.ed25519(public.master_key): sign_olm(user_key, keys.master_key),
149+
}
150+
}
151+
await self.client.upload_cross_signing_keys(
152+
keys={
153+
CrossSigningUsage.MASTER: master_key,
154+
CrossSigningUsage.SELF: self_key,
155+
CrossSigningUsage.USER: user_key,
156+
},
157+
auth=auth,
158+
)
159+
await self.crypto_store.put_cross_signing_key(
160+
self.client.mxid, CrossSigningUsage.MASTER, public.master_key
161+
)
162+
await self.crypto_store.put_cross_signing_key(
163+
self.client.mxid, CrossSigningUsage.SELF, public.self_signing_key
164+
)
165+
await self.crypto_store.put_cross_signing_key(
166+
self.client.mxid, CrossSigningUsage.USER, public.user_signing_key
167+
)
168+
self._cross_signing_private_keys = keys
169+
self._cross_signing_public_keys = public
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) 2025 Tulir Asokan
2+
#
3+
# This Source Code Form is subject to the terms of the Mozilla Public
4+
# License, v. 2.0. If a copy of the MPL was not distributed with this
5+
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
6+
from typing import NamedTuple
7+
8+
import olm
9+
10+
from mautrix.crypto.ssss.util import cryptorand
11+
from mautrix.types import SigningKey
12+
13+
14+
class CrossSigningPublicKeys(NamedTuple):
15+
master_key: SigningKey
16+
self_signing_key: SigningKey
17+
user_signing_key: SigningKey
18+
19+
20+
class CrossSigningPrivateKeys(NamedTuple):
21+
master_key: olm.PkSigning
22+
self_signing_key: olm.PkSigning
23+
user_signing_key: olm.PkSigning
24+
25+
@property
26+
def public_keys(self) -> CrossSigningPublicKeys:
27+
return CrossSigningPublicKeys(
28+
master_key=self.master_key.public_key,
29+
self_signing_key=self.self_signing_key.public_key,
30+
user_signing_key=self.user_signing_key.public_key,
31+
)
32+
33+
34+
class CrossSigningSeeds(NamedTuple):
35+
master_key: bytes
36+
self_signing_key: bytes
37+
user_signing_key: bytes
38+
39+
def to_keys(self) -> CrossSigningPrivateKeys:
40+
return CrossSigningPrivateKeys(
41+
master_key=olm.PkSigning(self.master_key),
42+
self_signing_key=olm.PkSigning(self.self_signing_key),
43+
user_signing_key=olm.PkSigning(self.user_signing_key),
44+
)
45+
46+
@classmethod
47+
def generate(cls) -> "CrossSigningSeeds":
48+
return cls(
49+
master_key=cryptorand.read(32),
50+
self_signing_key=cryptorand.read(32),
51+
user_signing_key=cryptorand.read(32),
52+
)

mautrix/crypto/device_lists.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,18 @@
2828

2929

3030
class DeviceListMachine(BaseOlmMachine):
31+
@property
32+
def own_identity(self) -> DeviceIdentity:
33+
return DeviceIdentity(
34+
user_id=self.client.mxid,
35+
device_id=self.client.device_id,
36+
identity_key=self.account.identity_key,
37+
signing_key=self.account.signing_key,
38+
trust=TrustState.VERIFIED,
39+
deleted=False,
40+
name="",
41+
)
42+
3143
async def _fetch_keys(
3244
self, users: list[UserID], since: SyncToken = "", include_untracked: bool = False
3345
) -> dict[UserID, dict[DeviceID, DeviceIdentity]]:
@@ -220,6 +232,12 @@ async def _store_cross_signing_keys(self, resp: QueryKeysResponse, user_id: User
220232
else:
221233
self.log.warning(f"Invalid signature from {signing_key_log} for {key_id}")
222234

235+
async def _get_full_device_keys(self, device: DeviceIdentity) -> DeviceKeys:
236+
resp = await self.client.query_keys({device.user_id: [device.device_id]})
237+
keys = resp.device_keys[device.user_id][device.device_id]
238+
await self._validate_device(device.user_id, device.device_id, keys, device)
239+
return keys
240+
223241
async def get_or_fetch_device(
224242
self, user_id: UserID, device_id: DeviceID
225243
) -> DeviceIdentity | None:

mautrix/crypto/machine.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from mautrix.util.logging import TraceLogger
3333

3434
from .account import OlmAccount
35+
from .cross_signing import CrossSigningMachine
3536
from .decrypt_megolm import MegolmDecryptionMachine
3637
from .encrypt_megolm import MegolmEncryptionMachine
3738
from .key_request import KeyRequestingMachine
@@ -47,6 +48,7 @@ class OlmMachine(
4748
OlmUnwedgingMachine,
4849
KeySharingMachine,
4950
KeyRequestingMachine,
51+
CrossSigningMachine,
5052
):
5153
"""
5254
OlmMachine is the main class for handling things related to Matrix end-to-end encryption with
@@ -99,6 +101,10 @@ def __init__(
99101
self._prev_unwedge = {}
100102
self._cs_fetch_attempted = set()
101103

104+
self._cross_signing_public_keys = None
105+
self._cross_signing_public_keys_fetched = False
106+
self._cross_signing_private_keys = None
107+
102108
self.client.add_event_handler(
103109
cli.InternalEventType.DEVICE_OTK_COUNT, self.handle_otk_count, wait_sync=True
104110
)

mautrix/types/crypto.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ def curve25519(self) -> Optional[IdentityKey]:
4747

4848

4949
class CrossSigningUsage(ExtensibleEnum):
50-
MASTER = "master"
51-
SELF = "self_signing"
52-
USER = "user_signing"
50+
MASTER: "CrossSigningUsage" = "master"
51+
SELF: "CrossSigningUsage" = "self_signing"
52+
USER: "CrossSigningUsage" = "user_signing"
5353

5454

5555
@dataclass

0 commit comments

Comments
 (0)