diff --git a/src/cryptography/hazmat/bindings/_rust/openssl/kdf.pyi b/src/cryptography/hazmat/bindings/_rust/openssl/kdf.pyi index 3380c7a785ff..00b530bb241e 100644 --- a/src/cryptography/hazmat/bindings/_rust/openssl/kdf.pyi +++ b/src/cryptography/hazmat/bindings/_rust/openssl/kdf.pyi @@ -91,3 +91,26 @@ class X963KDF: ) -> None: ... def derive(self, key_material: Buffer) -> bytes: ... def verify(self, key_material: bytes, expected_key: bytes) -> None: ... + +class ConcatKDFHash: + def __init__( + self, + algorithm: HashAlgorithm, + length: int, + otherinfo: bytes | None, + backend: typing.Any = None, + ) -> None: ... + def derive(self, key_material: Buffer) -> bytes: ... + def verify(self, key_material: bytes, expected_key: bytes) -> None: ... + +class ConcatKDFHMAC: + def __init__( + self, + algorithm: HashAlgorithm, + length: int, + salt: bytes | None, + otherinfo: bytes | None, + backend: typing.Any = None, + ) -> None: ... + def derive(self, key_material: Buffer) -> bytes: ... + def verify(self, key_material: bytes, expected_key: bytes) -> None: ... diff --git a/src/cryptography/hazmat/primitives/kdf/concatkdf.py b/src/cryptography/hazmat/primitives/kdf/concatkdf.py index 1b928415c5c1..398dc5dc55c3 100644 --- a/src/cryptography/hazmat/primitives/kdf/concatkdf.py +++ b/src/cryptography/hazmat/primitives/kdf/concatkdf.py @@ -4,122 +4,13 @@ from __future__ import annotations -import typing -from collections.abc import Callable - -from cryptography import utils -from cryptography.exceptions import AlreadyFinalized, InvalidKey -from cryptography.hazmat.primitives import constant_time, hashes, hmac +from cryptography.hazmat.bindings._rust import openssl as rust_openssl from cryptography.hazmat.primitives.kdf import KeyDerivationFunction +ConcatKDFHash = rust_openssl.kdf.ConcatKDFHash +ConcatKDFHMAC = rust_openssl.kdf.ConcatKDFHMAC -def _int_to_u32be(n: int) -> bytes: - return n.to_bytes(length=4, byteorder="big") - - -def _common_args_checks( - algorithm: hashes.HashAlgorithm, - length: int, - otherinfo: bytes | None, -) -> None: - max_length = algorithm.digest_size * (2**32 - 1) - if length > max_length: - raise ValueError(f"Cannot derive keys larger than {max_length} bits.") - if otherinfo is not None: - utils._check_bytes("otherinfo", otherinfo) - - -def _concatkdf_derive( - key_material: utils.Buffer, - length: int, - auxfn: Callable[[], hashes.HashContext], - otherinfo: bytes, -) -> bytes: - utils._check_byteslike("key_material", key_material) - output = [b""] - outlen = 0 - counter = 1 - - while length > outlen: - h = auxfn() - h.update(_int_to_u32be(counter)) - h.update(key_material) - h.update(otherinfo) - output.append(h.finalize()) - outlen += len(output[-1]) - counter += 1 - - return b"".join(output)[:length] - - -class ConcatKDFHash(KeyDerivationFunction): - def __init__( - self, - algorithm: hashes.HashAlgorithm, - length: int, - otherinfo: bytes | None, - backend: typing.Any = None, - ): - _common_args_checks(algorithm, length, otherinfo) - self._algorithm = algorithm - self._length = length - self._otherinfo: bytes = otherinfo if otherinfo is not None else b"" - - self._used = False - - def _hash(self) -> hashes.Hash: - return hashes.Hash(self._algorithm) - - def derive(self, key_material: utils.Buffer) -> bytes: - if self._used: - raise AlreadyFinalized - self._used = True - return _concatkdf_derive( - key_material, self._length, self._hash, self._otherinfo - ) - - def verify(self, key_material: bytes, expected_key: bytes) -> None: - if not constant_time.bytes_eq(self.derive(key_material), expected_key): - raise InvalidKey - - -class ConcatKDFHMAC(KeyDerivationFunction): - def __init__( - self, - algorithm: hashes.HashAlgorithm, - length: int, - salt: bytes | None, - otherinfo: bytes | None, - backend: typing.Any = None, - ): - _common_args_checks(algorithm, length, otherinfo) - self._algorithm = algorithm - self._length = length - self._otherinfo: bytes = otherinfo if otherinfo is not None else b"" - - if algorithm.block_size is None: - raise TypeError(f"{algorithm.name} is unsupported for ConcatKDF") - - if salt is None: - salt = b"\x00" * algorithm.block_size - else: - utils._check_bytes("salt", salt) - - self._salt = salt - - self._used = False - - def _hmac(self) -> hmac.HMAC: - return hmac.HMAC(self._salt, self._algorithm) - - def derive(self, key_material: utils.Buffer) -> bytes: - if self._used: - raise AlreadyFinalized - self._used = True - return _concatkdf_derive( - key_material, self._length, self._hmac, self._otherinfo - ) +KeyDerivationFunction.register(ConcatKDFHash) +KeyDerivationFunction.register(ConcatKDFHMAC) - def verify(self, key_material: bytes, expected_key: bytes) -> None: - if not constant_time.bytes_eq(self.derive(key_material), expected_key): - raise InvalidKey +__all__ = ["ConcatKDFHMAC", "ConcatKDFHash"] diff --git a/src/rust/src/backend/kdf.rs b/src/rust/src/backend/kdf.rs index c91afd728af0..579a773a2df4 100644 --- a/src/rust/src/backend/kdf.rs +++ b/src/rust/src/backend/kdf.rs @@ -942,8 +942,247 @@ impl X963Kdf { } } +// NO-COVERAGE-START +#[pyo3::pyclass( + module = "cryptography.hazmat.primitives.kdf.concatkdf", + name = "ConcatKDFHash" +)] +// NO-COVERAGE-END +struct ConcatKdfHash { + algorithm: pyo3::Py, + length: usize, + otherinfo: Option>, + used: bool, +} + +#[pyo3::pymethods] +impl ConcatKdfHash { + #[new] + #[pyo3(signature = (algorithm, length, otherinfo, backend=None))] + fn new( + py: pyo3::Python<'_>, + algorithm: pyo3::Py, + length: usize, + otherinfo: Option>, + backend: Option>, + ) -> CryptographyResult { + _ = backend; + + let algorithm_bound = algorithm.bind(py); + let digest_size = algorithm_bound + .getattr(pyo3::intern!(py, "digest_size"))? + .extract::()?; + + let max_len = digest_size.saturating_mul(u32::MAX as usize); + if length > max_len { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err(format!( + "Cannot derive keys larger than {max_len} bits." + )), + )); + } + + Ok(ConcatKdfHash { + algorithm, + length, + otherinfo, + used: false, + }) + } + + fn derive<'p>( + &mut self, + py: pyo3::Python<'p>, + key_material: CffiBuf<'_>, + ) -> CryptographyResult> { + if self.used { + return Err(exceptions::already_finalized_error()); + } + self.used = true; + + let algorithm_bound = self.algorithm.bind(py); + let digest_size = algorithm_bound + .getattr(pyo3::intern!(py, "digest_size"))? + .extract::()?; + + Ok(pyo3::types::PyBytes::new_with(py, self.length, |output| { + let mut pos = 0usize; + let mut counter = 1u32; + + while pos < self.length { + let mut hash_obj = hashes::Hash::new(py, algorithm_bound, None)?; + hash_obj.update_bytes(&counter.to_be_bytes())?; + hash_obj.update_bytes(key_material.as_bytes())?; + if let Some(ref otherinfo) = self.otherinfo { + hash_obj.update_bytes(otherinfo.as_bytes(py))?; + } + let block = hash_obj.finalize(py)?; + let block_bytes = block.as_bytes(); + + let copy_len = (self.length - pos).min(digest_size); + output[pos..pos + copy_len].copy_from_slice(&block_bytes[..copy_len]); + pos += copy_len; + counter += 1; + } + + Ok(()) + })?) + } + + fn verify( + &mut self, + py: pyo3::Python<'_>, + key_material: CffiBuf<'_>, + expected_key: CffiBuf<'_>, + ) -> CryptographyResult<()> { + let actual = self.derive(py, key_material)?; + let actual_bytes = actual.as_bytes(); + let expected_bytes = expected_key.as_bytes(); + + if !constant_time::bytes_eq(actual_bytes, expected_bytes) { + return Err(CryptographyError::from(exceptions::InvalidKey::new_err( + "Keys do not match.", + ))); + } + + Ok(()) + } +} + +// NO-COVERAGE-START +#[pyo3::pyclass( + module = "cryptography.hazmat.primitives.kdf.concatkdf", + name = "ConcatKDFHMAC" +)] +// NO-COVERAGE-END +struct ConcatKdfHmac { + algorithm: pyo3::Py, + length: usize, + salt: pyo3::Py, + otherinfo: Option>, + used: bool, +} + +#[pyo3::pymethods] +impl ConcatKdfHmac { + #[new] + #[pyo3(signature = (algorithm, length, salt, otherinfo, backend=None))] + fn new( + py: pyo3::Python<'_>, + algorithm: pyo3::Py, + length: usize, + salt: Option>, + otherinfo: Option>, + backend: Option>, + ) -> CryptographyResult { + _ = backend; + + let algorithm_bound = algorithm.bind(py); + let digest_size = algorithm_bound + .getattr(pyo3::intern!(py, "digest_size"))? + .extract::()?; + + let max_len = digest_size.saturating_mul(u32::MAX as usize); + if length > max_len { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err(format!( + "Cannot derive keys larger than {max_len} bits." + )), + )); + } + + let block_size = algorithm_bound.getattr(pyo3::intern!(py, "block_size"))?; + if block_size.is_none() { + let name = algorithm_bound + .getattr(pyo3::intern!(py, "name"))? + .extract::()?; + return Err(CryptographyError::from( + pyo3::exceptions::PyTypeError::new_err(format!( + "{name} is unsupported for ConcatKDF" + )), + )); + } + + let block_size_val = block_size.extract::()?; + + // Default salt to zeros of block_size length + let salt_bytes = if let Some(s) = salt { + s + } else { + pyo3::types::PyBytes::new_with(py, block_size_val, |_| Ok(()))?.into() + }; + + Ok(ConcatKdfHmac { + algorithm, + length, + salt: salt_bytes, + otherinfo, + used: false, + }) + } + + fn derive<'p>( + &mut self, + py: pyo3::Python<'p>, + key_material: CffiBuf<'_>, + ) -> CryptographyResult> { + if self.used { + return Err(exceptions::already_finalized_error()); + } + self.used = true; + + let algorithm_bound = self.algorithm.bind(py); + let digest_size = algorithm_bound + .getattr(pyo3::intern!(py, "digest_size"))? + .extract::()?; + + Ok(pyo3::types::PyBytes::new_with(py, self.length, |output| { + let mut pos = 0usize; + let mut counter = 1u32; + + while pos < self.length { + let mut hmac = Hmac::new_bytes(py, self.salt.as_bytes(py), algorithm_bound)?; + hmac.update_bytes(&counter.to_be_bytes())?; + hmac.update_bytes(key_material.as_bytes())?; + if let Some(ref otherinfo) = self.otherinfo { + hmac.update_bytes(otherinfo.as_bytes(py))?; + } + let result = hmac.finalize_bytes()?; + + let copy_len = (self.length - pos).min(digest_size); + output[pos..pos + copy_len].copy_from_slice(&result[..copy_len]); + pos += copy_len; + counter += 1; + } + + Ok(()) + })?) + } + + fn verify( + &mut self, + py: pyo3::Python<'_>, + key_material: CffiBuf<'_>, + expected_key: CffiBuf<'_>, + ) -> CryptographyResult<()> { + let actual = self.derive(py, key_material)?; + let actual_bytes = actual.as_bytes(); + let expected_bytes = expected_key.as_bytes(); + + if !constant_time::bytes_eq(actual_bytes, expected_bytes) { + return Err(CryptographyError::from(exceptions::InvalidKey::new_err( + "Keys do not match.", + ))); + } + + Ok(()) + } +} + #[pyo3::pymodule(gil_used = false)] pub(crate) mod kdf { #[pymodule_export] - use super::{Argon2id, Hkdf, HkdfExpand, Pbkdf2Hmac, Scrypt, X963Kdf}; + use super::{ + Argon2id, ConcatKdfHash, ConcatKdfHmac, Hkdf, HkdfExpand, Pbkdf2Hmac, Scrypt, X963Kdf, + }; } diff --git a/tests/hazmat/primitives/test_concatkdf.py b/tests/hazmat/primitives/test_concatkdf.py index 3a6e994be304..f0c44fbba556 100644 --- a/tests/hazmat/primitives/test_concatkdf.py +++ b/tests/hazmat/primitives/test_concatkdf.py @@ -4,6 +4,7 @@ import binascii +import sys import pytest @@ -18,8 +19,9 @@ class TestConcatKDFHash: def test_length_limit(self, backend): big_length = hashes.SHA256().digest_size * (2**32 - 1) + 1 + error = OverflowError if sys.maxsize <= 2**31 else ValueError - with pytest.raises(ValueError): + with pytest.raises(error): ConcatKDFHash(hashes.SHA256(), big_length, None, backend) def test_already_finalized(self, backend): @@ -127,8 +129,9 @@ def test_unicode_typeerror(self, backend): class TestConcatKDFHMAC: def test_length_limit(self, backend): big_length = hashes.SHA256().digest_size * (2**32 - 1) + 1 + error = OverflowError if sys.maxsize <= 2**31 else ValueError - with pytest.raises(ValueError): + with pytest.raises(error): ConcatKDFHMAC(hashes.SHA256(), big_length, None, None, backend) def test_already_finalized(self, backend):