diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 9894665ab412..929cc043147a 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -18,6 +18,11 @@ Changelog * Removed the deprecated ``CAST5``, ``SEED``, ``IDEA``, and ``Blowfish`` classes from the cipher module. These are still available in :doc:`/hazmat/decrepit/index`. +* Make instances of + :class:`~cryptography.hazmat.primitives.hashes.HashAlgorithm` as well as + instances of classes in + :mod:`~cryptography.hazmat.primitives.asymmetric.padding` + comparable. .. _v45-0-6: diff --git a/docs/hazmat/primitives/asymmetric/cloudhsm.rst b/docs/hazmat/primitives/asymmetric/cloudhsm.rst index 8934133a228a..09ae30f6fc87 100644 --- a/docs/hazmat/primitives/asymmetric/cloudhsm.rst +++ b/docs/hazmat/primitives/asymmetric/cloudhsm.rst @@ -88,7 +88,7 @@ if you only need a subset of functionality. ... Maps the cryptography padding and algorithm to the corresponding KMS signing algorithm. ... This is specific to your implementation. ... """ - ... if isinstance(padding, PKCS1v15) and isinstance(algorithm, hashes.SHA256): + ... if padding == PKCS1v15() and algorithm == hashes.SHA256(): ... return b"RSA_PKCS1_V1_5_SHA_256" ... else: ... raise NotImplementedError() diff --git a/docs/x509/reference.rst b/docs/x509/reference.rst index 74d6da68bad4..6acc63a0f4bc 100644 --- a/docs/x509/reference.rst +++ b/docs/x509/reference.rst @@ -248,7 +248,7 @@ Loading Certificate Revocation Lists >>> from cryptography import x509 >>> from cryptography.hazmat.primitives import hashes >>> crl = x509.load_pem_x509_crl(pem_crl_data) - >>> isinstance(crl.signature_hash_algorithm, hashes.SHA256) + >>> crl.signature_hash_algorithm == hashes.SHA256() True .. function:: load_der_x509_crl(data) @@ -287,7 +287,7 @@ Loading Certificate Signing Requests >>> from cryptography import x509 >>> from cryptography.hazmat.primitives import hashes >>> csr = x509.load_pem_x509_csr(pem_req_data) - >>> isinstance(csr.signature_hash_algorithm, hashes.SHA256) + >>> csr.signature_hash_algorithm == hashes.SHA256() True .. function:: load_der_x509_csr(data) @@ -477,7 +477,7 @@ X.509 Certificate Object .. doctest:: >>> from cryptography.hazmat.primitives import hashes - >>> isinstance(cert.signature_hash_algorithm, hashes.SHA256) + >>> cert.signature_hash_algorithm == hashes.SHA256() True .. attribute:: signature_algorithm_oid @@ -716,7 +716,7 @@ X.509 CRL (Certificate Revocation List) Object .. doctest:: >>> from cryptography.hazmat.primitives import hashes - >>> isinstance(crl.signature_hash_algorithm, hashes.SHA256) + >>> crl.signature_hash_algorithm == hashes.SHA256() True .. attribute:: signature_algorithm_oid @@ -1119,7 +1119,7 @@ X.509 CSR (Certificate Signing Request) Object .. doctest:: >>> from cryptography.hazmat.primitives import hashes - >>> isinstance(csr.signature_hash_algorithm, hashes.SHA256) + >>> csr.signature_hash_algorithm == hashes.SHA256() True .. attribute:: signature_algorithm_oid diff --git a/src/cryptography/hazmat/primitives/asymmetric/padding.py b/src/cryptography/hazmat/primitives/asymmetric/padding.py index 5121a288fcc7..6448227ee298 100644 --- a/src/cryptography/hazmat/primitives/asymmetric/padding.py +++ b/src/cryptography/hazmat/primitives/asymmetric/padding.py @@ -5,6 +5,7 @@ from __future__ import annotations import abc +import typing from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives._asymmetric import ( @@ -16,6 +17,9 @@ class PKCS1v15(AsymmetricPadding): name = "EMSA-PKCS1-v1_5" + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, PKCS1v15) + class _MaxLength: "Sentinel value for `MAX_LENGTH`." @@ -56,6 +60,18 @@ def __init__( self._salt_length = salt_length + def __eq__(self, other: typing.Any) -> bool: + if isinstance(self._salt_length, int): + eq_salt_length = self._salt_length == other._salt_length + else: + eq_salt_length = self._salt_length is other._salt_length + + return ( + isinstance(other, PSS) + and eq_salt_length + and self._mgf == other._mgf + ) + @property def mgf(self) -> MGF: return self._mgf @@ -77,6 +93,14 @@ def __init__( self._algorithm = algorithm self._label = label + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(other, OAEP) + and self._mgf == other._mgf + and self._algorithm == other._algorithm + and self._label == other._label + ) + @property def algorithm(self) -> hashes.HashAlgorithm: return self._algorithm @@ -89,6 +113,13 @@ def mgf(self) -> MGF: class MGF(metaclass=abc.ABCMeta): _algorithm: hashes.HashAlgorithm + @abc.abstractmethod + def __eq__(self, other: typing.Any) -> bool: + """ + Implement equality checking. + """ + ... + class MGF1(MGF): def __init__(self, algorithm: hashes.HashAlgorithm): @@ -97,6 +128,9 @@ def __init__(self, algorithm: hashes.HashAlgorithm): self._algorithm = algorithm + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, MGF1) and self._algorithm == other._algorithm + def calculate_max_pss_salt_length( key: rsa.RSAPrivateKey | rsa.RSAPublicKey, diff --git a/src/cryptography/hazmat/primitives/hashes.py b/src/cryptography/hazmat/primitives/hashes.py index 4b55ec33dbff..d3e6fbbc64fc 100644 --- a/src/cryptography/hazmat/primitives/hashes.py +++ b/src/cryptography/hazmat/primitives/hashes.py @@ -5,6 +5,7 @@ from __future__ import annotations import abc +import typing from cryptography.hazmat.bindings._rust import openssl as rust_openssl from cryptography.utils import Buffer @@ -36,6 +37,13 @@ class HashAlgorithm(metaclass=abc.ABCMeta): + @abc.abstractmethod + def __eq__(self, other: typing.Any) -> bool: + """ + Implement equality checking. + """ + ... + @property @abc.abstractmethod def name(self) -> str: @@ -103,66 +111,99 @@ class SHA1(HashAlgorithm): digest_size = 20 block_size = 64 + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, SHA1) + class SHA512_224(HashAlgorithm): # noqa: N801 name = "sha512-224" digest_size = 28 block_size = 128 + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, SHA512_224) + class SHA512_256(HashAlgorithm): # noqa: N801 name = "sha512-256" digest_size = 32 block_size = 128 + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, SHA512_256) + class SHA224(HashAlgorithm): name = "sha224" digest_size = 28 block_size = 64 + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, SHA224) + class SHA256(HashAlgorithm): name = "sha256" digest_size = 32 block_size = 64 + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, SHA256) + class SHA384(HashAlgorithm): name = "sha384" digest_size = 48 block_size = 128 + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, SHA384) + class SHA512(HashAlgorithm): name = "sha512" digest_size = 64 block_size = 128 + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, SHA512) + class SHA3_224(HashAlgorithm): # noqa: N801 name = "sha3-224" digest_size = 28 block_size = None + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, SHA3_224) + class SHA3_256(HashAlgorithm): # noqa: N801 name = "sha3-256" digest_size = 32 block_size = None + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, SHA3_256) + class SHA3_384(HashAlgorithm): # noqa: N801 name = "sha3-384" digest_size = 48 block_size = None + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, SHA3_384) + class SHA3_512(HashAlgorithm): # noqa: N801 name = "sha3-512" digest_size = 64 block_size = None + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, SHA3_512) + class SHAKE128(HashAlgorithm, ExtendableOutputFunction): name = "shake128" @@ -177,6 +218,12 @@ def __init__(self, digest_size: int): self._digest_size = digest_size + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(other, SHAKE128) + and self._digest_size == other._digest_size + ) + @property def digest_size(self) -> int: return self._digest_size @@ -195,6 +242,12 @@ def __init__(self, digest_size: int): self._digest_size = digest_size + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(other, SHAKE256) + and self._digest_size == other._digest_size + ) + @property def digest_size(self) -> int: return self._digest_size @@ -205,6 +258,9 @@ class MD5(HashAlgorithm): digest_size = 16 block_size = 64 + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, MD5) + class BLAKE2b(HashAlgorithm): name = "blake2b" @@ -218,6 +274,12 @@ def __init__(self, digest_size: int): self._digest_size = digest_size + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(other, BLAKE2b) + and self._digest_size == other._digest_size + ) + @property def digest_size(self) -> int: return self._digest_size @@ -235,6 +297,12 @@ def __init__(self, digest_size: int): self._digest_size = digest_size + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(other, BLAKE2s) + and self._digest_size == other._digest_size + ) + @property def digest_size(self) -> int: return self._digest_size @@ -244,3 +312,6 @@ class SM3(HashAlgorithm): name = "sm3" digest_size = 32 block_size = 64 + + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, SM3) diff --git a/tests/doubles.py b/tests/doubles.py index cf2c96a3e83c..760fc1ba7c49 100644 --- a/tests/doubles.py +++ b/tests/doubles.py @@ -2,6 +2,7 @@ # 2.0, and the BSD License. See the LICENSE file in the root of this repository # for complete details. +import typing from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import padding @@ -40,6 +41,12 @@ class DummyHashAlgorithm(hashes.HashAlgorithm): def __init__(self, digest_size: int = 32) -> None: self._digest_size = digest_size + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(self, DummyHashAlgorithm) + and self._digest_size == other._digest_size + ) + @property def digest_size(self) -> int: return self._digest_size diff --git a/tests/hazmat/backends/test_openssl.py b/tests/hazmat/backends/test_openssl.py index a48dc653f033..e8e89efb3f9d 100644 --- a/tests/hazmat/backends/test_openssl.py +++ b/tests/hazmat/backends/test_openssl.py @@ -4,6 +4,7 @@ import itertools +import typing import pytest @@ -32,6 +33,9 @@ class DummyMGF(padding.MGF): _salt_length = 0 _algorithm = hashes.SHA1() + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, DummyMGF) + class TestOpenSSL: def test_backend_exists(self): diff --git a/tests/hazmat/primitives/test_hashes.py b/tests/hazmat/primitives/test_hashes.py index 092ba9af41d4..0076d439961b 100644 --- a/tests/hazmat/primitives/test_hashes.py +++ b/tests/hazmat/primitives/test_hashes.py @@ -12,7 +12,7 @@ from ...doubles import DummyHashAlgorithm from ...utils import raises_unsupported_algorithm -from .utils import generate_base_hash_test +from .utils import generate_base_hash_test, generate_eq_hash_test class TestHashContext: @@ -52,6 +52,7 @@ class TestSHA1: hashes.SHA1(), digest_size=20, ) + test_sha1_eq = generate_eq_hash_test(hashes.SHA1()) @pytest.mark.supported( @@ -63,6 +64,7 @@ class TestSHA224: hashes.SHA224(), digest_size=28, ) + test_sha224_eq = generate_eq_hash_test(hashes.SHA224()) @pytest.mark.supported( @@ -74,6 +76,7 @@ class TestSHA256: hashes.SHA256(), digest_size=32, ) + test_sha256_eq = generate_eq_hash_test(hashes.SHA256()) @pytest.mark.supported( @@ -85,6 +88,7 @@ class TestSHA384: hashes.SHA384(), digest_size=48, ) + test_sha384_eq = generate_eq_hash_test(hashes.SHA384()) @pytest.mark.supported( @@ -96,6 +100,79 @@ class TestSHA512: hashes.SHA512(), digest_size=64, ) + test_sha512_eq = generate_eq_hash_test(hashes.SHA512()) + + +@pytest.mark.supported( + only_if=lambda backend: backend.hash_supported(hashes.SHA512_224()), + skip_message="Does not support SHA512 224", +) +class TestSHA512224: + test_sha512_224 = generate_base_hash_test( + hashes.SHA512_224(), + digest_size=28, + ) + test_sha512_224_eq = generate_eq_hash_test(hashes.SHA512_224()) + + +@pytest.mark.supported( + only_if=lambda backend: backend.hash_supported(hashes.SHA512_256()), + skip_message="Does not support SHA512 256", +) +class TestSHA512256: + test_sha512_256 = generate_base_hash_test( + hashes.SHA512_256(), + digest_size=32, + ) + test_sha512_256_eq = generate_eq_hash_test(hashes.SHA512_256()) + + +@pytest.mark.supported( + only_if=lambda backend: backend.hash_supported(hashes.SHA3_224()), + skip_message="Does not support SHA3 224", +) +class TestSHA3224: + test_sha3_224 = generate_base_hash_test( + hashes.SHA3_224(), + digest_size=28, + ) + test_sha3_224_eq = generate_eq_hash_test(hashes.SHA3_224()) + + +@pytest.mark.supported( + only_if=lambda backend: backend.hash_supported(hashes.SHA3_256()), + skip_message="Does not support SHA3 256", +) +class TestSHA3256: + test_sha3_256 = generate_base_hash_test( + hashes.SHA3_256(), + digest_size=32, + ) + test_sha3_256_eq = generate_eq_hash_test(hashes.SHA3_256()) + + +@pytest.mark.supported( + only_if=lambda backend: backend.hash_supported(hashes.SHA3_384()), + skip_message="Does not support SHA3 384", +) +class TestSHA3384: + test_sha3_384 = generate_base_hash_test( + hashes.SHA3_384(), + digest_size=48, + ) + test_sha3_384_eq = generate_eq_hash_test(hashes.SHA3_384()) + + +@pytest.mark.supported( + only_if=lambda backend: backend.hash_supported(hashes.SHA3_512()), + skip_message="Does not support SHA3 512", +) +class TestSHA3512: + test_sha3_512 = generate_base_hash_test( + hashes.SHA3_512(), + digest_size=64, + ) + test_sha3_512_eq = generate_eq_hash_test(hashes.SHA3_512()) @pytest.mark.supported( @@ -107,6 +184,7 @@ class TestMD5: hashes.MD5(), digest_size=16, ) + test_md5_eq = generate_eq_hash_test(hashes.MD5()) @pytest.mark.supported( @@ -120,6 +198,7 @@ class TestBLAKE2b: hashes.BLAKE2b(digest_size=64), digest_size=64, ) + test_blake2b_eq = generate_eq_hash_test(hashes.BLAKE2b(digest_size=64)) def test_invalid_digest_size(self, backend): with pytest.raises(ValueError): @@ -143,6 +222,7 @@ class TestBLAKE2s: hashes.BLAKE2s(digest_size=32), digest_size=32, ) + test_blake2s_eq = generate_eq_hash_test(hashes.BLAKE2s(digest_size=32)) def test_invalid_digest_size(self, backend): with pytest.raises(ValueError): @@ -165,6 +245,14 @@ def test_buffer_protocol_hash(backend): class TestSHAKE: + @pytest.mark.parametrize("xof", [hashes.SHAKE128, hashes.SHAKE256]) + def test_eq(self, xof): + value_one = xof(digest_size=32) + value_two = xof(digest_size=32) # identical + value_three = xof(digest_size=64) + assert value_one == value_two + assert value_one != value_three + @pytest.mark.parametrize("xof", [hashes.SHAKE128, hashes.SHAKE256]) def test_invalid_digest_type(self, xof): with pytest.raises(TypeError): @@ -188,3 +276,4 @@ class TestSM3: hashes.SM3(), digest_size=32, ) + test_sm3_eq = generate_eq_hash_test(hashes.SM3()) diff --git a/tests/hazmat/primitives/test_rsa.py b/tests/hazmat/primitives/test_rsa.py index 25edfb07592c..3011d4de7822 100644 --- a/tests/hazmat/primitives/test_rsa.py +++ b/tests/hazmat/primitives/test_rsa.py @@ -7,6 +7,7 @@ import copy import itertools import os +import typing import pytest @@ -70,6 +71,9 @@ class DummyMGF(padding.MGF): _salt_length = 0 _algorithm = hashes.SHA256() + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, DummyMGF) + def _check_fips_key_length(backend, private_key): if ( @@ -1603,6 +1607,14 @@ class TestRSAPKCS1Verification: ) +class TestPKCS1v15: + def test_eq(self): + assert padding.PKCS1v15() == padding.PKCS1v15() + assert padding.PKCS1v15() != padding.PSS( + mgf=padding.MGF1(hashes.SHA256()), salt_length=32 + ) + + class TestPSS: def test_calculate_max_pss_salt_length(self): with pytest.raises(TypeError): @@ -1644,8 +1656,68 @@ def test_mgf_property(self): assert pss.mgf == mgf assert pss.mgf == pss._mgf + @pytest.mark.parametrize("xof", [hashes.SHA256(), hashes.SHA512()]) + @pytest.mark.parametrize( + "salt_length", + [ + 1, + 32, + padding.PSS.MAX_LENGTH, + padding.PSS.AUTO, + padding.PSS.DIGEST_LENGTH, + ], + ) + def test_eq( + self, xof: hashes.HashAlgorithm, salt_length: typing.Any + ) -> None: + assert padding.PSS( + salt_length=salt_length, mgf=padding.MGF1(algorithm=xof) + ) == padding.PSS( + salt_length=salt_length, mgf=padding.MGF1(algorithm=xof) + ) + + @pytest.mark.parametrize( + "salt_length", + [ + 1, + 32, + padding.PSS.MAX_LENGTH, + padding.PSS.AUTO, + padding.PSS.DIGEST_LENGTH, + ], + ) + def test_not_eq_with_different_salt_length( + self, salt_length: typing.Any + ) -> None: + xof = hashes.SHA256() + assert padding.PSS( + salt_length=salt_length, mgf=padding.MGF1(algorithm=xof) + ) != padding.PSS(salt_length=64, mgf=padding.MGF1(algorithm=xof)) + + def test_not_eq_with_salt_length_object_identity(self) -> None: + xof = hashes.SHA256() + assert padding.PSS( + salt_length=padding.PSS.AUTO, mgf=padding.MGF1(algorithm=xof) + ) != padding.PSS( + salt_length=padding.PSS.DIGEST_LENGTH, + mgf=padding.MGF1(algorithm=xof), + ) + + def test_not_eq_with_different_mgf(self) -> None: + assert padding.PSS( + salt_length=padding.PSS.AUTO, + mgf=padding.MGF1(algorithm=hashes.SHA256()), + ) != padding.PSS( + salt_length=padding.PSS.AUTO, + mgf=padding.MGF1(algorithm=hashes.SHA512()), + ) + class TestMGF1: + def test_eq(self) -> None: + assert padding.MGF1(hashes.SHA256()) == padding.MGF1(hashes.SHA256()) + assert padding.MGF1(hashes.SHA256()) != padding.MGF1(hashes.SHA512()) + def test_invalid_hash_algorithm(self): with pytest.raises(TypeError): padding.MGF1(b"not_a_hash") # type:ignore[arg-type] @@ -1680,6 +1752,49 @@ def test_mgf_property(self): assert oaep.mgf == mgf assert oaep.mgf == oaep._mgf + @pytest.mark.parametrize("xof", [hashes.SHA256(), hashes.SHA512()]) + @pytest.mark.parametrize("label", [None, b"", b"foo"]) + def test_eq( + self, xof: hashes.HashAlgorithm, label: typing.Optional[bytes] + ) -> None: + mgf = padding.MGF1(algorithm=xof) + assert padding.OAEP( + mgf=mgf, algorithm=xof, label=label + ) == padding.OAEP(mgf=mgf, algorithm=xof, label=label) + + def test_not_eq_with_different_mgf(self) -> None: + assert padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA256()), + algorithm=hashes.SHA256(), + label=None, + ) != padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA512()), + algorithm=hashes.SHA256(), + label=None, + ) + + def test_not_eq_with_different_algorithm(self) -> None: + assert padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA512()), + algorithm=hashes.SHA512(), + label=None, + ) != padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA512()), + algorithm=hashes.SHA256(), + label=None, + ) + + def test_not_eq_with_different_label(self) -> None: + assert padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA512()), + algorithm=hashes.SHA256(), + label=None, + ) != padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA512()), + algorithm=hashes.SHA256(), + label=b"", + ) + class TestRSADecryption: @pytest.mark.supported( diff --git a/tests/hazmat/primitives/utils.py b/tests/hazmat/primitives/utils.py index aad324683a81..af53155410aa 100644 --- a/tests/hazmat/primitives/utils.py +++ b/tests/hazmat/primitives/utils.py @@ -35,6 +35,7 @@ Mode, ) +from ...doubles import DummyHashAlgorithm from ...utils import load_vectors_from_file @@ -207,6 +208,14 @@ def test_base_hash(self, backend): return test_base_hash +def generate_eq_hash_test(algorithm): + def test_eq(self): + assert algorithm == algorithm + assert algorithm != DummyHashAlgorithm() + + return test_eq + + def base_hash_test(backend, algorithm, digest_size): m = hashes.Hash(algorithm, backend=backend) assert m.algorithm.digest_size == digest_size