Skip to content

Commit 6ac1384

Browse files
authored
migrate X963KDF to rust (#13789)
* migrate X963KDF to rust * fix test * don't skip test on 32-bit
1 parent d8975e2 commit 6ac1384

File tree

4 files changed

+127
-54
lines changed

4 files changed

+127
-54
lines changed

src/cryptography/hazmat/bindings/_rust/openssl/kdf.pyi

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,14 @@ class HKDFExpand:
8080
def derive(self, key_material: Buffer) -> bytes: ...
8181
def derive_into(self, key_material: Buffer, buffer: Buffer) -> int: ...
8282
def verify(self, key_material: bytes, expected_key: bytes) -> None: ...
83+
84+
class X963KDF:
85+
def __init__(
86+
self,
87+
algorithm: HashAlgorithm,
88+
length: int,
89+
sharedinfo: bytes | None,
90+
backend: typing.Any = None,
91+
) -> None: ...
92+
def derive(self, key_material: Buffer) -> bytes: ...
93+
def verify(self, key_material: bytes, expected_key: bytes) -> None: ...

src/cryptography/hazmat/primitives/kdf/x963kdf.py

Lines changed: 4 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -4,58 +4,10 @@
44

55
from __future__ import annotations
66

7-
import typing
8-
9-
from cryptography import utils
10-
from cryptography.exceptions import AlreadyFinalized, InvalidKey
11-
from cryptography.hazmat.primitives import constant_time, hashes
7+
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
128
from cryptography.hazmat.primitives.kdf import KeyDerivationFunction
139

10+
X963KDF = rust_openssl.kdf.X963KDF
11+
KeyDerivationFunction.register(X963KDF)
1412

15-
def _int_to_u32be(n: int) -> bytes:
16-
return n.to_bytes(length=4, byteorder="big")
17-
18-
19-
class X963KDF(KeyDerivationFunction):
20-
def __init__(
21-
self,
22-
algorithm: hashes.HashAlgorithm,
23-
length: int,
24-
sharedinfo: bytes | None,
25-
backend: typing.Any = None,
26-
):
27-
max_len = algorithm.digest_size * (2**32 - 1)
28-
if length > max_len:
29-
raise ValueError(f"Cannot derive keys larger than {max_len} bits.")
30-
if sharedinfo is not None:
31-
utils._check_bytes("sharedinfo", sharedinfo)
32-
33-
self._algorithm = algorithm
34-
self._length = length
35-
self._sharedinfo = sharedinfo
36-
self._used = False
37-
38-
def derive(self, key_material: utils.Buffer) -> bytes:
39-
if self._used:
40-
raise AlreadyFinalized
41-
self._used = True
42-
utils._check_byteslike("key_material", key_material)
43-
output = [b""]
44-
outlen = 0
45-
counter = 1
46-
47-
while self._length > outlen:
48-
h = hashes.Hash(self._algorithm)
49-
h.update(key_material)
50-
h.update(_int_to_u32be(counter))
51-
if self._sharedinfo is not None:
52-
h.update(self._sharedinfo)
53-
output.append(h.finalize())
54-
outlen += len(output[-1])
55-
counter += 1
56-
57-
return b"".join(output)[: self._length]
58-
59-
def verify(self, key_material: bytes, expected_key: bytes) -> None:
60-
if not constant_time.bytes_eq(self.derive(key_material), expected_key):
61-
raise InvalidKey
13+
__all__ = ["X963KDF"]

src/rust/src/backend/kdf.rs

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -834,8 +834,116 @@ impl HkdfExpand {
834834
}
835835
}
836836

837+
// NO-COVERAGE-START
838+
#[pyo3::pyclass(
839+
module = "cryptography.hazmat.primitives.kdf.x963kdf",
840+
name = "X963KDF"
841+
)]
842+
// NO-COVERAGE-END
843+
struct X963Kdf {
844+
algorithm: pyo3::Py<pyo3::PyAny>,
845+
length: usize,
846+
sharedinfo: Option<pyo3::Py<pyo3::types::PyBytes>>,
847+
used: bool,
848+
}
849+
850+
#[pyo3::pymethods]
851+
impl X963Kdf {
852+
#[new]
853+
#[pyo3(signature = (algorithm, length, sharedinfo, backend=None))]
854+
fn new(
855+
py: pyo3::Python<'_>,
856+
algorithm: pyo3::Py<pyo3::PyAny>,
857+
length: usize,
858+
sharedinfo: Option<pyo3::Py<pyo3::types::PyBytes>>,
859+
backend: Option<pyo3::Bound<'_, pyo3::PyAny>>,
860+
) -> CryptographyResult<Self> {
861+
_ = backend;
862+
863+
let digest_size = algorithm
864+
.bind(py)
865+
.getattr(pyo3::intern!(py, "digest_size"))?
866+
.extract::<usize>()?;
867+
868+
let max_len = digest_size.saturating_mul(u32::MAX as usize);
869+
870+
if length > max_len {
871+
return Err(CryptographyError::from(
872+
pyo3::exceptions::PyValueError::new_err(format!(
873+
"Cannot derive keys larger than {max_len} bits."
874+
)),
875+
));
876+
}
877+
878+
Ok(X963Kdf {
879+
algorithm,
880+
length,
881+
sharedinfo,
882+
used: false,
883+
})
884+
}
885+
886+
fn derive<'p>(
887+
&mut self,
888+
py: pyo3::Python<'p>,
889+
key_material: CffiBuf<'_>,
890+
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
891+
if self.used {
892+
return Err(exceptions::already_finalized_error());
893+
}
894+
self.used = true;
895+
896+
let algorithm_bound = self.algorithm.bind(py);
897+
let digest_size = algorithm_bound
898+
.getattr(pyo3::intern!(py, "digest_size"))?
899+
.extract::<usize>()?;
900+
901+
Ok(pyo3::types::PyBytes::new_with(py, self.length, |output| {
902+
let mut pos = 0usize;
903+
let mut counter = 1u32;
904+
905+
while pos < self.length {
906+
let mut hash_obj = hashes::Hash::new(py, algorithm_bound, None)?;
907+
hash_obj.update_bytes(key_material.as_bytes())?;
908+
hash_obj.update_bytes(&counter.to_be_bytes())?;
909+
if let Some(ref sharedinfo) = self.sharedinfo {
910+
hash_obj.update_bytes(sharedinfo.as_bytes(py))?;
911+
}
912+
let block = hash_obj.finalize(py)?;
913+
let block_bytes = block.as_bytes();
914+
915+
let copy_len = (self.length - pos).min(digest_size);
916+
output[pos..pos + copy_len].copy_from_slice(&block_bytes[..copy_len]);
917+
pos += copy_len;
918+
counter += 1;
919+
}
920+
921+
Ok(())
922+
})?)
923+
}
924+
925+
fn verify(
926+
&mut self,
927+
py: pyo3::Python<'_>,
928+
key_material: CffiBuf<'_>,
929+
expected_key: CffiBuf<'_>,
930+
) -> CryptographyResult<()> {
931+
let actual = self.derive(py, key_material)?;
932+
let actual_bytes = actual.as_bytes();
933+
let expected_bytes = expected_key.as_bytes();
934+
935+
if !constant_time::bytes_eq(actual_bytes, expected_bytes) {
936+
return Err(CryptographyError::from(exceptions::InvalidKey::new_err(
937+
"Keys do not match.",
938+
)));
939+
}
940+
941+
Ok(())
942+
}
943+
}
944+
837945
#[pyo3::pymodule(gil_used = false)]
838946
pub(crate) mod kdf {
839947
#[pymodule_export]
840-
use super::{Argon2id, Hkdf, HkdfExpand, Pbkdf2Hmac, Scrypt};
948+
use super::{Argon2id, Hkdf, HkdfExpand, Pbkdf2Hmac, Scrypt, X963Kdf};
841949
}

tests/hazmat/primitives/test_x963kdf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55

66
import binascii
7+
import sys
78

89
import pytest
910

@@ -15,8 +16,9 @@
1516
class TestX963KDF:
1617
def test_length_limit(self, backend):
1718
big_length = hashes.SHA256().digest_size * (2**32 - 1) + 1
19+
error = OverflowError if sys.maxsize <= 2**31 else ValueError
1820

19-
with pytest.raises(ValueError):
21+
with pytest.raises(error):
2022
X963KDF(hashes.SHA256(), big_length, None, backend)
2123

2224
def test_already_finalized(self, backend):

0 commit comments

Comments
 (0)