Skip to content

Commit 3f8d605

Browse files
committed
Add and update tests
1 parent 21ffe8b commit 3f8d605

File tree

7 files changed

+305
-15
lines changed

7 files changed

+305
-15
lines changed

mautrix/client/state_store/tests/store_test.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
# This Source Code Form is subject to the terms of the Mozilla Public
44
# License, v. 2.0. If a copy of the MPL was not distributed with this
55
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
6-
from typing import AsyncContextManager, AsyncIterator, Callable, Dict, List
6+
from __future__ import annotations
7+
8+
from typing import AsyncContextManager, AsyncIterator, Callable
79
from contextlib import asynccontextmanager
810
import json
911
import os
@@ -83,7 +85,7 @@ async def store(request) -> AsyncIterator[StateStore]:
8385
yield state_store
8486

8587

86-
def read_state_file(request, file) -> Dict[RoomID, List[StateEvent]]:
88+
def read_state_file(request, file) -> dict[RoomID, list[StateEvent]]:
8789
path = pathlib.Path(request.node.fspath).with_name(file)
8890
with path.open() as fp:
8991
content = json.load(fp)
@@ -122,7 +124,6 @@ async def get_joined_members(request, store: StateStore) -> None:
122124
await store.set_members(room_id, parsed_members, only_membership=Membership.JOIN)
123125

124126

125-
@pytest.mark.asyncio
126127
async def test_basic(store: StateStore) -> None:
127128
room_id = RoomID("!foo:example.com")
128129
user_id = UserID("@tulir:example.com")
@@ -136,7 +137,6 @@ async def test_basic(store: StateStore) -> None:
136137
assert await store.is_encrypted(RoomID("!unknown-room:example.com")) is None
137138

138139

139-
@pytest.mark.asyncio
140140
async def test_basic_updated(request, store: StateStore) -> None:
141141
await store_room_state(request, store)
142142
test_group = RoomID("!telegram-group:example.com")
@@ -145,7 +145,6 @@ async def test_basic_updated(request, store: StateStore) -> None:
145145
assert not await store.is_encrypted(RoomID("!unencrypted-room:example.com"))
146146

147147

148-
@pytest.mark.asyncio
149148
async def test_updates(request, store: StateStore) -> None:
150149
await store_room_state(request, store)
151150
room_id = RoomID("!telegram-group:example.com")
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright © 2019 Damir Jelić <[email protected]> (under the Apache 2.0 license)
2+
# Copyright © 2019 miruka <[email protected]> (under the Apache 2.0 license)
3+
# Copyright (c) 2022 Tulir Asokan
4+
#
5+
# This Source Code Form is subject to the terms of the Mozilla Public
6+
# License, v. 2.0. If a copy of the MPL was not distributed with this
7+
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
8+
from mautrix.types import EncryptedFile
9+
10+
from .async_attachments import async_encrypt_attachment, async_inplace_encrypt_attachment
11+
from .attachments import decrypt_attachment
12+
13+
try:
14+
from Crypto import Random
15+
except ImportError:
16+
from Cryptodome import Random
17+
18+
19+
async def _get_data_cypher_keys(data: bytes) -> tuple[bytes, EncryptedFile]:
20+
*chunks, keys = [i async for i in async_encrypt_attachment(data)]
21+
return b"".join(chunks), keys
22+
23+
24+
async def test_async_encrypt():
25+
data = b"Test bytes"
26+
27+
cyphertext, keys = await _get_data_cypher_keys(data)
28+
29+
plaintext = decrypt_attachment(cyphertext, keys.key.key, keys.hashes["sha256"], keys.iv)
30+
31+
assert data == plaintext
32+
33+
34+
async def test_async_inplace_encrypt():
35+
orig_data = b"Test bytes"
36+
data = bytearray(orig_data)
37+
38+
keys = await async_inplace_encrypt_attachment(data)
39+
40+
assert data != orig_data
41+
42+
decrypt_attachment(data, keys.key.key, keys.hashes["sha256"], keys.iv, inplace=True)
43+
44+
assert data == orig_data

mautrix/crypto/attachments/attachments.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,42 +29,46 @@
2929
from Cryptodome.Util import Counter
3030

3131

32-
def decrypt_attachment(ciphertext: bytes, key: str, hash: str, iv: str) -> bytes:
32+
def decrypt_attachment(
33+
ciphertext: bytes | bytearray | memoryview, key: str, hash: str, iv: str, inplace: bool = False
34+
) -> bytes:
3335
"""Decrypt an encrypted attachment.
3436
3537
Args:
3638
ciphertext: The data to decrypt.
3739
key: AES_CTR JWK key object.
3840
hash: Base64 encoded SHA-256 hash of the ciphertext.
3941
iv: Base64 encoded 16 byte AES-CTR IV.
42+
inplace: Should the decryption be performed in-place?
43+
The input must be a bytearray or writable memoryview to use this.
4044
Returns:
4145
The plaintext bytes.
4246
Raises:
4347
EncryptionError: if the integrity check fails.
44-
45-
4648
"""
4749
expected_hash = unpaddedbase64.decode_base64(hash)
4850

4951
h = SHA256.new()
5052
h.update(ciphertext)
5153

5254
if h.digest() != expected_hash:
53-
raise DecryptionError("Mismatched SHA-256 digest.")
55+
raise DecryptionError("Mismatched SHA-256 digest")
5456

5557
try:
5658
byte_key: bytes = unpaddedbase64.decode_base64(key)
5759
except (binascii.Error, TypeError):
58-
raise DecryptionError("Error decoding key.")
60+
raise DecryptionError("Error decoding key")
5961

6062
try:
6163
byte_iv: bytes = unpaddedbase64.decode_base64(iv)
64+
if len(byte_iv) != 16:
65+
raise DecryptionError("Invalid IV length")
6266
prefix = byte_iv[:8]
6367
# A non-zero IV counter is not spec-compliant, but some clients still do it,
6468
# so decode the counter part too.
6569
initial_value = struct.unpack(">Q", byte_iv[8:])[0]
6670
except (binascii.Error, TypeError, IndexError, struct.error):
67-
raise DecryptionError("Error decoding initial values.")
71+
raise DecryptionError("Error decoding IV")
6872

6973
ctr = Counter.new(64, prefix=prefix, initial_value=initial_value)
7074

@@ -73,7 +77,11 @@ def decrypt_attachment(ciphertext: bytes, key: str, hash: str, iv: str) -> bytes
7377
except ValueError as e:
7478
raise DecryptionError("Failed to create AES cipher") from e
7579

76-
return cipher.decrypt(ciphertext)
80+
if inplace:
81+
cipher.decrypt(ciphertext, ciphertext)
82+
return ciphertext
83+
else:
84+
return cipher.decrypt(ciphertext)
7785

7886

7987
def encrypt_attachment(plaintext: bytes) -> tuple[bytes, EncryptedFile]:
@@ -103,7 +111,7 @@ def _prepare_encryption() -> tuple[bytes, bytes, AES, SHA256.SHA256Hash]:
103111
return key, iv, cipher, sha256
104112

105113

106-
def inplace_encrypt_attachment(data: bytearray) -> EncryptedFile:
114+
def inplace_encrypt_attachment(data: bytearray | memoryview) -> EncryptedFile:
107115
key, iv, cipher, sha256 = _prepare_encryption()
108116

109117
cipher.encrypt(plaintext=data, output=data)
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright © 2019 Damir Jelić <[email protected]> (under the Apache 2.0 license)
2+
# Copyright (c) 2022 Tulir Asokan
3+
#
4+
# This Source Code Form is subject to the terms of the Mozilla Public
5+
# License, v. 2.0. If a copy of the MPL was not distributed with this
6+
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
7+
import pytest
8+
import unpaddedbase64
9+
10+
from mautrix.errors import DecryptionError
11+
12+
from .attachments import decrypt_attachment, encrypt_attachment, inplace_encrypt_attachment
13+
14+
try:
15+
from Crypto import Random
16+
except ImportError:
17+
from Cryptodome import Random
18+
19+
20+
def test_encrypt():
21+
data = b"Test bytes"
22+
23+
cyphertext, keys = encrypt_attachment(data)
24+
25+
plaintext = decrypt_attachment(cyphertext, keys.key.key, keys.hashes["sha256"], keys.iv)
26+
27+
assert data == plaintext
28+
29+
30+
def test_inplace_encrypt():
31+
orig_data = b"Test bytes"
32+
data = bytearray(orig_data)
33+
34+
keys = inplace_encrypt_attachment(data)
35+
36+
assert data != orig_data
37+
38+
decrypt_attachment(data, keys.key.key, keys.hashes["sha256"], keys.iv, inplace=True)
39+
40+
assert data == orig_data
41+
42+
43+
def test_hash_verification():
44+
data = b"Test bytes"
45+
46+
cyphertext, keys = encrypt_attachment(data)
47+
48+
with pytest.raises(DecryptionError):
49+
decrypt_attachment(cyphertext, keys.key.key, "Fake hash", keys.iv)
50+
51+
52+
def test_invalid_key():
53+
data = b"Test bytes"
54+
55+
cyphertext, keys = encrypt_attachment(data)
56+
57+
with pytest.raises(DecryptionError):
58+
decrypt_attachment(cyphertext, "Fake key", keys.hashes["sha256"], keys.iv)
59+
60+
61+
def test_invalid_iv():
62+
data = b"Test bytes"
63+
64+
cyphertext, keys = encrypt_attachment(data)
65+
66+
with pytest.raises(DecryptionError):
67+
decrypt_attachment(cyphertext, keys.key.key, keys.hashes["sha256"], "Fake iv")
68+
69+
70+
def test_short_key():
71+
data = b"Test bytes"
72+
73+
cyphertext, keys = encrypt_attachment(data)
74+
75+
with pytest.raises(DecryptionError):
76+
decrypt_attachment(
77+
cyphertext,
78+
unpaddedbase64.encode_base64(b"Fake key", urlsafe=True),
79+
keys["hashes"]["sha256"],
80+
keys["iv"],
81+
)
82+
83+
84+
def test_short_iv():
85+
data = b"Test bytes"
86+
87+
cyphertext, keys = encrypt_attachment(data)
88+
89+
with pytest.raises(DecryptionError):
90+
decrypt_attachment(
91+
cyphertext,
92+
keys.key.key,
93+
keys.hashes["sha256"],
94+
unpaddedbase64.encode_base64(b"F" + b"\x00" * 8),
95+
)
96+
97+
98+
def test_fake_key():
99+
data = b"Test bytes"
100+
101+
cyphertext, keys = encrypt_attachment(data)
102+
103+
fake_key = Random.new().read(32)
104+
105+
plaintext = decrypt_attachment(
106+
cyphertext,
107+
unpaddedbase64.encode_base64(fake_key, urlsafe=True),
108+
keys["hashes"]["sha256"],
109+
keys["iv"],
110+
)
111+
assert plaintext != data

mautrix/crypto/store/tests/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)