From 8f0052d7bb702c7feb7c98fa9bb0c28ee15c8eff Mon Sep 17 00:00:00 2001 From: Paul Kehrer Date: Sun, 2 Nov 2025 15:28:58 -0800 Subject: [PATCH 1/3] migrate concatkdf{hash,hmac} to rust --- .../hazmat/bindings/_rust/openssl/kdf.pyi | 23 ++ .../hazmat/primitives/kdf/concatkdf.py | 121 +-------- src/rust/src/backend/kdf.rs | 245 +++++++++++++++++- tests/hazmat/primitives/test_concatkdf.py | 7 +- 4 files changed, 278 insertions(+), 118 deletions(-) diff --git a/src/cryptography/hazmat/bindings/_rust/openssl/kdf.pyi b/src/cryptography/hazmat/bindings/_rust/openssl/kdf.pyi index 3380c7a785ff..00b530bb241e 100644 --- a/src/cryptography/hazmat/bindings/_rust/openssl/kdf.pyi +++ b/src/cryptography/hazmat/bindings/_rust/openssl/kdf.pyi @@ -91,3 +91,26 @@ class X963KDF: ) -> None: ... def derive(self, key_material: Buffer) -> bytes: ... def verify(self, key_material: bytes, expected_key: bytes) -> None: ... + +class ConcatKDFHash: + def __init__( + self, + algorithm: HashAlgorithm, + length: int, + otherinfo: bytes | None, + backend: typing.Any = None, + ) -> None: ... + def derive(self, key_material: Buffer) -> bytes: ... + def verify(self, key_material: bytes, expected_key: bytes) -> None: ... + +class ConcatKDFHMAC: + def __init__( + self, + algorithm: HashAlgorithm, + length: int, + salt: bytes | None, + otherinfo: bytes | None, + backend: typing.Any = None, + ) -> None: ... + def derive(self, key_material: Buffer) -> bytes: ... + def verify(self, key_material: bytes, expected_key: bytes) -> None: ... diff --git a/src/cryptography/hazmat/primitives/kdf/concatkdf.py b/src/cryptography/hazmat/primitives/kdf/concatkdf.py index 1b928415c5c1..398dc5dc55c3 100644 --- a/src/cryptography/hazmat/primitives/kdf/concatkdf.py +++ b/src/cryptography/hazmat/primitives/kdf/concatkdf.py @@ -4,122 +4,13 @@ from __future__ import annotations -import typing -from collections.abc import Callable - -from cryptography import utils -from cryptography.exceptions import AlreadyFinalized, InvalidKey -from cryptography.hazmat.primitives import constant_time, hashes, hmac +from cryptography.hazmat.bindings._rust import openssl as rust_openssl from cryptography.hazmat.primitives.kdf import KeyDerivationFunction +ConcatKDFHash = rust_openssl.kdf.ConcatKDFHash +ConcatKDFHMAC = rust_openssl.kdf.ConcatKDFHMAC -def _int_to_u32be(n: int) -> bytes: - return n.to_bytes(length=4, byteorder="big") - - -def _common_args_checks( - algorithm: hashes.HashAlgorithm, - length: int, - otherinfo: bytes | None, -) -> None: - max_length = algorithm.digest_size * (2**32 - 1) - if length > max_length: - raise ValueError(f"Cannot derive keys larger than {max_length} bits.") - if otherinfo is not None: - utils._check_bytes("otherinfo", otherinfo) - - -def _concatkdf_derive( - key_material: utils.Buffer, - length: int, - auxfn: Callable[[], hashes.HashContext], - otherinfo: bytes, -) -> bytes: - utils._check_byteslike("key_material", key_material) - output = [b""] - outlen = 0 - counter = 1 - - while length > outlen: - h = auxfn() - h.update(_int_to_u32be(counter)) - h.update(key_material) - h.update(otherinfo) - output.append(h.finalize()) - outlen += len(output[-1]) - counter += 1 - - return b"".join(output)[:length] - - -class ConcatKDFHash(KeyDerivationFunction): - def __init__( - self, - algorithm: hashes.HashAlgorithm, - length: int, - otherinfo: bytes | None, - backend: typing.Any = None, - ): - _common_args_checks(algorithm, length, otherinfo) - self._algorithm = algorithm - self._length = length - self._otherinfo: bytes = otherinfo if otherinfo is not None else b"" - - self._used = False - - def _hash(self) -> hashes.Hash: - return hashes.Hash(self._algorithm) - - def derive(self, key_material: utils.Buffer) -> bytes: - if self._used: - raise AlreadyFinalized - self._used = True - return _concatkdf_derive( - key_material, self._length, self._hash, self._otherinfo - ) - - def verify(self, key_material: bytes, expected_key: bytes) -> None: - if not constant_time.bytes_eq(self.derive(key_material), expected_key): - raise InvalidKey - - -class ConcatKDFHMAC(KeyDerivationFunction): - def __init__( - self, - algorithm: hashes.HashAlgorithm, - length: int, - salt: bytes | None, - otherinfo: bytes | None, - backend: typing.Any = None, - ): - _common_args_checks(algorithm, length, otherinfo) - self._algorithm = algorithm - self._length = length - self._otherinfo: bytes = otherinfo if otherinfo is not None else b"" - - if algorithm.block_size is None: - raise TypeError(f"{algorithm.name} is unsupported for ConcatKDF") - - if salt is None: - salt = b"\x00" * algorithm.block_size - else: - utils._check_bytes("salt", salt) - - self._salt = salt - - self._used = False - - def _hmac(self) -> hmac.HMAC: - return hmac.HMAC(self._salt, self._algorithm) - - def derive(self, key_material: utils.Buffer) -> bytes: - if self._used: - raise AlreadyFinalized - self._used = True - return _concatkdf_derive( - key_material, self._length, self._hmac, self._otherinfo - ) +KeyDerivationFunction.register(ConcatKDFHash) +KeyDerivationFunction.register(ConcatKDFHMAC) - def verify(self, key_material: bytes, expected_key: bytes) -> None: - if not constant_time.bytes_eq(self.derive(key_material), expected_key): - raise InvalidKey +__all__ = ["ConcatKDFHMAC", "ConcatKDFHash"] diff --git a/src/rust/src/backend/kdf.rs b/src/rust/src/backend/kdf.rs index c91afd728af0..03bab6bc44ee 100644 --- a/src/rust/src/backend/kdf.rs +++ b/src/rust/src/backend/kdf.rs @@ -942,8 +942,251 @@ impl X963Kdf { } } +// NO-COVERAGE-START +#[pyo3::pyclass( + module = "cryptography.hazmat.primitives.kdf.concatkdf", + name = "ConcatKDFHash" +)] +// NO-COVERAGE-END +struct ConcatKdfHash { + algorithm: pyo3::Py, + length: usize, + otherinfo: pyo3::Py, + used: bool, +} + +#[pyo3::pymethods] +impl ConcatKdfHash { + #[new] + #[pyo3(signature = (algorithm, length, otherinfo, backend=None))] + fn new( + py: pyo3::Python<'_>, + algorithm: pyo3::Py, + length: usize, + otherinfo: Option>, + backend: Option>, + ) -> CryptographyResult { + _ = backend; + + let algorithm_bound = algorithm.bind(py); + let digest_size = algorithm_bound + .getattr(pyo3::intern!(py, "digest_size"))? + .extract::()?; + + let max_len = digest_size.saturating_mul(u32::MAX as usize); + if length > max_len { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err(format!( + "Cannot derive keys larger than {max_len} bits." + )), + )); + } + + let otherinfo_bytes = + otherinfo.unwrap_or_else(|| pyo3::types::PyBytes::new(py, b"").into()); + + Ok(ConcatKdfHash { + algorithm, + length, + otherinfo: otherinfo_bytes, + used: false, + }) + } + + fn derive<'p>( + &mut self, + py: pyo3::Python<'p>, + key_material: CffiBuf<'_>, + ) -> CryptographyResult> { + if self.used { + return Err(exceptions::already_finalized_error()); + } + self.used = true; + + let algorithm_bound = self.algorithm.bind(py); + let digest_size = algorithm_bound + .getattr(pyo3::intern!(py, "digest_size"))? + .extract::()?; + + Ok(pyo3::types::PyBytes::new_with(py, self.length, |output| { + let mut pos = 0usize; + let mut counter = 1u32; + + while pos < self.length { + let mut hash_obj = hashes::Hash::new(py, algorithm_bound, None)?; + hash_obj.update_bytes(&counter.to_be_bytes())?; + hash_obj.update_bytes(key_material.as_bytes())?; + hash_obj.update_bytes(self.otherinfo.as_bytes(py))?; + let block = hash_obj.finalize(py)?; + let block_bytes = block.as_bytes(); + + let copy_len = (self.length - pos).min(digest_size); + output[pos..pos + copy_len].copy_from_slice(&block_bytes[..copy_len]); + pos += copy_len; + counter += 1; + } + + Ok(()) + })?) + } + + fn verify( + &mut self, + py: pyo3::Python<'_>, + key_material: CffiBuf<'_>, + expected_key: CffiBuf<'_>, + ) -> CryptographyResult<()> { + let actual = self.derive(py, key_material)?; + let actual_bytes = actual.as_bytes(); + let expected_bytes = expected_key.as_bytes(); + + if !constant_time::bytes_eq(actual_bytes, expected_bytes) { + return Err(CryptographyError::from(exceptions::InvalidKey::new_err( + "Keys do not match.", + ))); + } + + Ok(()) + } +} + +// NO-COVERAGE-START +#[pyo3::pyclass( + module = "cryptography.hazmat.primitives.kdf.concatkdf", + name = "ConcatKDFHMAC" +)] +// NO-COVERAGE-END +struct ConcatKdfHmac { + algorithm: pyo3::Py, + length: usize, + salt: pyo3::Py, + otherinfo: pyo3::Py, + used: bool, +} + +#[pyo3::pymethods] +impl ConcatKdfHmac { + #[new] + #[pyo3(signature = (algorithm, length, salt, otherinfo, backend=None))] + fn new( + py: pyo3::Python<'_>, + algorithm: pyo3::Py, + length: usize, + salt: Option>, + otherinfo: Option>, + backend: Option>, + ) -> CryptographyResult { + _ = backend; + + let algorithm_bound = algorithm.bind(py); + let digest_size = algorithm_bound + .getattr(pyo3::intern!(py, "digest_size"))? + .extract::()?; + + let max_len = digest_size.saturating_mul(u32::MAX as usize); + if length > max_len { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err(format!( + "Cannot derive keys larger than {max_len} bits." + )), + )); + } + + // Check for block_size (required for HMAC) + let block_size = algorithm_bound.getattr(pyo3::intern!(py, "block_size"))?; + + if block_size.is_none() { + let name = algorithm_bound + .getattr(pyo3::intern!(py, "name"))? + .extract::()?; + return Err(CryptographyError::from( + pyo3::exceptions::PyTypeError::new_err(format!( + "{name} is unsupported for ConcatKDF" + )), + )); + } + + let block_size_val = block_size.extract::()?; + + // Default salt to zeros of block_size length + let salt_bytes = if let Some(s) = salt { + s + } else { + pyo3::types::PyBytes::new(py, &vec![0u8; block_size_val]).into() + }; + + let otherinfo_bytes = + otherinfo.unwrap_or_else(|| pyo3::types::PyBytes::new(py, b"").into()); + + Ok(ConcatKdfHmac { + algorithm, + length, + salt: salt_bytes, + otherinfo: otherinfo_bytes, + used: false, + }) + } + + fn derive<'p>( + &mut self, + py: pyo3::Python<'p>, + key_material: CffiBuf<'_>, + ) -> CryptographyResult> { + if self.used { + return Err(exceptions::already_finalized_error()); + } + self.used = true; + + let algorithm_bound = self.algorithm.bind(py); + let digest_size = algorithm_bound + .getattr(pyo3::intern!(py, "digest_size"))? + .extract::()?; + + Ok(pyo3::types::PyBytes::new_with(py, self.length, |output| { + let mut pos = 0usize; + let mut counter = 1u32; + + while pos < self.length { + let mut hmac = Hmac::new_bytes(py, self.salt.as_bytes(py), algorithm_bound)?; + hmac.update_bytes(&counter.to_be_bytes())?; + hmac.update_bytes(key_material.as_bytes())?; + hmac.update_bytes(self.otherinfo.as_bytes(py))?; + let result = hmac.finalize_bytes()?; + + let copy_len = (self.length - pos).min(digest_size); + output[pos..pos + copy_len].copy_from_slice(&result[..copy_len]); + pos += copy_len; + counter += 1; + } + + Ok(()) + })?) + } + + fn verify( + &mut self, + py: pyo3::Python<'_>, + key_material: CffiBuf<'_>, + expected_key: CffiBuf<'_>, + ) -> CryptographyResult<()> { + let actual = self.derive(py, key_material)?; + let actual_bytes = actual.as_bytes(); + let expected_bytes = expected_key.as_bytes(); + + if !constant_time::bytes_eq(actual_bytes, expected_bytes) { + return Err(CryptographyError::from(exceptions::InvalidKey::new_err( + "Keys do not match.", + ))); + } + + Ok(()) + } +} + #[pyo3::pymodule(gil_used = false)] pub(crate) mod kdf { #[pymodule_export] - use super::{Argon2id, Hkdf, HkdfExpand, Pbkdf2Hmac, Scrypt, X963Kdf}; + use super::{ + Argon2id, ConcatKdfHash, ConcatKdfHmac, Hkdf, HkdfExpand, Pbkdf2Hmac, Scrypt, X963Kdf, + }; } diff --git a/tests/hazmat/primitives/test_concatkdf.py b/tests/hazmat/primitives/test_concatkdf.py index 3a6e994be304..f0c44fbba556 100644 --- a/tests/hazmat/primitives/test_concatkdf.py +++ b/tests/hazmat/primitives/test_concatkdf.py @@ -4,6 +4,7 @@ import binascii +import sys import pytest @@ -18,8 +19,9 @@ class TestConcatKDFHash: def test_length_limit(self, backend): big_length = hashes.SHA256().digest_size * (2**32 - 1) + 1 + error = OverflowError if sys.maxsize <= 2**31 else ValueError - with pytest.raises(ValueError): + with pytest.raises(error): ConcatKDFHash(hashes.SHA256(), big_length, None, backend) def test_already_finalized(self, backend): @@ -127,8 +129,9 @@ def test_unicode_typeerror(self, backend): class TestConcatKDFHMAC: def test_length_limit(self, backend): big_length = hashes.SHA256().digest_size * (2**32 - 1) + 1 + error = OverflowError if sys.maxsize <= 2**31 else ValueError - with pytest.raises(ValueError): + with pytest.raises(error): ConcatKDFHMAC(hashes.SHA256(), big_length, None, None, backend) def test_already_finalized(self, backend): From fff177b54f598f051da4b9da423065cbff3ec02d Mon Sep 17 00:00:00 2001 From: Paul Kehrer Date: Sun, 2 Nov 2025 17:23:48 -0800 Subject: [PATCH 2/3] use an option --- src/rust/src/backend/kdf.rs | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/src/rust/src/backend/kdf.rs b/src/rust/src/backend/kdf.rs index 03bab6bc44ee..a7c04278c9b4 100644 --- a/src/rust/src/backend/kdf.rs +++ b/src/rust/src/backend/kdf.rs @@ -951,7 +951,7 @@ impl X963Kdf { struct ConcatKdfHash { algorithm: pyo3::Py, length: usize, - otherinfo: pyo3::Py, + otherinfo: Option>, used: bool, } @@ -982,13 +982,10 @@ impl ConcatKdfHash { )); } - let otherinfo_bytes = - otherinfo.unwrap_or_else(|| pyo3::types::PyBytes::new(py, b"").into()); - Ok(ConcatKdfHash { algorithm, length, - otherinfo: otherinfo_bytes, + otherinfo, used: false, }) } @@ -1016,7 +1013,9 @@ impl ConcatKdfHash { let mut hash_obj = hashes::Hash::new(py, algorithm_bound, None)?; hash_obj.update_bytes(&counter.to_be_bytes())?; hash_obj.update_bytes(key_material.as_bytes())?; - hash_obj.update_bytes(self.otherinfo.as_bytes(py))?; + if let Some(ref otherinfo) = self.otherinfo { + hash_obj.update_bytes(otherinfo.as_bytes(py))?; + } let block = hash_obj.finalize(py)?; let block_bytes = block.as_bytes(); @@ -1060,7 +1059,7 @@ struct ConcatKdfHmac { algorithm: pyo3::Py, length: usize, salt: pyo3::Py, - otherinfo: pyo3::Py, + otherinfo: Option>, used: bool, } @@ -1115,14 +1114,11 @@ impl ConcatKdfHmac { pyo3::types::PyBytes::new(py, &vec![0u8; block_size_val]).into() }; - let otherinfo_bytes = - otherinfo.unwrap_or_else(|| pyo3::types::PyBytes::new(py, b"").into()); - Ok(ConcatKdfHmac { algorithm, length, salt: salt_bytes, - otherinfo: otherinfo_bytes, + otherinfo, used: false, }) } @@ -1150,7 +1146,9 @@ impl ConcatKdfHmac { let mut hmac = Hmac::new_bytes(py, self.salt.as_bytes(py), algorithm_bound)?; hmac.update_bytes(&counter.to_be_bytes())?; hmac.update_bytes(key_material.as_bytes())?; - hmac.update_bytes(self.otherinfo.as_bytes(py))?; + if let Some(ref otherinfo) = self.otherinfo { + hmac.update_bytes(otherinfo.as_bytes(py))?; + } let result = hmac.finalize_bytes()?; let copy_len = (self.length - pos).min(digest_size); From 149bb635ba80607d728b7df8fd59527bf440e3e1 Mon Sep 17 00:00:00 2001 From: Paul Kehrer Date: Sun, 2 Nov 2025 17:43:59 -0800 Subject: [PATCH 3/3] code review --- src/rust/src/backend/kdf.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/rust/src/backend/kdf.rs b/src/rust/src/backend/kdf.rs index a7c04278c9b4..579a773a2df4 100644 --- a/src/rust/src/backend/kdf.rs +++ b/src/rust/src/backend/kdf.rs @@ -1091,9 +1091,7 @@ impl ConcatKdfHmac { )); } - // Check for block_size (required for HMAC) let block_size = algorithm_bound.getattr(pyo3::intern!(py, "block_size"))?; - if block_size.is_none() { let name = algorithm_bound .getattr(pyo3::intern!(py, "name"))? @@ -1111,7 +1109,7 @@ impl ConcatKdfHmac { let salt_bytes = if let Some(s) = salt { s } else { - pyo3::types::PyBytes::new(py, &vec![0u8; block_size_val]).into() + pyo3::types::PyBytes::new_with(py, block_size_val, |_| Ok(()))?.into() }; Ok(ConcatKdfHmac {