Skip to content

Commit ccfe7a2

Browse files
committed
migrate concatkdf{hash,hmac} to rust
1 parent 6ac1384 commit ccfe7a2

File tree

4 files changed

+263
-118
lines changed

4 files changed

+263
-118
lines changed

src/cryptography/hazmat/bindings/_rust/openssl/kdf.pyi

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,26 @@ class X963KDF:
9191
) -> None: ...
9292
def derive(self, key_material: Buffer) -> bytes: ...
9393
def verify(self, key_material: bytes, expected_key: bytes) -> None: ...
94+
95+
class ConcatKDFHash:
96+
def __init__(
97+
self,
98+
algorithm: HashAlgorithm,
99+
length: int,
100+
otherinfo: bytes | None,
101+
backend: typing.Any = None,
102+
) -> None: ...
103+
def derive(self, key_material: Buffer) -> bytes: ...
104+
def verify(self, key_material: bytes, expected_key: bytes) -> None: ...
105+
106+
class ConcatKDFHMAC:
107+
def __init__(
108+
self,
109+
algorithm: HashAlgorithm,
110+
length: int,
111+
salt: bytes | None,
112+
otherinfo: bytes | None,
113+
backend: typing.Any = None,
114+
) -> None: ...
115+
def derive(self, key_material: Buffer) -> bytes: ...
116+
def verify(self, key_material: bytes, expected_key: bytes) -> None: ...

src/cryptography/hazmat/primitives/kdf/concatkdf.py

Lines changed: 6 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -4,122 +4,13 @@
44

55
from __future__ import annotations
66

7-
import typing
8-
from collections.abc import Callable
9-
10-
from cryptography import utils
11-
from cryptography.exceptions import AlreadyFinalized, InvalidKey
12-
from cryptography.hazmat.primitives import constant_time, hashes, hmac
7+
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
138
from cryptography.hazmat.primitives.kdf import KeyDerivationFunction
149

10+
ConcatKDFHash = rust_openssl.kdf.ConcatKDFHash
11+
ConcatKDFHMAC = rust_openssl.kdf.ConcatKDFHMAC
1512

16-
def _int_to_u32be(n: int) -> bytes:
17-
return n.to_bytes(length=4, byteorder="big")
18-
19-
20-
def _common_args_checks(
21-
algorithm: hashes.HashAlgorithm,
22-
length: int,
23-
otherinfo: bytes | None,
24-
) -> None:
25-
max_length = algorithm.digest_size * (2**32 - 1)
26-
if length > max_length:
27-
raise ValueError(f"Cannot derive keys larger than {max_length} bits.")
28-
if otherinfo is not None:
29-
utils._check_bytes("otherinfo", otherinfo)
30-
31-
32-
def _concatkdf_derive(
33-
key_material: utils.Buffer,
34-
length: int,
35-
auxfn: Callable[[], hashes.HashContext],
36-
otherinfo: bytes,
37-
) -> bytes:
38-
utils._check_byteslike("key_material", key_material)
39-
output = [b""]
40-
outlen = 0
41-
counter = 1
42-
43-
while length > outlen:
44-
h = auxfn()
45-
h.update(_int_to_u32be(counter))
46-
h.update(key_material)
47-
h.update(otherinfo)
48-
output.append(h.finalize())
49-
outlen += len(output[-1])
50-
counter += 1
51-
52-
return b"".join(output)[:length]
53-
54-
55-
class ConcatKDFHash(KeyDerivationFunction):
56-
def __init__(
57-
self,
58-
algorithm: hashes.HashAlgorithm,
59-
length: int,
60-
otherinfo: bytes | None,
61-
backend: typing.Any = None,
62-
):
63-
_common_args_checks(algorithm, length, otherinfo)
64-
self._algorithm = algorithm
65-
self._length = length
66-
self._otherinfo: bytes = otherinfo if otherinfo is not None else b""
67-
68-
self._used = False
69-
70-
def _hash(self) -> hashes.Hash:
71-
return hashes.Hash(self._algorithm)
72-
73-
def derive(self, key_material: utils.Buffer) -> bytes:
74-
if self._used:
75-
raise AlreadyFinalized
76-
self._used = True
77-
return _concatkdf_derive(
78-
key_material, self._length, self._hash, self._otherinfo
79-
)
80-
81-
def verify(self, key_material: bytes, expected_key: bytes) -> None:
82-
if not constant_time.bytes_eq(self.derive(key_material), expected_key):
83-
raise InvalidKey
84-
85-
86-
class ConcatKDFHMAC(KeyDerivationFunction):
87-
def __init__(
88-
self,
89-
algorithm: hashes.HashAlgorithm,
90-
length: int,
91-
salt: bytes | None,
92-
otherinfo: bytes | None,
93-
backend: typing.Any = None,
94-
):
95-
_common_args_checks(algorithm, length, otherinfo)
96-
self._algorithm = algorithm
97-
self._length = length
98-
self._otherinfo: bytes = otherinfo if otherinfo is not None else b""
99-
100-
if algorithm.block_size is None:
101-
raise TypeError(f"{algorithm.name} is unsupported for ConcatKDF")
102-
103-
if salt is None:
104-
salt = b"\x00" * algorithm.block_size
105-
else:
106-
utils._check_bytes("salt", salt)
107-
108-
self._salt = salt
109-
110-
self._used = False
111-
112-
def _hmac(self) -> hmac.HMAC:
113-
return hmac.HMAC(self._salt, self._algorithm)
114-
115-
def derive(self, key_material: utils.Buffer) -> bytes:
116-
if self._used:
117-
raise AlreadyFinalized
118-
self._used = True
119-
return _concatkdf_derive(
120-
key_material, self._length, self._hmac, self._otherinfo
121-
)
13+
KeyDerivationFunction.register(ConcatKDFHash)
14+
KeyDerivationFunction.register(ConcatKDFHMAC)
12215

123-
def verify(self, key_material: bytes, expected_key: bytes) -> None:
124-
if not constant_time.bytes_eq(self.derive(key_material), expected_key):
125-
raise InvalidKey
16+
__all__ = ["ConcatKDFHMAC", "ConcatKDFHash"]

src/rust/src/backend/kdf.rs

Lines changed: 229 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -942,8 +942,236 @@ impl X963Kdf {
942942
}
943943
}
944944

945+
// NO-COVERAGE-START
946+
#[pyo3::pyclass(
947+
module = "cryptography.hazmat.primitives.kdf.concatkdf",
948+
name = "ConcatKDFHash"
949+
)]
950+
// NO-COVERAGE-END
951+
struct ConcatKdfHash {
952+
algorithm: pyo3::Py<pyo3::PyAny>,
953+
length: usize,
954+
otherinfo: pyo3::Py<pyo3::types::PyBytes>,
955+
used: bool,
956+
}
957+
958+
#[pyo3::pymethods]
959+
impl ConcatKdfHash {
960+
#[new]
961+
#[pyo3(signature = (algorithm, length, otherinfo, backend=None))]
962+
fn new(
963+
py: pyo3::Python<'_>,
964+
algorithm: pyo3::Py<pyo3::PyAny>,
965+
length: usize,
966+
otherinfo: Option<pyo3::Py<pyo3::types::PyBytes>>,
967+
backend: Option<pyo3::Bound<'_, pyo3::PyAny>>,
968+
) -> CryptographyResult<Self> {
969+
_ = backend;
970+
971+
let algorithm_bound = algorithm.bind(py);
972+
let digest_size = algorithm_bound
973+
.getattr(pyo3::intern!(py, "digest_size"))?
974+
.extract::<usize>()?;
975+
976+
let max_len = digest_size.saturating_mul(u32::MAX as usize);
977+
if length > max_len {
978+
return Err(CryptographyError::from(
979+
pyo3::exceptions::PyValueError::new_err(format!(
980+
"Cannot derive keys larger than {max_len} bits."
981+
)),
982+
));
983+
}
984+
985+
let otherinfo_bytes =
986+
otherinfo.unwrap_or_else(|| pyo3::types::PyBytes::new(py, b"").into());
987+
988+
Ok(ConcatKdfHash {
989+
algorithm,
990+
length,
991+
otherinfo: otherinfo_bytes,
992+
used: false,
993+
})
994+
}
995+
996+
fn derive<'p>(
997+
&mut self,
998+
py: pyo3::Python<'p>,
999+
key_material: CffiBuf<'_>,
1000+
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
1001+
if self.used {
1002+
return Err(exceptions::already_finalized_error());
1003+
}
1004+
self.used = true;
1005+
1006+
let algorithm_bound = self.algorithm.bind(py);
1007+
1008+
let mut output = Vec::with_capacity(self.length);
1009+
let mut counter = 1u32;
1010+
1011+
while output.len() < self.length {
1012+
let mut hash_obj = hashes::Hash::new(py, algorithm_bound, None)?;
1013+
hash_obj.update_bytes(&counter.to_be_bytes())?;
1014+
hash_obj.update_bytes(key_material.as_bytes())?;
1015+
hash_obj.update_bytes(self.otherinfo.as_bytes(py))?;
1016+
let block = hash_obj.finalize(py)?;
1017+
output.extend_from_slice(block.as_bytes());
1018+
counter += 1;
1019+
}
1020+
1021+
output.truncate(self.length);
1022+
Ok(pyo3::types::PyBytes::new(py, &output))
1023+
}
1024+
1025+
fn verify(
1026+
&mut self,
1027+
py: pyo3::Python<'_>,
1028+
key_material: CffiBuf<'_>,
1029+
expected_key: CffiBuf<'_>,
1030+
) -> CryptographyResult<()> {
1031+
let actual = self.derive(py, key_material)?;
1032+
let actual_bytes = actual.as_bytes();
1033+
let expected_bytes = expected_key.as_bytes();
1034+
1035+
if !constant_time::bytes_eq(actual_bytes, expected_bytes) {
1036+
return Err(CryptographyError::from(exceptions::InvalidKey::new_err(
1037+
"Keys do not match.",
1038+
)));
1039+
}
1040+
1041+
Ok(())
1042+
}
1043+
}
1044+
1045+
// NO-COVERAGE-START
1046+
#[pyo3::pyclass(
1047+
module = "cryptography.hazmat.primitives.kdf.concatkdf",
1048+
name = "ConcatKDFHMAC"
1049+
)]
1050+
// NO-COVERAGE-END
1051+
struct ConcatKdfHmac {
1052+
algorithm: pyo3::Py<pyo3::PyAny>,
1053+
length: usize,
1054+
salt: pyo3::Py<pyo3::types::PyBytes>,
1055+
otherinfo: pyo3::Py<pyo3::types::PyBytes>,
1056+
used: bool,
1057+
}
1058+
1059+
#[pyo3::pymethods]
1060+
impl ConcatKdfHmac {
1061+
#[new]
1062+
#[pyo3(signature = (algorithm, length, salt, otherinfo, backend=None))]
1063+
fn new(
1064+
py: pyo3::Python<'_>,
1065+
algorithm: pyo3::Py<pyo3::PyAny>,
1066+
length: usize,
1067+
salt: Option<pyo3::Py<pyo3::types::PyBytes>>,
1068+
otherinfo: Option<pyo3::Py<pyo3::types::PyBytes>>,
1069+
backend: Option<pyo3::Bound<'_, pyo3::PyAny>>,
1070+
) -> CryptographyResult<Self> {
1071+
_ = backend;
1072+
1073+
let algorithm_bound = algorithm.bind(py);
1074+
let digest_size = algorithm_bound
1075+
.getattr(pyo3::intern!(py, "digest_size"))?
1076+
.extract::<usize>()?;
1077+
1078+
let max_len = digest_size.saturating_mul(u32::MAX as usize);
1079+
if length > max_len {
1080+
return Err(CryptographyError::from(
1081+
pyo3::exceptions::PyValueError::new_err(format!(
1082+
"Cannot derive keys larger than {max_len} bits."
1083+
)),
1084+
));
1085+
}
1086+
1087+
// Check for block_size (required for HMAC)
1088+
let block_size = algorithm_bound.getattr(pyo3::intern!(py, "block_size"))?;
1089+
1090+
if block_size.is_none() {
1091+
let name = algorithm_bound
1092+
.getattr(pyo3::intern!(py, "name"))?
1093+
.extract::<String>()?;
1094+
return Err(CryptographyError::from(
1095+
pyo3::exceptions::PyTypeError::new_err(format!(
1096+
"{name} is unsupported for ConcatKDF"
1097+
)),
1098+
));
1099+
}
1100+
1101+
let block_size_val = block_size.extract::<usize>()?;
1102+
1103+
// Default salt to zeros of block_size length
1104+
let salt_bytes = if let Some(s) = salt {
1105+
s
1106+
} else {
1107+
pyo3::types::PyBytes::new(py, &vec![0u8; block_size_val]).into()
1108+
};
1109+
1110+
let otherinfo_bytes =
1111+
otherinfo.unwrap_or_else(|| pyo3::types::PyBytes::new(py, b"").into());
1112+
1113+
Ok(ConcatKdfHmac {
1114+
algorithm,
1115+
length,
1116+
salt: salt_bytes,
1117+
otherinfo: otherinfo_bytes,
1118+
used: false,
1119+
})
1120+
}
1121+
1122+
fn derive<'p>(
1123+
&mut self,
1124+
py: pyo3::Python<'p>,
1125+
key_material: CffiBuf<'_>,
1126+
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
1127+
if self.used {
1128+
return Err(exceptions::already_finalized_error());
1129+
}
1130+
self.used = true;
1131+
1132+
let algorithm_bound = self.algorithm.bind(py);
1133+
1134+
let mut output = Vec::with_capacity(self.length);
1135+
let mut counter = 1u32;
1136+
1137+
while output.len() < self.length {
1138+
let mut hmac = Hmac::new_bytes(py, self.salt.as_bytes(py), algorithm_bound)?;
1139+
hmac.update_bytes(&counter.to_be_bytes())?;
1140+
hmac.update_bytes(key_material.as_bytes())?;
1141+
hmac.update_bytes(self.otherinfo.as_bytes(py))?;
1142+
let result = hmac.finalize_bytes()?;
1143+
output.extend_from_slice(&result);
1144+
counter += 1;
1145+
}
1146+
1147+
output.truncate(self.length);
1148+
Ok(pyo3::types::PyBytes::new(py, &output))
1149+
}
1150+
1151+
fn verify(
1152+
&mut self,
1153+
py: pyo3::Python<'_>,
1154+
key_material: CffiBuf<'_>,
1155+
expected_key: CffiBuf<'_>,
1156+
) -> CryptographyResult<()> {
1157+
let actual = self.derive(py, key_material)?;
1158+
let actual_bytes = actual.as_bytes();
1159+
let expected_bytes = expected_key.as_bytes();
1160+
1161+
if !constant_time::bytes_eq(actual_bytes, expected_bytes) {
1162+
return Err(CryptographyError::from(exceptions::InvalidKey::new_err(
1163+
"Keys do not match.",
1164+
)));
1165+
}
1166+
1167+
Ok(())
1168+
}
1169+
}
1170+
9451171
#[pyo3::pymodule(gil_used = false)]
9461172
pub(crate) mod kdf {
9471173
#[pymodule_export]
948-
use super::{Argon2id, Hkdf, HkdfExpand, Pbkdf2Hmac, Scrypt, X963Kdf};
1174+
use super::{
1175+
Argon2id, ConcatKdfHash, ConcatKdfHmac, Hkdf, HkdfExpand, Pbkdf2Hmac, Scrypt, X963Kdf,
1176+
};
9491177
}

0 commit comments

Comments
 (0)