Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/cryptography/hazmat/bindings/_rust/openssl/kdf.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,26 @@ class X963KDF:
) -> None: ...
def derive(self, key_material: Buffer) -> bytes: ...
def verify(self, key_material: bytes, expected_key: bytes) -> None: ...

class ConcatKDFHash:
def __init__(
self,
algorithm: HashAlgorithm,
length: int,
otherinfo: bytes | None,
backend: typing.Any = None,
) -> None: ...
def derive(self, key_material: Buffer) -> bytes: ...
def verify(self, key_material: bytes, expected_key: bytes) -> None: ...

class ConcatKDFHMAC:
def __init__(
self,
algorithm: HashAlgorithm,
length: int,
salt: bytes | None,
otherinfo: bytes | None,
backend: typing.Any = None,
) -> None: ...
def derive(self, key_material: Buffer) -> bytes: ...
def verify(self, key_material: bytes, expected_key: bytes) -> None: ...
121 changes: 6 additions & 115 deletions src/cryptography/hazmat/primitives/kdf/concatkdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,122 +4,13 @@

from __future__ import annotations

import typing
from collections.abc import Callable

from cryptography import utils
from cryptography.exceptions import AlreadyFinalized, InvalidKey
from cryptography.hazmat.primitives import constant_time, hashes, hmac
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
from cryptography.hazmat.primitives.kdf import KeyDerivationFunction

ConcatKDFHash = rust_openssl.kdf.ConcatKDFHash
ConcatKDFHMAC = rust_openssl.kdf.ConcatKDFHMAC

def _int_to_u32be(n: int) -> bytes:
return n.to_bytes(length=4, byteorder="big")


def _common_args_checks(
algorithm: hashes.HashAlgorithm,
length: int,
otherinfo: bytes | None,
) -> None:
max_length = algorithm.digest_size * (2**32 - 1)
if length > max_length:
raise ValueError(f"Cannot derive keys larger than {max_length} bits.")
if otherinfo is not None:
utils._check_bytes("otherinfo", otherinfo)


def _concatkdf_derive(
key_material: utils.Buffer,
length: int,
auxfn: Callable[[], hashes.HashContext],
otherinfo: bytes,
) -> bytes:
utils._check_byteslike("key_material", key_material)
output = [b""]
outlen = 0
counter = 1

while length > outlen:
h = auxfn()
h.update(_int_to_u32be(counter))
h.update(key_material)
h.update(otherinfo)
output.append(h.finalize())
outlen += len(output[-1])
counter += 1

return b"".join(output)[:length]


class ConcatKDFHash(KeyDerivationFunction):
def __init__(
self,
algorithm: hashes.HashAlgorithm,
length: int,
otherinfo: bytes | None,
backend: typing.Any = None,
):
_common_args_checks(algorithm, length, otherinfo)
self._algorithm = algorithm
self._length = length
self._otherinfo: bytes = otherinfo if otherinfo is not None else b""

self._used = False

def _hash(self) -> hashes.Hash:
return hashes.Hash(self._algorithm)

def derive(self, key_material: utils.Buffer) -> bytes:
if self._used:
raise AlreadyFinalized
self._used = True
return _concatkdf_derive(
key_material, self._length, self._hash, self._otherinfo
)

def verify(self, key_material: bytes, expected_key: bytes) -> None:
if not constant_time.bytes_eq(self.derive(key_material), expected_key):
raise InvalidKey


class ConcatKDFHMAC(KeyDerivationFunction):
def __init__(
self,
algorithm: hashes.HashAlgorithm,
length: int,
salt: bytes | None,
otherinfo: bytes | None,
backend: typing.Any = None,
):
_common_args_checks(algorithm, length, otherinfo)
self._algorithm = algorithm
self._length = length
self._otherinfo: bytes = otherinfo if otherinfo is not None else b""

if algorithm.block_size is None:
raise TypeError(f"{algorithm.name} is unsupported for ConcatKDF")

if salt is None:
salt = b"\x00" * algorithm.block_size
else:
utils._check_bytes("salt", salt)

self._salt = salt

self._used = False

def _hmac(self) -> hmac.HMAC:
return hmac.HMAC(self._salt, self._algorithm)

def derive(self, key_material: utils.Buffer) -> bytes:
if self._used:
raise AlreadyFinalized
self._used = True
return _concatkdf_derive(
key_material, self._length, self._hmac, self._otherinfo
)
KeyDerivationFunction.register(ConcatKDFHash)
KeyDerivationFunction.register(ConcatKDFHMAC)

def verify(self, key_material: bytes, expected_key: bytes) -> None:
if not constant_time.bytes_eq(self.derive(key_material), expected_key):
raise InvalidKey
__all__ = ["ConcatKDFHMAC", "ConcatKDFHash"]
Loading
Loading