|
6 | 6 | import sys |
7 | 7 | import typing |
8 | 8 |
|
| 9 | +import pytest |
| 10 | + |
9 | 11 | import cryptography.hazmat.asn1 as asn1 |
10 | 12 |
|
11 | 13 | U = typing.TypeVar("U") |
@@ -34,16 +36,26 @@ def _comparable_dataclass(cls: typing.Type[U]) -> typing.Type[U]: |
34 | 36 | )(cls) |
35 | 37 |
|
36 | 38 |
|
| 39 | +# Checks that the encoding-decoding roundtrip results |
| 40 | +# in the expected values and is consistent. |
| 41 | +# |
| 42 | +# The `decoded_eq` argument is the equality function to use |
| 43 | +# for the decoded values. It's useful for types that aren't |
| 44 | +# directly comparable, like `PrintableString`. |
37 | 45 | def assert_roundtrips( |
38 | | - test_cases: typing.List[typing.Tuple[typing.Any, bytes]], |
| 46 | + test_cases: typing.List[typing.Tuple[U, bytes]], |
| 47 | + decoded_eq: typing.Optional[typing.Callable[[U, U], bool]] = None, |
39 | 48 | ) -> None: |
40 | 49 | for obj, obj_bytes in test_cases: |
41 | 50 | encoded = asn1.encode_der(obj) |
42 | 51 | assert encoded == obj_bytes |
43 | 52 |
|
44 | 53 | decoded = asn1.decode_der(type(obj), encoded) |
45 | 54 | assert isinstance(decoded, type(obj)) |
46 | | - assert decoded == obj |
| 55 | + if decoded_eq: |
| 56 | + assert decoded_eq(decoded, obj) |
| 57 | + else: |
| 58 | + assert decoded == obj |
47 | 59 |
|
48 | 60 |
|
49 | 61 | class TestBool: |
@@ -105,6 +117,28 @@ def test_string(self) -> None: |
105 | 117 | ) |
106 | 118 |
|
107 | 119 |
|
| 120 | +class TestPrintableString: |
| 121 | + def test_ok_printable_string(self) -> None: |
| 122 | + def decoded_eq(a: asn1.PrintableString, b: asn1.PrintableString): |
| 123 | + return a.as_str() == b.as_str() |
| 124 | + |
| 125 | + assert_roundtrips( |
| 126 | + [ |
| 127 | + (asn1.PrintableString(""), b"\x13\x00"), |
| 128 | + (asn1.PrintableString("hello"), b"\x13\x05hello"), |
| 129 | + (asn1.PrintableString("Test User 1"), b"\x13\x0bTest User 1"), |
| 130 | + ], |
| 131 | + decoded_eq, |
| 132 | + ) |
| 133 | + |
| 134 | + def test_invalid_printable_string(self) -> None: |
| 135 | + with pytest.raises(ValueError, match="allocation error"): |
| 136 | + asn1.encode_der(asn1.PrintableString("café")) |
| 137 | + |
| 138 | + with pytest.raises(ValueError, match="error parsing asn1 value"): |
| 139 | + asn1.decode_der(asn1.PrintableString, b"\x0c\x05caf\xc3\xa9") |
| 140 | + |
| 141 | + |
108 | 142 | class TestSequence: |
109 | 143 | def test_ok_sequence_single_field(self) -> None: |
110 | 144 | @asn1.sequence |
|
0 commit comments