From 870bf35b419968f8206b6bc8ea46b59265c375f3 Mon Sep 17 00:00:00 2001 From: Paul Kehrer Date: Mon, 3 Nov 2025 08:04:44 -0800 Subject: [PATCH 1/8] migrate kbkdfhmac to rust --- .../hazmat/bindings/_rust/openssl/kdf.pyi | 20 ++ .../hazmat/primitives/kdf/kbkdf.py | 67 +--- src/rust/src/backend/kdf.rs | 299 +++++++++++++++++- src/rust/src/types.rs | 7 + tests/hazmat/primitives/test_kbkdf.py | 54 +--- 5 files changed, 345 insertions(+), 102 deletions(-) diff --git a/src/cryptography/hazmat/bindings/_rust/openssl/kdf.pyi b/src/cryptography/hazmat/bindings/_rust/openssl/kdf.pyi index 29d380ab214f..e8fab2eda288 100644 --- a/src/cryptography/hazmat/bindings/_rust/openssl/kdf.pyi +++ b/src/cryptography/hazmat/bindings/_rust/openssl/kdf.pyi @@ -5,6 +5,7 @@ import typing from cryptography.hazmat.primitives.hashes import HashAlgorithm +from cryptography.hazmat.primitives.kdf.kbkdf import CounterLocation, Mode from cryptography.utils import Buffer class PBKDF2HMAC: @@ -162,3 +163,22 @@ class ConcatKDFHMAC: def derive(self, key_material: Buffer) -> bytes: ... def derive_into(self, key_material: Buffer, buffer: Buffer) -> int: ... def verify(self, key_material: bytes, expected_key: bytes) -> None: ... + +class KBKDFHMAC: + def __init__( + self, + algorithm: HashAlgorithm, + mode: Mode, + length: int, + rlen: int, + llen: int | None, + location: CounterLocation, + label: bytes | None, + context: bytes | None, + fixed: bytes | None, + backend: typing.Any = None, + *, + break_location: int | None = None, + ) -> None: ... + def derive(self, key_material: Buffer) -> bytes: ... + def verify(self, key_material: bytes, expected_key: bytes) -> None: ... diff --git a/src/cryptography/hazmat/primitives/kdf/kbkdf.py b/src/cryptography/hazmat/primitives/kdf/kbkdf.py index 5b4713761679..a00b00008dd6 100644 --- a/src/cryptography/hazmat/primitives/kdf/kbkdf.py +++ b/src/cryptography/hazmat/primitives/kdf/kbkdf.py @@ -14,13 +14,8 @@ UnsupportedAlgorithm, _Reasons, ) -from cryptography.hazmat.primitives import ( - ciphers, - cmac, - constant_time, - hashes, - hmac, -) +from cryptography.hazmat.bindings._rust import openssl as rust_openssl +from cryptography.hazmat.primitives import ciphers, cmac, constant_time from cryptography.hazmat.primitives.kdf import KeyDerivationFunction @@ -178,62 +173,8 @@ def _generate_fixed_input(self) -> bytes: return b"".join([self._label, b"\x00", self._context, l_val]) -class KBKDFHMAC(KeyDerivationFunction): - def __init__( - self, - algorithm: hashes.HashAlgorithm, - mode: Mode, - length: int, - rlen: int, - llen: int | None, - location: CounterLocation, - label: bytes | None, - context: bytes | None, - fixed: bytes | None, - backend: typing.Any = None, - *, - break_location: int | None = None, - ): - if not isinstance(algorithm, hashes.HashAlgorithm): - raise UnsupportedAlgorithm( - "Algorithm supplied is not a supported hash algorithm.", - _Reasons.UNSUPPORTED_HASH, - ) - - from cryptography.hazmat.backends.openssl.backend import ( - backend as ossl, - ) - - if not ossl.hmac_supported(algorithm): - raise UnsupportedAlgorithm( - "Algorithm supplied is not a supported hmac algorithm.", - _Reasons.UNSUPPORTED_HASH, - ) - - self._algorithm = algorithm - - self._deriver = _KBKDFDeriver( - self._prf, - mode, - length, - rlen, - llen, - location, - break_location, - label, - context, - fixed, - ) - - def _prf(self, key_material: bytes) -> hmac.HMAC: - return hmac.HMAC(key_material, self._algorithm) - - def derive(self, key_material: utils.Buffer) -> bytes: - return self._deriver.derive(key_material, self._algorithm.digest_size) - - def verify(self, key_material: bytes, expected_key: bytes) -> None: - if not constant_time.bytes_eq(self.derive(key_material), expected_key): - raise InvalidKey +KBKDFHMAC = rust_openssl.kdf.KBKDFHMAC +KeyDerivationFunction.register(KBKDFHMAC) class KBKDFCMAC(KeyDerivationFunction): diff --git a/src/rust/src/backend/kdf.rs b/src/rust/src/backend/kdf.rs index 68e4af677a70..2942cc485072 100644 --- a/src/rust/src/backend/kdf.rs +++ b/src/rust/src/backend/kdf.rs @@ -1696,11 +1696,306 @@ impl ConcatKdfHmac { } } +// NO-COVERAGE-START +#[pyo3::pyclass( + module = "cryptography.hazmat.primitives.kdf.kbkdf", + name = "KBKDFHMAC" +)] +// NO-COVERAGE-END +struct KbkdfHmac { + algorithm: pyo3::Py, + length: usize, + params: KbkdfValidatedParams, + used: bool, +} + +fn int_to_bytes(value: usize, length: usize) -> Vec { + let mut bytes = Vec::with_capacity(length); + for i in (0..length).rev() { + bytes.push(((value >> (i * 8)) & 0xff) as u8); + } + bytes +} + +struct KbkdfValidatedParams { + rlen: usize, + llen: Option, + location: pyo3::Py, + label: Option>, + context: Option>, + fixed: Option>, + break_location: Option, +} + +#[allow(clippy::too_many_arguments)] +fn validate_kbkdf_parameters( + py: pyo3::Python<'_>, + mode: pyo3::Py, + rlen: usize, + llen: Option, + location: pyo3::Py, + label: Option>, + context: Option>, + fixed: Option>, + break_location: Option, +) -> CryptographyResult { + let mode_bound = mode.bind(py); + let mode_type = crate::types::KBKDF_MODE.get(py)?; + if !mode_bound.is_instance(&mode_type)? { + return Err(CryptographyError::from( + pyo3::exceptions::PyTypeError::new_err("mode must be of type Mode"), + )); + } + + let location_bound = location.bind(py); + let counter_location_type = crate::types::KBKDF_COUNTER_LOCATION.get(py)?; + if !location_bound.is_instance(&counter_location_type)? { + return Err(CryptographyError::from( + pyo3::exceptions::PyTypeError::new_err("location must be of type CounterLocation"), + )); + } + + let counter_location_middle_fixed = crate::types::KBKDF_COUNTER_LOCATION + .get(py)? + .getattr(pyo3::intern!(py, "MiddleFixed"))?; + if location_bound.eq(&counter_location_middle_fixed)? && break_location.is_none() { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("Please specify a break_location"), + )); + } + + if break_location.is_some() && !location_bound.eq(&counter_location_middle_fixed)? { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err( + "break_location is ignored when location is not CounterLocation.MiddleFixed", + ), + )); + } + + if (label.is_some() || context.is_some()) && fixed.is_some() { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err( + "When supplying fixed data, label and context are ignored.", + ), + )); + } + + if !(1..=4).contains(&rlen) { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("rlen must be between 1 and 4"), + )); + } + + if fixed.is_none() && llen.is_none() { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("Please specify an llen"), + )); + } + + if let Some(l) = llen { + if l == 0 { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("llen must be non-zero"), + )); + } + } + + Ok(KbkdfValidatedParams { + rlen, + llen, + location, + label, + context, + fixed, + break_location, + }) +} + +impl KbkdfHmac { + fn derive_into_buffer( + &mut self, + py: pyo3::Python<'_>, + key_material: &[u8], + output: &mut [u8], + ) -> CryptographyResult { + if self.used { + return Err(exceptions::already_finalized_error()); + } + self.used = true; + + if output.len() != self.length { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err(format!( + "buffer must be {} bytes", + self.length + )), + )); + } + + let algorithm_bound = self.algorithm.bind(py); + let digest_size = algorithm_bound + .getattr(pyo3::intern!(py, "digest_size"))? + .extract::()?; + + let fixed = self.generate_fixed_input(py)?; + + let counter_location = self.params.location.bind(py); + let counter_location_before_fixed = crate::types::KBKDF_COUNTER_LOCATION + .get(py)? + .getattr(pyo3::intern!(py, "BeforeFixed"))?; + let counter_location_after_fixed = crate::types::KBKDF_COUNTER_LOCATION + .get(py)? + .getattr(pyo3::intern!(py, "AfterFixed"))?; + + let (data_before_ctr, data_after_ctr) = if counter_location + .eq(&counter_location_before_fixed)? + { + (&b""[..], &fixed[..]) + } else if counter_location.eq(&counter_location_after_fixed)? { + (&fixed[..], &b""[..]) + } else { + // There are only 3 counter locations so this is MiddleFixed + // We validate break_location is Some when counter_location is MiddleFixed + // in the validate function + let break_loc = self.params.break_location.unwrap(); + if break_loc > fixed.len() { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("break_location offset > len(fixed)"), + )); + } + (&fixed[..break_loc], &fixed[break_loc..]) + }; + + let mut pos = 0usize; + let rounds = self.length.div_ceil(digest_size); + for i in 1..=rounds { + let mut hmac = Hmac::new_bytes(py, key_material, algorithm_bound)?; + + let counter = int_to_bytes(i, self.params.rlen); + hmac.update_bytes(data_before_ctr)?; + hmac.update_bytes(&counter)?; + hmac.update_bytes(data_after_ctr)?; + + let result = hmac.finalize_bytes()?; + + let copy_len = (self.length - pos).min(digest_size); + output[pos..pos + copy_len].copy_from_slice(&result[..copy_len]); + pos += copy_len; + } + + Ok(self.length) + } + + fn generate_fixed_input(&self, py: pyo3::Python<'_>) -> CryptographyResult> { + if let Some(ref fixed_data) = self.params.fixed { + return Ok(fixed_data.as_bytes(py).to_vec()); + } + + // llen will exist if fixed data is not provided + let l_val = int_to_bytes(self.length * 8, self.params.llen.unwrap()); + + let mut result = Vec::new(); + let label: &[u8] = self.params.label.as_ref().map_or(b"", |l| l.as_bytes(py)); + result.extend_from_slice(label); + result.push(0x00); + let context: &[u8] = self.params.context.as_ref().map_or(b"", |l| l.as_bytes(py)); + result.extend_from_slice(context); + result.extend_from_slice(&l_val); + + Ok(result) + } +} + +#[pyo3::pymethods] +impl KbkdfHmac { + #[new] + #[pyo3(signature = (algorithm, mode, length, rlen, llen, location, label, context, fixed, backend=None, *, break_location=None))] + #[allow(clippy::too_many_arguments)] + fn new( + py: pyo3::Python<'_>, + algorithm: pyo3::Py, + mode: pyo3::Py, + length: usize, + rlen: usize, + llen: Option, + location: pyo3::Py, + label: Option>, + context: Option>, + fixed: Option>, + backend: Option>, + break_location: Option, + ) -> CryptographyResult { + _ = backend; + + // Validate common KBKDF parameters + let params = validate_kbkdf_parameters( + py, + mode, + rlen, + llen, + location, + label, + context, + fixed, + break_location, + )?; + + let algorithm_bound = algorithm.bind(py); + let _md = hashes::message_digest_from_algorithm(py, algorithm_bound)?; + let digest_size = algorithm_bound + .getattr(pyo3::intern!(py, "digest_size"))? + .extract::()?; + let rounds = length.div_ceil(digest_size); + if rounds as u64 > (1u64 << (params.rlen * 8)) - 1 { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("There are too many iterations."), + )); + } + + Ok(KbkdfHmac { + algorithm, + length, + params, + used: false, + }) + } + + fn derive<'p>( + &mut self, + py: pyo3::Python<'p>, + key_material: CffiBuf<'_>, + ) -> CryptographyResult> { + Ok(pyo3::types::PyBytes::new_with(py, self.length, |output| { + self.derive_into_buffer(py, key_material.as_bytes(), output)?; + Ok(()) + })?) + } + + fn verify( + &mut self, + py: pyo3::Python<'_>, + key_material: CffiBuf<'_>, + expected_key: CffiBuf<'_>, + ) -> CryptographyResult<()> { + let actual = self.derive(py, key_material)?; + let actual_bytes = actual.as_bytes(); + let expected_bytes = expected_key.as_bytes(); + + if !constant_time::bytes_eq(actual_bytes, expected_bytes) { + return Err(CryptographyError::from(exceptions::InvalidKey::new_err( + "Keys do not match.", + ))); + } + + Ok(()) + } +} + #[pyo3::pymodule(gil_used = false)] pub(crate) mod kdf { #[pymodule_export] use super::{ - Argon2d, Argon2i, Argon2id, ConcatKdfHash, ConcatKdfHmac, Hkdf, HkdfExpand, Pbkdf2Hmac, - Scrypt, X963Kdf, + Argon2d, Argon2i, Argon2id, ConcatKdfHash, ConcatKdfHmac, Hkdf, HkdfExpand, KbkdfHmac, + Pbkdf2Hmac, Scrypt, X963Kdf, }; } diff --git a/src/rust/src/types.rs b/src/rust/src/types.rs index a1330da6baa2..30c5376f0f2e 100644 --- a/src/rust/src/types.rs +++ b/src/rust/src/types.rs @@ -602,6 +602,13 @@ pub static LEGACY_PROVIDER_LOADED: LazyPyImport = LazyPyImport::new( &["openssl", "_legacy_provider_loaded"], ); +pub static KBKDF_MODE: LazyPyImport = + LazyPyImport::new("cryptography.hazmat.primitives.kdf.kbkdf", &["Mode"]); +pub static KBKDF_COUNTER_LOCATION: LazyPyImport = LazyPyImport::new( + "cryptography.hazmat.primitives.kdf.kbkdf", + &["CounterLocation"], +); + #[cfg(test)] mod tests { use super::LazyPyImport; diff --git a/tests/hazmat/primitives/test_kbkdf.py b/tests/hazmat/primitives/test_kbkdf.py index 900c7664bb4b..415db97cb81e 100644 --- a/tests/hazmat/primitives/test_kbkdf.py +++ b/tests/hazmat/primitives/test_kbkdf.py @@ -4,6 +4,7 @@ import re +import sys import pytest @@ -113,21 +114,20 @@ def test_already_finalized(self, backend): kdf.verify(b"material", key) def test_key_length(self, backend): - kdf = KBKDFHMAC( - hashes.SHA1(), - Mode.CounterMode, - 85899345920, - 4, - 4, - CounterLocation.BeforeFixed, - b"label", - b"context", - None, - backend=backend, - ) - - with pytest.raises(ValueError): - kdf.derive(b"material") + error = OverflowError if sys.maxsize <= 2**31 else ValueError + with pytest.raises(error): + KBKDFHMAC( + hashes.SHA1(), + Mode.CounterMode, + 85899345920, + 4, + 4, + CounterLocation.BeforeFixed, + b"label", + b"context", + None, + backend=backend, + ) def test_rlen(self, backend): with pytest.raises(ValueError): @@ -302,27 +302,7 @@ def test_keyword_only_break_location(self, backend): ) def test_invalid_break_location(self, backend): - with pytest.raises( - TypeError, match=re.escape("break_location must be an integer") - ): - KBKDFHMAC( - hashes.SHA256(), - Mode.CounterMode, - 32, - 4, - 4, - CounterLocation.MiddleFixed, - b"label", - b"context", - None, - backend=backend, - break_location="0", # type: ignore[arg-type] - ) - - with pytest.raises( - ValueError, - match=re.escape("break_location must be a positive integer"), - ): + with pytest.raises(OverflowError): KBKDFHMAC( hashes.SHA256(), Mode.CounterMode, @@ -402,7 +382,7 @@ def test_ignored_break_location_after(self, backend): def test_unsupported_hash(self, backend): with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_HASH): KBKDFHMAC( - object(), # type: ignore[arg-type] + DummyHashAlgorithm(), Mode.CounterMode, 32, 4, From ec2ca6c0bf92dd8298ce11111e7161101367e059 Mon Sep 17 00:00:00 2001 From: Paul Kehrer Date: Tue, 4 Nov 2025 07:29:50 -0800 Subject: [PATCH 2/8] implement derive_into for kbkdfhmac --- CHANGELOG.rst | 1 + .../primitives/key-derivation-functions.rst | 32 ++++++++- .../hazmat/bindings/_rust/openssl/kdf.pyi | 1 + src/rust/src/backend/kdf.rs | 9 +++ tests/hazmat/primitives/test_kbkdf.py | 68 +++++++++++++++++++ 5 files changed, 109 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 580a777bfdd0..96a60c2106da 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -55,6 +55,7 @@ Changelog :class:`~cryptography.hazmat.primitives.kdf.concatkdf.ConcatKDFHMAC`, :class:`~cryptography.hazmat.primitives.kdf.argon2.Argon2id`, :class:`~cryptography.hazmat.primitives.kdf.pbkdf2.PBKDF2HMAC`, + :class:`~cryptography.hazmat.primitives.kdf.kbkdf.KBKDFHMAC`, :class:`~cryptography.hazmat.primitives.kdf.scrypt.Scrypt`, and :class:`~cryptography.hazmat.primitives.kdf.x963kdf.X963KDF` to allow deriving keys directly into pre-allocated buffers. diff --git a/docs/hazmat/primitives/key-derivation-functions.rst b/docs/hazmat/primitives/key-derivation-functions.rst index f8228c6d9586..4115624d6022 100644 --- a/docs/hazmat/primitives/key-derivation-functions.rst +++ b/docs/hazmat/primitives/key-derivation-functions.rst @@ -1090,13 +1090,40 @@ KBKDF :raises TypeError: This exception is raised if ``key_material`` is not ``bytes``. :raises cryptography.exceptions.AlreadyFinalized: This is raised when - :meth:`derive` or + :meth:`derive`, + :meth:`derive_into`, or :meth:`verify` is called more than once. Derives a new key from the input key material. + .. method:: derive_into(key_material, buffer) + + .. versionadded:: 47.0.0 + + :param key_material: The input key material. + :type key_material: :term:`bytes-like` + :param buffer: A writable buffer to write the derived key into. The + buffer must be equal to the length supplied in the + constructor. + :type buffer: :term:`bytes-like` + :return int: the number of bytes written to the buffer. + :raises ValueError: This exception is raised if the buffer length does + not match the specified ``length``. + :raises TypeError: This exception is raised if ``key_material`` or + ``buffer`` is not ``bytes``. + :raises cryptography.exceptions.AlreadyFinalized: This is raised when + :meth:`derive`, + :meth:`derive_into`, or + :meth:`verify` is + called more than + once. + + Derives a new key from the input key material and writes it into + the provided buffer. This is useful when you want to avoid allocating + new memory for the derived key. + .. method:: verify(key_material, expected_key) :param bytes key_material: The input key material. This is the same as @@ -1108,7 +1135,8 @@ KBKDF derived key does not match the expected key. :raises cryptography.exceptions.AlreadyFinalized: This is raised when - :meth:`derive` or + :meth:`derive`, + :meth:`derive_into`, or :meth:`verify` is called more than once. diff --git a/src/cryptography/hazmat/bindings/_rust/openssl/kdf.pyi b/src/cryptography/hazmat/bindings/_rust/openssl/kdf.pyi index e8fab2eda288..d807755559ab 100644 --- a/src/cryptography/hazmat/bindings/_rust/openssl/kdf.pyi +++ b/src/cryptography/hazmat/bindings/_rust/openssl/kdf.pyi @@ -181,4 +181,5 @@ class KBKDFHMAC: break_location: int | None = None, ) -> None: ... def derive(self, key_material: Buffer) -> bytes: ... + def derive_into(self, key_material: Buffer, buffer: Buffer) -> int: ... def verify(self, key_material: bytes, expected_key: bytes) -> None: ... diff --git a/src/rust/src/backend/kdf.rs b/src/rust/src/backend/kdf.rs index 2942cc485072..67a54c6f2371 100644 --- a/src/rust/src/backend/kdf.rs +++ b/src/rust/src/backend/kdf.rs @@ -1971,6 +1971,15 @@ impl KbkdfHmac { })?) } + fn derive_into( + &mut self, + py: pyo3::Python<'_>, + key_material: CffiBuf<'_>, + mut buf: CffiMutBuf<'_>, + ) -> CryptographyResult { + self.derive_into_buffer(py, key_material.as_bytes(), buf.as_mut_bytes()) + } + fn verify( &mut self, py: pyo3::Python<'_>, diff --git a/tests/hazmat/primitives/test_kbkdf.py b/tests/hazmat/primitives/test_kbkdf.py index 415db97cb81e..f3af568da065 100644 --- a/tests/hazmat/primitives/test_kbkdf.py +++ b/tests/hazmat/primitives/test_kbkdf.py @@ -113,6 +113,74 @@ def test_already_finalized(self, backend): with pytest.raises(AlreadyFinalized): kdf.verify(b"material", key) + def test_derive_into(self, backend): + kdf = KBKDFHMAC( + hashes.SHA256(), + Mode.CounterMode, + 32, + 4, + 4, + CounterLocation.BeforeFixed, + b"label", + b"context", + None, + backend=backend, + ) + buf = bytearray(32) + n = kdf.derive_into(b"material", buf) + assert n == 32 + # Verify the output matches what derive would produce + kdf2 = KBKDFHMAC( + hashes.SHA256(), + Mode.CounterMode, + 32, + 4, + 4, + CounterLocation.BeforeFixed, + b"label", + b"context", + None, + backend=backend, + ) + expected = kdf2.derive(b"material") + assert buf == expected + + @pytest.mark.parametrize(("buflen", "outlen"), [(31, 32), (33, 32)]) + def test_derive_into_buffer_incorrect_size(self, buflen, outlen, backend): + kdf = KBKDFHMAC( + hashes.SHA256(), + Mode.CounterMode, + outlen, + 4, + 4, + CounterLocation.BeforeFixed, + b"label", + b"context", + None, + backend=backend, + ) + buf = bytearray(buflen) + with pytest.raises(ValueError, match="buffer must be"): + kdf.derive_into(b"material", buf) + + def test_derive_into_already_finalized(self, backend): + kdf = KBKDFHMAC( + hashes.SHA256(), + Mode.CounterMode, + 32, + 4, + 4, + CounterLocation.BeforeFixed, + b"label", + b"context", + None, + backend=backend, + ) + buf = bytearray(32) + kdf.derive_into(b"material", buf) + with pytest.raises(AlreadyFinalized): + kdf.derive_into(b"material2", buf) + def test_key_length(self, backend): error = OverflowError if sys.maxsize <= 2**31 else ValueError with pytest.raises(error): From 65a4e0f8bcc989ff4d14628b93f921d68a56988b Mon Sep 17 00:00:00 2001 From: Paul Kehrer Date: Tue, 4 Nov 2025 10:49:01 -0800 Subject: [PATCH 3/8] code review --- src/rust/src/backend/kdf.rs | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/rust/src/backend/kdf.rs b/src/rust/src/backend/kdf.rs index 67a54c6f2371..f46c787125d2 100644 --- a/src/rust/src/backend/kdf.rs +++ b/src/rust/src/backend/kdf.rs @@ -1705,7 +1705,7 @@ impl ConcatKdfHmac { struct KbkdfHmac { algorithm: pyo3::Py, length: usize, - params: KbkdfValidatedParams, + params: KbkdfParams, used: bool, } @@ -1717,7 +1717,7 @@ fn int_to_bytes(value: usize, length: usize) -> Vec { bytes } -struct KbkdfValidatedParams { +struct KbkdfParams { rlen: usize, llen: Option, location: pyo3::Py, @@ -1738,7 +1738,7 @@ fn validate_kbkdf_parameters( context: Option>, fixed: Option>, break_location: Option, -) -> CryptographyResult { +) -> CryptographyResult { let mode_bound = mode.bind(py); let mode_type = crate::types::KBKDF_MODE.get(py)?; if !mode_bound.is_instance(&mode_type)? { @@ -1792,15 +1792,13 @@ fn validate_kbkdf_parameters( )); } - if let Some(l) = llen { - if l == 0 { - return Err(CryptographyError::from( - pyo3::exceptions::PyValueError::new_err("llen must be non-zero"), - )); - } + if llen == Some(0) { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("llen must be non-zero"), + )); } - Ok(KbkdfValidatedParams { + Ok(KbkdfParams { rlen, llen, location, From 0e3599ddae7a5b34211eba8a37b532e8a149aedb Mon Sep 17 00:00:00 2001 From: Paul Kehrer Date: Tue, 4 Nov 2025 11:03:06 -0800 Subject: [PATCH 4/8] internal enum --- src/rust/src/backend/kdf.rs | 74 ++++++++++++++++++++----------------- 1 file changed, 41 insertions(+), 33 deletions(-) diff --git a/src/rust/src/backend/kdf.rs b/src/rust/src/backend/kdf.rs index f46c787125d2..7ff3095d8a8b 100644 --- a/src/rust/src/backend/kdf.rs +++ b/src/rust/src/backend/kdf.rs @@ -1717,10 +1717,18 @@ fn int_to_bytes(value: usize, length: usize) -> Vec { bytes } +#[allow(clippy::enum_variant_names)] +#[derive(PartialEq)] +enum CounterLocation { + BeforeFixed, + AfterFixed, + MiddleFixed, +} + struct KbkdfParams { rlen: usize, llen: Option, - location: pyo3::Py, + location: CounterLocation, label: Option>, context: Option>, fixed: Option>, @@ -1748,23 +1756,32 @@ fn validate_kbkdf_parameters( } let location_bound = location.bind(py); - let counter_location_type = crate::types::KBKDF_COUNTER_LOCATION.get(py)?; - if !location_bound.is_instance(&counter_location_type)? { + let counter_location = crate::types::KBKDF_COUNTER_LOCATION.get(py)?; + if !location_bound.is_instance(&counter_location)? { return Err(CryptographyError::from( pyo3::exceptions::PyTypeError::new_err("location must be of type CounterLocation"), )); } - let counter_location_middle_fixed = crate::types::KBKDF_COUNTER_LOCATION - .get(py)? - .getattr(pyo3::intern!(py, "MiddleFixed"))?; - if location_bound.eq(&counter_location_middle_fixed)? && break_location.is_none() { + let counter_location_before_fixed = + counter_location.getattr(pyo3::intern!(py, "BeforeFixed"))?; + let counter_location_after_fixed = counter_location.getattr(pyo3::intern!(py, "AfterFixed"))?; + let rust_location = if location_bound.eq(&counter_location_before_fixed)? { + CounterLocation::BeforeFixed + } else if location_bound.eq(&counter_location_after_fixed)? { + CounterLocation::AfterFixed + } else { + // There are only 3 options so this is MiddleFixed + CounterLocation::MiddleFixed + }; + + if rust_location == CounterLocation::MiddleFixed && break_location.is_none() { return Err(CryptographyError::from( pyo3::exceptions::PyValueError::new_err("Please specify a break_location"), )); } - if break_location.is_some() && !location_bound.eq(&counter_location_middle_fixed)? { + if break_location.is_some() && rust_location != CounterLocation::MiddleFixed { return Err(CryptographyError::from( pyo3::exceptions::PyValueError::new_err( "break_location is ignored when location is not CounterLocation.MiddleFixed", @@ -1801,7 +1818,7 @@ fn validate_kbkdf_parameters( Ok(KbkdfParams { rlen, llen, - location, + location: rust_location, label, context, fixed, @@ -1837,31 +1854,22 @@ impl KbkdfHmac { let fixed = self.generate_fixed_input(py)?; - let counter_location = self.params.location.bind(py); - let counter_location_before_fixed = crate::types::KBKDF_COUNTER_LOCATION - .get(py)? - .getattr(pyo3::intern!(py, "BeforeFixed"))?; - let counter_location_after_fixed = crate::types::KBKDF_COUNTER_LOCATION - .get(py)? - .getattr(pyo3::intern!(py, "AfterFixed"))?; - - let (data_before_ctr, data_after_ctr) = if counter_location - .eq(&counter_location_before_fixed)? - { - (&b""[..], &fixed[..]) - } else if counter_location.eq(&counter_location_after_fixed)? { - (&fixed[..], &b""[..]) - } else { - // There are only 3 counter locations so this is MiddleFixed - // We validate break_location is Some when counter_location is MiddleFixed - // in the validate function - let break_loc = self.params.break_location.unwrap(); - if break_loc > fixed.len() { - return Err(CryptographyError::from( - pyo3::exceptions::PyValueError::new_err("break_location offset > len(fixed)"), - )); + let (data_before_ctr, data_after_ctr) = match &self.params.location { + CounterLocation::BeforeFixed => (&b""[..], &fixed[..]), + CounterLocation::AfterFixed => (&fixed[..], &b""[..]), + CounterLocation::MiddleFixed => { + // We validate break_location is Some when counter_location is MiddleFixed + // in the validate function + let break_loc = self.params.break_location.unwrap(); + if break_loc > fixed.len() { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err( + "break_location offset > len(fixed)", + ), + )); + } + (&fixed[..break_loc], &fixed[break_loc..]) } - (&fixed[..break_loc], &fixed[break_loc..]) }; let mut pos = 0usize; From cb183f040fcdbe0f6a98a7a80d4b4d8eceda2eb0 Mon Sep 17 00:00:00 2001 From: Paul Kehrer Date: Sat, 8 Nov 2025 12:46:53 -0800 Subject: [PATCH 5/8] use python int_to_bytes to handle arbitrary llen sizes properly --- src/rust/src/backend/kdf.rs | 21 +++++++++------------ src/rust/src/types.rs | 2 ++ 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/src/rust/src/backend/kdf.rs b/src/rust/src/backend/kdf.rs index 7ff3095d8a8b..e01194650168 100644 --- a/src/rust/src/backend/kdf.rs +++ b/src/rust/src/backend/kdf.rs @@ -14,6 +14,7 @@ use crate::backend::hmac::Hmac; use crate::buf::{CffiBuf, CffiMutBuf}; use crate::error::{CryptographyError, CryptographyResult}; use crate::exceptions; +use crate::types; // NO-COVERAGE-START #[pyo3::pyclass( @@ -1709,14 +1710,6 @@ struct KbkdfHmac { used: bool, } -fn int_to_bytes(value: usize, length: usize) -> Vec { - let mut bytes = Vec::with_capacity(length); - for i in (0..length).rev() { - bytes.push(((value >> (i * 8)) & 0xff) as u8); - } - bytes -} - #[allow(clippy::enum_variant_names)] #[derive(PartialEq)] enum CounterLocation { @@ -1877,9 +1870,10 @@ impl KbkdfHmac { for i in 1..=rounds { let mut hmac = Hmac::new_bytes(py, key_material, algorithm_bound)?; - let counter = int_to_bytes(i, self.params.rlen); + let py_counter = types::INT_TO_BYTES.get(py)?.call1((i, self.params.rlen))?; + let counter = py_counter.extract::<&[u8]>()?; hmac.update_bytes(data_before_ctr)?; - hmac.update_bytes(&counter)?; + hmac.update_bytes(counter)?; hmac.update_bytes(data_after_ctr)?; let result = hmac.finalize_bytes()?; @@ -1898,7 +1892,10 @@ impl KbkdfHmac { } // llen will exist if fixed data is not provided - let l_val = int_to_bytes(self.length * 8, self.params.llen.unwrap()); + let py_l_val = types::INT_TO_BYTES + .get(py)? + .call1((self.length * 8, self.params.llen.unwrap()))?; + let l_val = py_l_val.extract::<&[u8]>()?; let mut result = Vec::new(); let label: &[u8] = self.params.label.as_ref().map_or(b"", |l| l.as_bytes(py)); @@ -1906,7 +1903,7 @@ impl KbkdfHmac { result.push(0x00); let context: &[u8] = self.params.context.as_ref().map_or(b"", |l| l.as_bytes(py)); result.extend_from_slice(context); - result.extend_from_slice(&l_val); + result.extend_from_slice(l_val); Ok(result) } diff --git a/src/rust/src/types.rs b/src/rust/src/types.rs index 30c5376f0f2e..d522bf1eeeeb 100644 --- a/src/rust/src/types.rs +++ b/src/rust/src/types.rs @@ -47,6 +47,8 @@ pub static DEPRECATED_IN_42: LazyPyImport = pub static DEPRECATED_IN_43: LazyPyImport = LazyPyImport::new("cryptography.utils", &["DeprecatedIn43"]); +pub static INT_TO_BYTES: LazyPyImport = LazyPyImport::new("cryptography.utils", &["int_to_bytes"]); + pub static ENCODING: LazyPyImport = LazyPyImport::new( "cryptography.hazmat.primitives.serialization", &["Encoding"], From 4427698f6c3a8e57d40489c31043baf8f7d3b970 Mon Sep 17 00:00:00 2001 From: Paul Kehrer Date: Sat, 8 Nov 2025 14:31:09 -0800 Subject: [PATCH 6/8] alternate solution --- src/rust/src/asn1.rs | 34 +++++++++++++++++++++++++--------- src/rust/src/backend/kdf.rs | 18 +++++++++--------- src/rust/src/types.rs | 2 -- 3 files changed, 34 insertions(+), 20 deletions(-) diff --git a/src/rust/src/asn1.rs b/src/rust/src/asn1.rs index cb71dccbe62a..6caeee3de27f 100644 --- a/src/rust/src/asn1.rs +++ b/src/rust/src/asn1.rs @@ -71,22 +71,38 @@ pub(crate) fn py_uint_to_big_endian_bytes<'p>( py: pyo3::Python<'p>, v: pyo3::Bound<'p, pyo3::types::PyInt>, ) -> pyo3::PyResult { - if v.lt(0)? { - return Err(pyo3::exceptions::PyValueError::new_err( - "Negative integers are not supported", - )); - } - + reject_negative_integer(&v)?; // Round the length up so that we prefix an extra \x00. This ensures that // integers that'd have the high bit set in their first octet are not // encoded as negative in DER. - let n = v + let length = v .call_method0(pyo3::intern!(py, "bit_length"))? .extract::()? / 8 + 1; - Ok(v.call_method1(pyo3::intern!(py, "to_bytes"), (n, "big"))? - .extract()?) + py_uint_to_be_bytes_with_length(py, v, length) +} + +fn reject_negative_integer(v: &pyo3::Bound<'_, pyo3::types::PyInt>) -> pyo3::PyResult<()> { + if v.lt(0)? { + Err(pyo3::exceptions::PyValueError::new_err( + "Negative integers are not supported", + )) + } else { + Ok(()) + } +} + +pub(crate) fn py_uint_to_be_bytes_with_length<'p>( + py: pyo3::Python<'p>, + v: pyo3::Bound<'p, pyo3::types::PyInt>, + length: usize, +) -> pyo3::PyResult { + reject_negative_integer(&v)?; + Ok( + v.call_method1(pyo3::intern!(py, "to_bytes"), (length, "big"))? + .extract()?, + ) } pub(crate) fn encode_der_data<'p>( diff --git a/src/rust/src/backend/kdf.rs b/src/rust/src/backend/kdf.rs index e01194650168..a96bb2c71cf9 100644 --- a/src/rust/src/backend/kdf.rs +++ b/src/rust/src/backend/kdf.rs @@ -9,12 +9,12 @@ use base64::engine::Engine; use cryptography_crypto::constant_time; use pyo3::types::{PyAnyMethods, PyBytesMethods}; +use crate::asn1::py_uint_to_be_bytes_with_length; use crate::backend::hashes; use crate::backend::hmac::Hmac; use crate::buf::{CffiBuf, CffiMutBuf}; use crate::error::{CryptographyError, CryptographyResult}; use crate::exceptions; -use crate::types; // NO-COVERAGE-START #[pyo3::pyclass( @@ -1870,10 +1870,10 @@ impl KbkdfHmac { for i in 1..=rounds { let mut hmac = Hmac::new_bytes(py, key_material, algorithm_bound)?; - let py_counter = types::INT_TO_BYTES.get(py)?.call1((i, self.params.rlen))?; - let counter = py_counter.extract::<&[u8]>()?; + let py_i = pyo3::types::PyInt::new(py, i); + let counter = py_uint_to_be_bytes_with_length(py, py_i, self.params.rlen)?; hmac.update_bytes(data_before_ctr)?; - hmac.update_bytes(counter)?; + hmac.update_bytes(counter.as_ref())?; hmac.update_bytes(data_after_ctr)?; let result = hmac.finalize_bytes()?; @@ -1892,10 +1892,10 @@ impl KbkdfHmac { } // llen will exist if fixed data is not provided - let py_l_val = types::INT_TO_BYTES - .get(py)? - .call1((self.length * 8, self.params.llen.unwrap()))?; - let l_val = py_l_val.extract::<&[u8]>()?; + let py_bitlength = pyo3::types::PyInt::new(py, self.length) + .mul(8)? + .extract::>()?; + let l_val = py_uint_to_be_bytes_with_length(py, py_bitlength, self.params.llen.unwrap())?; let mut result = Vec::new(); let label: &[u8] = self.params.label.as_ref().map_or(b"", |l| l.as_bytes(py)); @@ -1903,7 +1903,7 @@ impl KbkdfHmac { result.push(0x00); let context: &[u8] = self.params.context.as_ref().map_or(b"", |l| l.as_bytes(py)); result.extend_from_slice(context); - result.extend_from_slice(l_val); + result.extend_from_slice(l_val.as_ref()); Ok(result) } diff --git a/src/rust/src/types.rs b/src/rust/src/types.rs index d522bf1eeeeb..30c5376f0f2e 100644 --- a/src/rust/src/types.rs +++ b/src/rust/src/types.rs @@ -47,8 +47,6 @@ pub static DEPRECATED_IN_42: LazyPyImport = pub static DEPRECATED_IN_43: LazyPyImport = LazyPyImport::new("cryptography.utils", &["DeprecatedIn43"]); -pub static INT_TO_BYTES: LazyPyImport = LazyPyImport::new("cryptography.utils", &["int_to_bytes"]); - pub static ENCODING: LazyPyImport = LazyPyImport::new( "cryptography.hazmat.primitives.serialization", &["Encoding"], From ac46d2ae7e9ea85b303e3da7dc2b963327b82dcc Mon Sep 17 00:00:00 2001 From: Paul Kehrer Date: Sat, 8 Nov 2025 20:33:05 -0800 Subject: [PATCH 7/8] feedback --- src/rust/src/asn1.rs | 17 +++++----------- src/rust/src/backend/kdf.rs | 39 +++++++++++++++++++------------------ 2 files changed, 25 insertions(+), 31 deletions(-) diff --git a/src/rust/src/asn1.rs b/src/rust/src/asn1.rs index 6caeee3de27f..f99dec1864d1 100644 --- a/src/rust/src/asn1.rs +++ b/src/rust/src/asn1.rs @@ -71,7 +71,6 @@ pub(crate) fn py_uint_to_big_endian_bytes<'p>( py: pyo3::Python<'p>, v: pyo3::Bound<'p, pyo3::types::PyInt>, ) -> pyo3::PyResult { - reject_negative_integer(&v)?; // Round the length up so that we prefix an extra \x00. This ensures that // integers that'd have the high bit set in their first octet are not // encoded as negative in DER. @@ -83,22 +82,16 @@ pub(crate) fn py_uint_to_big_endian_bytes<'p>( py_uint_to_be_bytes_with_length(py, v, length) } -fn reject_negative_integer(v: &pyo3::Bound<'_, pyo3::types::PyInt>) -> pyo3::PyResult<()> { - if v.lt(0)? { - Err(pyo3::exceptions::PyValueError::new_err( - "Negative integers are not supported", - )) - } else { - Ok(()) - } -} - pub(crate) fn py_uint_to_be_bytes_with_length<'p>( py: pyo3::Python<'p>, v: pyo3::Bound<'p, pyo3::types::PyInt>, length: usize, ) -> pyo3::PyResult { - reject_negative_integer(&v)?; + if v.lt(0)? { + return Err(pyo3::exceptions::PyValueError::new_err( + "Negative integers are not supported", + )); + } Ok( v.call_method1(pyo3::intern!(py, "to_bytes"), (length, "big"))? .extract()?, diff --git a/src/rust/src/backend/kdf.rs b/src/rust/src/backend/kdf.rs index a96bb2c71cf9..55663cd2b97b 100644 --- a/src/rust/src/backend/kdf.rs +++ b/src/rust/src/backend/kdf.rs @@ -1715,7 +1715,7 @@ struct KbkdfHmac { enum CounterLocation { BeforeFixed, AfterFixed, - MiddleFixed, + MiddleFixed(usize), } struct KbkdfParams { @@ -1725,7 +1725,6 @@ struct KbkdfParams { label: Option>, context: Option>, fixed: Option>, - break_location: Option, } #[allow(clippy::too_many_arguments)] @@ -1765,16 +1764,15 @@ fn validate_kbkdf_parameters( CounterLocation::AfterFixed } else { // There are only 3 options so this is MiddleFixed - CounterLocation::MiddleFixed + if break_location.is_none() { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("Please specify a break_location"), + )); + } + CounterLocation::MiddleFixed(break_location.unwrap()) }; - if rust_location == CounterLocation::MiddleFixed && break_location.is_none() { - return Err(CryptographyError::from( - pyo3::exceptions::PyValueError::new_err("Please specify a break_location"), - )); - } - - if break_location.is_some() && rust_location != CounterLocation::MiddleFixed { + if break_location.is_some() && !matches!(rust_location, CounterLocation::MiddleFixed(_)) { return Err(CryptographyError::from( pyo3::exceptions::PyValueError::new_err( "break_location is ignored when location is not CounterLocation.MiddleFixed", @@ -1815,7 +1813,6 @@ fn validate_kbkdf_parameters( label, context, fixed, - break_location, }) } @@ -1847,21 +1844,20 @@ impl KbkdfHmac { let fixed = self.generate_fixed_input(py)?; - let (data_before_ctr, data_after_ctr) = match &self.params.location { + let (data_before_ctr, data_after_ctr) = match self.params.location { CounterLocation::BeforeFixed => (&b""[..], &fixed[..]), CounterLocation::AfterFixed => (&fixed[..], &b""[..]), - CounterLocation::MiddleFixed => { + CounterLocation::MiddleFixed(break_location) => { // We validate break_location is Some when counter_location is MiddleFixed // in the validate function - let break_loc = self.params.break_location.unwrap(); - if break_loc > fixed.len() { + if break_location > fixed.len() { return Err(CryptographyError::from( pyo3::exceptions::PyValueError::new_err( "break_location offset > len(fixed)", ), )); } - (&fixed[..break_loc], &fixed[break_loc..]) + (&fixed[..break_location], &fixed[break_location..]) } }; @@ -1892,9 +1888,14 @@ impl KbkdfHmac { } // llen will exist if fixed data is not provided - let py_bitlength = pyo3::types::PyInt::new(py, self.length) - .mul(8)? - .extract::>()?; + let py_bitlength = pyo3::types::PyInt::new( + py, + self.length + .checked_mul(8) + .ok_or(pyo3::exceptions::PyOverflowError::new_err( + "Length too large, would cause overflow in bit length calculation", + ))?, + ); let l_val = py_uint_to_be_bytes_with_length(py, py_bitlength, self.params.llen.unwrap())?; let mut result = Vec::new(); From be87ae283c76097eb6220ed437f6d027aa1cae4b Mon Sep 17 00:00:00 2001 From: Paul Kehrer Date: Sat, 8 Nov 2025 20:44:51 -0800 Subject: [PATCH 8/8] ellipsis --- src/rust/src/backend/kdf.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/rust/src/backend/kdf.rs b/src/rust/src/backend/kdf.rs index 55663cd2b97b..3101543cc430 100644 --- a/src/rust/src/backend/kdf.rs +++ b/src/rust/src/backend/kdf.rs @@ -1764,12 +1764,14 @@ fn validate_kbkdf_parameters( CounterLocation::AfterFixed } else { // There are only 3 options so this is MiddleFixed - if break_location.is_none() { - return Err(CryptographyError::from( - pyo3::exceptions::PyValueError::new_err("Please specify a break_location"), - )); + match break_location { + Some(break_location) => CounterLocation::MiddleFixed(break_location), + None => { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("Please specify a break_location"), + )) + } } - CounterLocation::MiddleFixed(break_location.unwrap()) }; if break_location.is_some() && !matches!(rust_location, CounterLocation::MiddleFixed(_)) { @@ -1848,8 +1850,6 @@ impl KbkdfHmac { CounterLocation::BeforeFixed => (&b""[..], &fixed[..]), CounterLocation::AfterFixed => (&fixed[..], &b""[..]), CounterLocation::MiddleFixed(break_location) => { - // We validate break_location is Some when counter_location is MiddleFixed - // in the validate function if break_location > fixed.len() { return Err(CryptographyError::from( pyo3::exceptions::PyValueError::new_err(