Skip to content

Commit f84ec6c

Browse files
committed
add __eq__ for padding classes
1 parent 22ef165 commit f84ec6c

File tree

4 files changed

+156
-0
lines changed

4 files changed

+156
-0
lines changed

CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ Changelog
1818
* Removed the deprecated ``CAST5``, ``SEED``, ``IDEA``, and ``Blowfish``
1919
classes from the cipher module. These are still available in
2020
:doc:`/hazmat/decrepit/index`.
21+
* Make instances of
22+
:class:`~cryptography.hazmat.primitives.hashes.HashAlgorithm` as well as
23+
instances of classes in
24+
:mod:`~cryptography.hazmat.primitives.asymmetric.padding`
25+
comparable.
2126

2227
.. _v45-0-6:
2328

src/cryptography/hazmat/primitives/asymmetric/padding.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import abc
8+
import typing
89

910
from cryptography.hazmat.primitives import hashes
1011
from cryptography.hazmat.primitives._asymmetric import (
@@ -16,6 +17,9 @@
1617
class PKCS1v15(AsymmetricPadding):
1718
name = "EMSA-PKCS1-v1_5"
1819

20+
def __eq__(self, other: typing.Any) -> bool:
21+
return isinstance(other, PKCS1v15)
22+
1923

2024
class _MaxLength:
2125
"Sentinel value for `MAX_LENGTH`."
@@ -56,6 +60,18 @@ def __init__(
5660

5761
self._salt_length = salt_length
5862

63+
def __eq__(self, other: typing.Any) -> bool:
64+
if isinstance(self._salt_length, int):
65+
eq_salt_length = self._salt_length == other._salt_length
66+
else:
67+
eq_salt_length = self._salt_length is other._salt_length
68+
69+
return (
70+
isinstance(other, PSS)
71+
and eq_salt_length
72+
and self._mgf == other._mgf
73+
)
74+
5975
@property
6076
def mgf(self) -> MGF:
6177
return self._mgf
@@ -77,6 +93,14 @@ def __init__(
7793
self._algorithm = algorithm
7894
self._label = label
7995

96+
def __eq__(self, other: typing.Any) -> bool:
97+
return (
98+
isinstance(other, OAEP)
99+
and self._mgf == other._mgf
100+
and self._algorithm == other._algorithm
101+
and self._label == other._label
102+
)
103+
80104
@property
81105
def algorithm(self) -> hashes.HashAlgorithm:
82106
return self._algorithm
@@ -89,6 +113,13 @@ def mgf(self) -> MGF:
89113
class MGF(metaclass=abc.ABCMeta):
90114
_algorithm: hashes.HashAlgorithm
91115

116+
@abc.abstractmethod
117+
def __eq__(self, other: typing.Any) -> bool:
118+
"""
119+
Implement equality checking.
120+
"""
121+
...
122+
92123

93124
class MGF1(MGF):
94125
def __init__(self, algorithm: hashes.HashAlgorithm):
@@ -97,6 +128,9 @@ def __init__(self, algorithm: hashes.HashAlgorithm):
97128

98129
self._algorithm = algorithm
99130

131+
def __eq__(self, other: typing.Any) -> bool:
132+
return isinstance(other, MGF1) and self._algorithm == other._algorithm
133+
100134

101135
def calculate_max_pss_salt_length(
102136
key: rsa.RSAPrivateKey | rsa.RSAPublicKey,

tests/hazmat/backends/test_openssl.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55

66
import itertools
7+
import typing
78

89
import pytest
910

@@ -32,6 +33,9 @@ class DummyMGF(padding.MGF):
3233
_salt_length = 0
3334
_algorithm = hashes.SHA1()
3435

36+
def __eq__(self, other: typing.Any) -> bool:
37+
return isinstance(other, DummyMGF)
38+
3539

3640
class TestOpenSSL:
3741
def test_backend_exists(self):

tests/hazmat/primitives/test_rsa.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import copy
88
import itertools
99
import os
10+
import typing
1011

1112
import pytest
1213

@@ -70,6 +71,9 @@ class DummyMGF(padding.MGF):
7071
_salt_length = 0
7172
_algorithm = hashes.SHA256()
7273

74+
def __eq__(self, other: typing.Any) -> bool:
75+
return isinstance(other, DummyMGF)
76+
7377

7478
def _check_fips_key_length(backend, private_key):
7579
if (
@@ -1603,6 +1607,14 @@ class TestRSAPKCS1Verification:
16031607
)
16041608

16051609

1610+
class TestPKCS1v15:
1611+
def test_eq(self):
1612+
assert padding.PKCS1v15() == padding.PKCS1v15()
1613+
assert padding.PKCS1v15() != padding.PSS(
1614+
mgf=padding.MGF1(hashes.SHA256()), salt_length=32
1615+
)
1616+
1617+
16061618
class TestPSS:
16071619
def test_calculate_max_pss_salt_length(self):
16081620
with pytest.raises(TypeError):
@@ -1644,8 +1656,68 @@ def test_mgf_property(self):
16441656
assert pss.mgf == mgf
16451657
assert pss.mgf == pss._mgf
16461658

1659+
@pytest.mark.parametrize("xof", [hashes.SHA256(), hashes.SHA512()])
1660+
@pytest.mark.parametrize(
1661+
"salt_length",
1662+
[
1663+
1,
1664+
32,
1665+
padding.PSS.MAX_LENGTH,
1666+
padding.PSS.AUTO,
1667+
padding.PSS.DIGEST_LENGTH,
1668+
],
1669+
)
1670+
def test_eq(
1671+
self, xof: hashes.HashAlgorithm, salt_length: typing.Any
1672+
) -> None:
1673+
assert padding.PSS(
1674+
salt_length=salt_length, mgf=padding.MGF1(algorithm=xof)
1675+
) == padding.PSS(
1676+
salt_length=salt_length, mgf=padding.MGF1(algorithm=xof)
1677+
)
1678+
1679+
@pytest.mark.parametrize(
1680+
"salt_length",
1681+
[
1682+
1,
1683+
32,
1684+
padding.PSS.MAX_LENGTH,
1685+
padding.PSS.AUTO,
1686+
padding.PSS.DIGEST_LENGTH,
1687+
],
1688+
)
1689+
def test_not_eq_with_different_salt_length(
1690+
self, salt_length: typing.Any
1691+
) -> None:
1692+
xof = hashes.SHA256()
1693+
assert padding.PSS(
1694+
salt_length=salt_length, mgf=padding.MGF1(algorithm=xof)
1695+
) != padding.PSS(salt_length=64, mgf=padding.MGF1(algorithm=xof))
1696+
1697+
def test_not_eq_with_salt_length_object_identity(self) -> None:
1698+
xof = hashes.SHA256()
1699+
assert padding.PSS(
1700+
salt_length=padding.PSS.AUTO, mgf=padding.MGF1(algorithm=xof)
1701+
) != padding.PSS(
1702+
salt_length=padding.PSS.DIGEST_LENGTH,
1703+
mgf=padding.MGF1(algorithm=xof),
1704+
)
1705+
1706+
def test_not_eq_with_different_mgf(self) -> None:
1707+
assert padding.PSS(
1708+
salt_length=padding.PSS.AUTO,
1709+
mgf=padding.MGF1(algorithm=hashes.SHA256()),
1710+
) != padding.PSS(
1711+
salt_length=padding.PSS.AUTO,
1712+
mgf=padding.MGF1(algorithm=hashes.SHA512()),
1713+
)
1714+
16471715

16481716
class TestMGF1:
1717+
def test_eq(self) -> None:
1718+
assert padding.MGF1(hashes.SHA256()) == padding.MGF1(hashes.SHA256())
1719+
assert padding.MGF1(hashes.SHA256()) != padding.MGF1(hashes.SHA512())
1720+
16491721
def test_invalid_hash_algorithm(self):
16501722
with pytest.raises(TypeError):
16511723
padding.MGF1(b"not_a_hash") # type:ignore[arg-type]
@@ -1680,6 +1752,47 @@ def test_mgf_property(self):
16801752
assert oaep.mgf == mgf
16811753
assert oaep.mgf == oaep._mgf
16821754

1755+
@pytest.mark.parametrize("xof", [hashes.SHA256(), hashes.SHA512()])
1756+
@pytest.mark.parametrize("label", [None, b"", b"foo"])
1757+
def test_eq(self, xof: hashes.HashAlgorithm, label: bytes | None) -> None:
1758+
mgf = padding.MGF1(algorithm=xof)
1759+
assert padding.OAEP(
1760+
mgf=mgf, algorithm=xof, label=label
1761+
) == padding.OAEP(mgf=mgf, algorithm=xof, label=label)
1762+
1763+
def test_not_eq_with_different_mgf(self) -> None:
1764+
assert padding.OAEP(
1765+
mgf=padding.MGF1(algorithm=hashes.SHA256()),
1766+
algorithm=hashes.SHA256(),
1767+
label=None,
1768+
) != padding.OAEP(
1769+
mgf=padding.MGF1(algorithm=hashes.SHA512()),
1770+
algorithm=hashes.SHA256(),
1771+
label=None,
1772+
)
1773+
1774+
def test_not_eq_with_different_algorithm(self) -> None:
1775+
assert padding.OAEP(
1776+
mgf=padding.MGF1(algorithm=hashes.SHA512()),
1777+
algorithm=hashes.SHA512(),
1778+
label=None,
1779+
) != padding.OAEP(
1780+
mgf=padding.MGF1(algorithm=hashes.SHA512()),
1781+
algorithm=hashes.SHA256(),
1782+
label=None,
1783+
)
1784+
1785+
def test_not_eq_with_different_label(self) -> None:
1786+
assert padding.OAEP(
1787+
mgf=padding.MGF1(algorithm=hashes.SHA512()),
1788+
algorithm=hashes.SHA256(),
1789+
label=None,
1790+
) != padding.OAEP(
1791+
mgf=padding.MGF1(algorithm=hashes.SHA512()),
1792+
algorithm=hashes.SHA256(),
1793+
label=b"",
1794+
)
1795+
16831796

16841797
class TestRSADecryption:
16851798
@pytest.mark.supported(

0 commit comments

Comments
 (0)