Skip to content

Commit 0ea9586

Browse files
committed
Use more dataclasses for uploading e2ee device keys
1 parent 71f4fae commit 0ea9586

File tree

5 files changed

+59
-21
lines changed

5 files changed

+59
-21
lines changed

mautrix/client/api/modules/crypto.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
66
from __future__ import annotations
77

8-
from typing import Any
8+
from typing import Any, Union
99

1010
from mautrix.api import Method, Path
1111
from mautrix.errors import MatrixResponseError
1212
from mautrix.types import (
13+
JSON,
1314
ClaimKeysResponse,
1415
DeviceID,
16+
DeviceKeys,
1517
EncryptionKeyAlgorithm,
1618
EventType,
1719
QueryKeysResponse,
@@ -82,7 +84,7 @@ async def send_to_one_device(
8284
async def upload_keys(
8385
self,
8486
one_time_keys: dict[str, Any] | None = None,
85-
device_keys: dict[str, Any] | None = None,
87+
device_keys: DeviceKeys | dict[str, Any] | None = None,
8688
) -> dict[EncryptionKeyAlgorithm, int]:
8789
"""
8890
Publishes end-to-end encryption keys for the device.
@@ -102,8 +104,12 @@ async def upload_keys(
102104
"""
103105
data = {}
104106
if device_keys:
107+
if isinstance(device_keys, Serializable):
108+
device_keys = device_keys.serialize()
105109
data["device_keys"] = device_keys
106110
if one_time_keys:
111+
if isinstance(one_time_keys, Serializable):
112+
one_time_keys = one_time_keys.serialize()
107113
data["one_time_keys"] = one_time_keys
108114
resp = await self.api.request(Method.POST, Path.v3.keys.upload, data)
109115
try:

mautrix/crypto/account.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,17 @@
1010

1111
from mautrix.types import (
1212
DeviceID,
13+
DeviceKeys,
1314
EncryptionAlgorithm,
1415
EncryptionKeyAlgorithm,
1516
IdentityKey,
17+
KeyID,
1618
SigningKey,
1719
UserID,
1820
)
1921

20-
from . import base
2122
from .sessions import Session
23+
from .signature import sign_olm
2224

2325

2426
class OlmAccount(olm.Account):
@@ -74,19 +76,18 @@ def new_outbound_session(self, target_key: IdentityKey, one_time_key: IdentityKe
7476
session.pickle("roundtrip"), passphrase="roundtrip", creation_time=datetime.now()
7577
)
7678

77-
def get_device_keys(self, user_id: UserID, device_id: DeviceID) -> Dict[str, Any]:
78-
device_keys = {
79-
"user_id": user_id,
80-
"device_id": device_id,
81-
"algorithms": [EncryptionAlgorithm.OLM_V1.value, EncryptionAlgorithm.MEGOLM_V1.value],
82-
"keys": {
83-
f"{algorithm}:{device_id}": key for algorithm, key in self.identity_keys.items()
79+
def get_device_keys(self, user_id: UserID, device_id: DeviceID) -> DeviceKeys:
80+
device_keys = DeviceKeys(
81+
user_id=user_id,
82+
device_id=device_id,
83+
algorithms=[EncryptionAlgorithm.OLM_V1, EncryptionAlgorithm.MEGOLM_V1],
84+
keys={
85+
KeyID(algorithm=EncryptionKeyAlgorithm(algorithm), key_id=key): key
86+
for algorithm, key in self.identity_keys.items()
8487
},
85-
}
86-
signature = self.sign(base.canonical_json(device_keys))
87-
device_keys["signatures"] = {
88-
user_id: {f"{EncryptionKeyAlgorithm.ED25519}:{device_id}": signature}
89-
}
88+
signatures={},
89+
)
90+
device_keys.signatures[user_id] = {KeyID.ed25519(device_id): sign_olm(device_keys, self)}
9091
return device_keys
9192

9293
def get_one_time_keys(
@@ -97,12 +98,12 @@ def get_one_time_keys(
9798
self.generate_one_time_keys(new_count)
9899
keys = {}
99100
for key_id, key in self.one_time_keys.get("curve25519", {}).items():
100-
signature = self.sign(base.canonical_json({"key": key}))
101-
keys[f"{EncryptionKeyAlgorithm.SIGNED_CURVE25519}:{key_id}"] = {
101+
keys[str(KeyID.signed_curve25519(IdentityKey(key_id)))] = {
102102
"key": key,
103103
"signatures": {
104-
user_id: {f"{EncryptionKeyAlgorithm.ED25519}:{device_id}": signature}
104+
user_id: {
105+
str(KeyID.ed25519(device_id)): sign_olm({"key": key}, self),
106+
}
105107
},
106108
}
107-
self.mark_keys_as_published()
108109
return keys

mautrix/crypto/machine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ async def _share_keys(self, current_otk_count: int | None) -> None:
313313
self.log.debug(f"Uploading {len(one_time_keys)} one-time keys")
314314
resp = await self.client.upload_keys(one_time_keys=one_time_keys, device_keys=device_keys)
315315
self.account.shared = True
316+
self.account.mark_keys_as_published()
316317
self._last_key_share = time.monotonic()
317318
await self.crypto_store.put_account(self.account)
318319
self.log.debug(f"Shared keys and saved account, new keys: {resp}")

mautrix/crypto/signature.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,19 @@
77
import functools
88
import json
99

10+
import olm
1011
import unpaddedbase64
1112

12-
from mautrix.types import DeviceID, EncryptionKeyAlgorithm, KeyID, SigningKey, UserID
13+
from mautrix.types import (
14+
JSON,
15+
DeviceID,
16+
EncryptionKeyAlgorithm,
17+
KeyID,
18+
Serializable,
19+
Signature,
20+
SigningKey,
21+
UserID,
22+
)
1323

1424
try:
1525
from Crypto.PublicKey import ECC
@@ -28,6 +38,14 @@ class SignedObject(TypedDict):
2838
unsigned: Any
2939

3040

41+
def sign_olm(data: dict[str, JSON] | Serializable, key: olm.PkSigning | olm.Account) -> Signature:
42+
if isinstance(data, Serializable):
43+
data = data.serialize()
44+
data.pop("signatures", None)
45+
data.pop("unsigned", None)
46+
return Signature(key.sign(canonical_json(data)))
47+
48+
3149
def verify_signature_json(
3250
data: "SignedObject", user_id: UserID, key_name: DeviceID | str, key: SigningKey
3351
) -> bool:

mautrix/types/event/encrypted.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from attr import dataclass
1111

12-
from ..primitive import JSON, DeviceID, IdentityKey, SessionID
12+
from ..primitive import JSON, DeviceID, IdentityKey, SessionID, SigningKey
1313
from ..util import ExtensibleEnum, Obj, Serializable, SerializableAttrs, deserializer, field
1414
from .base import BaseRoomEvent, BaseUnsigned
1515
from .message import RelatesTo
@@ -43,6 +43,18 @@ def deserialize(cls, raw: JSON) -> "KeyID":
4343
def __str__(self) -> str:
4444
return f"{self.algorithm.value}:{self.key_id}"
4545

46+
@classmethod
47+
def ed25519(cls, key_id: SigningKey | DeviceID) -> "KeyID":
48+
return cls(EncryptionKeyAlgorithm.ED25519, key_id)
49+
50+
@classmethod
51+
def curve25519(cls, key_id: IdentityKey) -> "KeyID":
52+
return cls(EncryptionKeyAlgorithm.CURVE25519, key_id)
53+
54+
@classmethod
55+
def signed_curve25519(cls, key_id: IdentityKey) -> "KeyID":
56+
return cls(EncryptionKeyAlgorithm.SIGNED_CURVE25519, key_id)
57+
4658

4759
class OlmMsgType(Serializable, IntEnum):
4860
PREKEY = 0

0 commit comments

Comments
 (0)