Skip to content

Commit 387e509

Browse files
authored
Convert AESSIV AEAD to Rust (#9359)
1 parent 46930d2 commit 387e509

File tree

7 files changed

+258
-127
lines changed

7 files changed

+258
-127
lines changed

src/cryptography/hazmat/backends/openssl/aead.py

Lines changed: 12 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,10 @@
1414
AESCCM,
1515
AESGCM,
1616
AESOCB3,
17-
AESSIV,
1817
ChaCha20Poly1305,
1918
)
2019

21-
_AEADTypes = typing.Union[
22-
AESCCM, AESGCM, AESOCB3, AESSIV, ChaCha20Poly1305
23-
]
20+
_AEADTypes = typing.Union[AESCCM, AESGCM, AESOCB3, ChaCha20Poly1305]
2421

2522

2623
def _is_evp_aead_supported_cipher(
@@ -44,16 +41,9 @@ def _aead_cipher_supported(backend: Backend, cipher: _AEADTypes) -> bool:
4441
cipher_name = _evp_cipher_cipher_name(cipher)
4542
if backend._fips_enabled and cipher_name not in backend._fips_aead:
4643
return False
47-
# SIV isn't loaded through get_cipherbyname but instead a new fetch API
48-
# only available in 3.0+. But if we know we're on 3.0+ then we know
49-
# it's supported.
50-
if cipher_name.endswith(b"-siv"):
51-
return backend._lib.CRYPTOGRAPHY_OPENSSL_300_OR_GREATER == 1
52-
else:
53-
return (
54-
backend._lib.EVP_get_cipherbyname(cipher_name)
55-
!= backend._ffi.NULL
56-
)
44+
return (
45+
backend._lib.EVP_get_cipherbyname(cipher_name) != backend._ffi.NULL
46+
)
5747

5848

5949
def _aead_create_ctx(
@@ -231,7 +221,6 @@ def _evp_cipher_cipher_name(cipher: _AEADTypes) -> bytes:
231221
AESCCM,
232222
AESGCM,
233223
AESOCB3,
234-
AESSIV,
235224
ChaCha20Poly1305,
236225
)
237226

@@ -241,26 +230,14 @@ def _evp_cipher_cipher_name(cipher: _AEADTypes) -> bytes:
241230
return f"aes-{len(cipher._key) * 8}-ccm".encode("ascii")
242231
elif isinstance(cipher, AESOCB3):
243232
return f"aes-{len(cipher._key) * 8}-ocb".encode("ascii")
244-
elif isinstance(cipher, AESSIV):
245-
return f"aes-{len(cipher._key) * 8 // 2}-siv".encode("ascii")
246233
else:
247234
assert isinstance(cipher, AESGCM)
248235
return f"aes-{len(cipher._key) * 8}-gcm".encode("ascii")
249236

250237

251238
def _evp_cipher(cipher_name: bytes, backend: Backend):
252-
if cipher_name.endswith(b"-siv"):
253-
evp_cipher = backend._lib.EVP_CIPHER_fetch(
254-
backend._ffi.NULL,
255-
cipher_name,
256-
backend._ffi.NULL,
257-
)
258-
backend.openssl_assert(evp_cipher != backend._ffi.NULL)
259-
evp_cipher = backend._ffi.gc(evp_cipher, backend._lib.EVP_CIPHER_free)
260-
else:
261-
evp_cipher = backend._lib.EVP_get_cipherbyname(cipher_name)
262-
backend.openssl_assert(evp_cipher != backend._ffi.NULL)
263-
239+
evp_cipher = backend._lib.EVP_get_cipherbyname(cipher_name)
240+
backend.openssl_assert(evp_cipher != backend._ffi.NULL)
264241
return evp_cipher
265242

266243

@@ -389,10 +366,7 @@ def _evp_cipher_process_data(backend: Backend, ctx, data: bytes) -> bytes:
389366
buf = backend._ffi.new("unsigned char[]", len(data))
390367
data_ptr = backend._ffi.from_buffer(data)
391368
res = backend._lib.EVP_CipherUpdate(ctx, buf, outlen, data_ptr, len(data))
392-
if res == 0:
393-
# AES SIV can error here if the data is invalid on decrypt
394-
backend._consume_errors()
395-
raise InvalidTag
369+
backend.openssl_assert(res != 0)
396370
return backend._ffi.buffer(buf, outlen[0])[:]
397371

398372

@@ -405,7 +379,7 @@ def _evp_cipher_encrypt(
405379
tag_length: int,
406380
ctx: typing.Any = None,
407381
) -> bytes:
408-
from cryptography.hazmat.primitives.ciphers.aead import AESCCM, AESSIV
382+
from cryptography.hazmat.primitives.ciphers.aead import AESCCM
409383

410384
if ctx is None:
411385
cipher_name = _evp_cipher_cipher_name(cipher)
@@ -445,14 +419,7 @@ def _evp_cipher_encrypt(
445419
backend.openssl_assert(res != 0)
446420
tag = backend._ffi.buffer(tag_buf)[:]
447421

448-
if isinstance(cipher, AESSIV):
449-
# RFC 5297 defines the output as IV || C, where the tag we generate
450-
# is the "IV" and C is the ciphertext. This is the opposite of our
451-
# other AEADs, which are Ciphertext || Tag
452-
backend.openssl_assert(len(tag) == 16)
453-
return tag + processed_data
454-
else:
455-
return processed_data + tag
422+
return processed_data + tag
456423

457424

458425
def _evp_cipher_decrypt(
@@ -464,20 +431,13 @@ def _evp_cipher_decrypt(
464431
tag_length: int,
465432
ctx: typing.Any = None,
466433
) -> bytes:
467-
from cryptography.hazmat.primitives.ciphers.aead import AESCCM, AESSIV
434+
from cryptography.hazmat.primitives.ciphers.aead import AESCCM
468435

469436
if len(data) < tag_length:
470437
raise InvalidTag
471438

472-
if isinstance(cipher, AESSIV):
473-
# RFC 5297 defines the output as IV || C, where the tag we generate
474-
# is the "IV" and C is the ciphertext. This is the opposite of our
475-
# other AEADs, which are Ciphertext || Tag
476-
tag = data[:tag_length]
477-
data = data[tag_length:]
478-
else:
479-
tag = data[-tag_length:]
480-
data = data[:-tag_length]
439+
tag = data[-tag_length:]
440+
data = data[:-tag_length]
481441
if ctx is None:
482442
cipher_name = _evp_cipher_cipher_name(cipher)
483443
ctx = _evp_cipher_aead_setup(

src/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import typing
66

77
from cryptography.hazmat.bindings._rust.openssl import (
8+
aead,
89
dh,
910
dsa,
1011
ec,
@@ -21,6 +22,7 @@ from cryptography.hazmat.bindings._rust.openssl import (
2122
__all__ = [
2223
"openssl_version",
2324
"raise_openssl_error",
25+
"aead",
2426
"dh",
2527
"dsa",
2628
"ec",
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# This file is dual licensed under the terms of the Apache License, Version
2+
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3+
# for complete details.
4+
5+
import typing
6+
7+
class AESSIV:
8+
def __init__(self, key: bytes) -> None: ...
9+
@staticmethod
10+
def generate_key(key_size: int) -> bytes: ...
11+
def encrypt(
12+
self,
13+
nonce: bytes,
14+
associated_data: typing.Optional[typing.List[bytes]],
15+
) -> bytes: ...
16+
def decrypt(
17+
self,
18+
nonce: bytes,
19+
associated_data: typing.Optional[typing.List[bytes]],
20+
) -> bytes: ...

src/cryptography/hazmat/primitives/ciphers/aead.py

Lines changed: 11 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,17 @@
1111
from cryptography.hazmat.backends.openssl import aead
1212
from cryptography.hazmat.backends.openssl.backend import backend
1313
from cryptography.hazmat.bindings._rust import FixedPool
14+
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
15+
16+
__all__ = [
17+
"ChaCha20Poly1305",
18+
"AESCCM",
19+
"AESGCM",
20+
"AESOCB3",
21+
"AESSIV",
22+
]
23+
24+
AESSIV = rust_openssl.aead.AESSIV
1425

1526

1627
class ChaCha20Poly1305:
@@ -301,78 +312,3 @@ def _check_params(
301312
utils._check_byteslike("associated_data", associated_data)
302313
if len(nonce) < 12 or len(nonce) > 15:
303314
raise ValueError("Nonce must be between 12 and 15 bytes")
304-
305-
306-
class AESSIV:
307-
_MAX_SIZE = 2**31 - 1
308-
309-
def __init__(self, key: bytes):
310-
utils._check_byteslike("key", key)
311-
if len(key) not in (32, 48, 64):
312-
raise ValueError("AESSIV key must be 256, 384, or 512 bits.")
313-
314-
self._key = key
315-
316-
if not backend.aead_cipher_supported(self):
317-
raise exceptions.UnsupportedAlgorithm(
318-
"AES-SIV is not supported by this version of OpenSSL",
319-
exceptions._Reasons.UNSUPPORTED_CIPHER,
320-
)
321-
322-
@classmethod
323-
def generate_key(cls, bit_length: int) -> bytes:
324-
if not isinstance(bit_length, int):
325-
raise TypeError("bit_length must be an integer")
326-
327-
if bit_length not in (256, 384, 512):
328-
raise ValueError("bit_length must be 256, 384, or 512")
329-
330-
return os.urandom(bit_length // 8)
331-
332-
def encrypt(
333-
self,
334-
data: bytes,
335-
associated_data: typing.Optional[typing.List[bytes]],
336-
) -> bytes:
337-
if associated_data is None:
338-
associated_data = []
339-
340-
self._check_params(data, associated_data)
341-
342-
if len(data) > self._MAX_SIZE or any(
343-
len(ad) > self._MAX_SIZE for ad in associated_data
344-
):
345-
# This is OverflowError to match what cffi would raise
346-
raise OverflowError(
347-
"Data or associated data too long. Max 2**31 - 1 bytes"
348-
)
349-
350-
return aead._encrypt(backend, self, b"", data, associated_data, 16)
351-
352-
def decrypt(
353-
self,
354-
data: bytes,
355-
associated_data: typing.Optional[typing.List[bytes]],
356-
) -> bytes:
357-
if associated_data is None:
358-
associated_data = []
359-
360-
self._check_params(data, associated_data)
361-
362-
return aead._decrypt(backend, self, b"", data, associated_data, 16)
363-
364-
def _check_params(
365-
self,
366-
data: bytes,
367-
associated_data: typing.List[bytes],
368-
) -> None:
369-
utils._check_byteslike("data", data)
370-
if len(data) == 0:
371-
raise ValueError("data must not be zero length")
372-
373-
if not isinstance(associated_data, list):
374-
raise TypeError(
375-
"associated_data must be a list of bytes-like objects or None"
376-
)
377-
for x in associated_data:
378-
utils._check_byteslike("associated_data elements", x)

0 commit comments

Comments
 (0)