Skip to content

Commit 5848c3c

Browse files
authored
asn1: Add support for bytes, str and bool (#13482)
* asn1: Add support for bytes, str and bool Signed-off-by: Facundo Tuesca <[email protected]> * Use `PyBackedStr` Signed-off-by: Facundo Tuesca <[email protected]> --------- Signed-off-by: Facundo Tuesca <[email protected]>
1 parent 032888f commit 5848c3c

File tree

5 files changed

+125
-2
lines changed

5 files changed

+125
-2
lines changed

src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ def non_root_python_to_rust(cls: type) -> Type: ...
1111
# annotations like this:
1212
class Type:
1313
Sequence: typing.ClassVar[type]
14+
PyBool: typing.ClassVar[type]
1415
PyInt: typing.ClassVar[type]
16+
PyBytes: typing.ClassVar[type]
17+
PyStr: typing.ClassVar[type]
1518

1619
class Annotation:
1720
def __new__(

src/rust/src/declarative_asn1/decode.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@ use crate::error::CryptographyError;
1111

1212
type ParseResult<T> = Result<T, CryptographyError>;
1313

14+
fn decode_pybool<'a>(
15+
py: pyo3::Python<'a>,
16+
parser: &mut Parser<'a>,
17+
) -> ParseResult<pyo3::Bound<'a, pyo3::types::PyBool>> {
18+
let value = parser.read_element::<bool>()?;
19+
Ok(pyo3::types::PyBool::new(py, value).to_owned())
20+
}
21+
1422
fn decode_pyint<'a>(
1523
py: pyo3::Python<'a>,
1624
parser: &mut Parser<'a>,
@@ -24,6 +32,22 @@ fn decode_pyint<'a>(
2432
Ok(pyint)
2533
}
2634

35+
fn decode_pybytes<'a>(
36+
py: pyo3::Python<'a>,
37+
parser: &mut Parser<'a>,
38+
) -> ParseResult<pyo3::Bound<'a, pyo3::types::PyBytes>> {
39+
let value = parser.read_element::<&[u8]>()?;
40+
Ok(pyo3::types::PyBytes::new(py, value))
41+
}
42+
43+
fn decode_pystr<'a>(
44+
py: pyo3::Python<'a>,
45+
parser: &mut Parser<'a>,
46+
) -> ParseResult<pyo3::Bound<'a, pyo3::types::PyString>> {
47+
let value = parser.read_element::<asn1::Utf8String<'a>>()?;
48+
Ok(pyo3::types::PyString::new(py, value.as_str()))
49+
}
50+
2751
pub(crate) fn decode_annotated_type<'a>(
2852
py: pyo3::Python<'a>,
2953
parser: &mut Parser<'a>,
@@ -50,6 +74,9 @@ pub(crate) fn decode_annotated_type<'a>(
5074
Ok(val)
5175
})
5276
}
77+
Type::PyBool() => Ok(decode_pybool(py, parser)?.into_any()),
5378
Type::PyInt() => Ok(decode_pyint(py, parser)?.into_any()),
79+
Type::PyBytes() => Ok(decode_pybytes(py, parser)?.into_any()),
80+
Type::PyStr() => Ok(decode_pystr(py, parser)?.into_any()),
5481
}
5582
}

src/rust/src/declarative_asn1/encode.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,31 @@ impl asn1::Asn1Writable for AnnotatedTypeObject<'_> {
4848
Ok(())
4949
}),
5050
),
51+
Type::PyBool() => {
52+
let val: bool = value
53+
.extract()
54+
.map_err(|_| asn1::WriteError::AllocationError)?;
55+
write_value(writer, &val)
56+
}
5157
Type::PyInt() => {
5258
let val: i64 = value
5359
.extract()
5460
.map_err(|_| asn1::WriteError::AllocationError)?;
5561
write_value(writer, &val)
5662
}
63+
Type::PyBytes() => {
64+
let val: &[u8] = value
65+
.extract()
66+
.map_err(|_| asn1::WriteError::AllocationError)?;
67+
write_value(writer, &val)
68+
}
69+
Type::PyStr() => {
70+
let val: pyo3::pybacked::PyBackedStr = value
71+
.extract()
72+
.map_err(|_| asn1::WriteError::AllocationError)?;
73+
let asn1_string: asn1::Utf8String<'_> = asn1::Utf8String::new(&val);
74+
write_value(writer, &asn1_string)
75+
}
5776
}
5877
}
5978
}

src/rust/src/declarative_asn1/types.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,18 @@ pub enum Type {
2020

2121
// Python types that we map to canonical ASN.1 types
2222
//
23+
/// `bool` -> `Boolean`
24+
#[pyo3(constructor = ())]
25+
PyBool(),
2326
/// `int` -> `Integer`
2427
#[pyo3(constructor = ())]
2528
PyInt(),
29+
/// `bytes` -> `Octet String`
30+
#[pyo3(constructor = ())]
31+
PyBytes(),
32+
/// `str` -> `UTF8String`
33+
#[pyo3(constructor = ())]
34+
PyStr(),
2635
}
2736

2837
/// A type that we know how to encode/decode, along with any
@@ -70,6 +79,12 @@ pub fn non_root_python_to_rust<'p>(
7079
) -> pyo3::PyResult<pyo3::Bound<'p, Type>> {
7180
if class.is(pyo3::types::PyInt::type_object(py)) {
7281
Type::PyInt().into_pyobject(py)
82+
} else if class.is(pyo3::types::PyBool::type_object(py)) {
83+
Type::PyBool().into_pyobject(py)
84+
} else if class.is(pyo3::types::PyString::type_object(py)) {
85+
Type::PyStr().into_pyobject(py)
86+
} else if class.is(pyo3::types::PyBytes::type_object(py)) {
87+
Type::PyBytes().into_pyobject(py)
7388
} else {
7489
Err(pyo3::exceptions::PyTypeError::new_err(format!(
7590
"cannot handle type: {class:?}"

tests/hazmat/asn1/test_serialization.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,16 @@ def assert_roundtrips(
4646
assert decoded == obj
4747

4848

49+
class TestBool:
50+
def test_bool(self) -> None:
51+
assert_roundtrips(
52+
[
53+
(True, b"\x01\x01\xff"),
54+
(False, b"\x01\x01\x00"),
55+
],
56+
)
57+
58+
4959
class TestInteger:
5060
def test_int(self) -> None:
5161
assert_roundtrips(
@@ -64,6 +74,37 @@ def test_int(self) -> None:
6474
)
6575

6676

77+
class TestBytes:
78+
def test_bytes(self) -> None:
79+
assert_roundtrips(
80+
[
81+
(b"", b"\x04\x00"),
82+
(b"hello", b"\x04\x05hello"),
83+
(b"\x01\x02\x03", b"\x04\x03\x01\x02\x03"),
84+
(
85+
b"\x00\xff\x80\x7f",
86+
b"\x04\x04\x00\xff\x80\x7f",
87+
),
88+
]
89+
)
90+
91+
92+
class TestString:
93+
def test_string(self) -> None:
94+
assert_roundtrips(
95+
[
96+
("", b"\x0c\x00"),
97+
("hello", b"\x0c\x05hello"),
98+
("Test User 1", b"\x0c\x0bTest User 1"),
99+
(
100+
"café",
101+
b"\x0c\x05caf\xc3\xa9",
102+
), # UTF-8 string with non-ASCII
103+
("🚀", b"\x0c\x04\xf0\x9f\x9a\x80"), # UTF-8 emoji
104+
]
105+
)
106+
107+
67108
class TestSequence:
68109
def test_ok_sequence_single_field(self) -> None:
69110
@asn1.sequence
@@ -73,7 +114,7 @@ class Example:
73114

74115
assert_roundtrips([(Example(foo=9), b"\x30\x03\x02\x01\x09")])
75116

76-
def test_encode_ok_sequence_multiple_fields(self) -> None:
117+
def test_ok_sequence_multiple_fields(self) -> None:
77118
@asn1.sequence
78119
@_comparable_dataclass
79120
class Example:
@@ -84,7 +125,7 @@ class Example:
84125
[(Example(foo=9, bar=6), b"\x30\x06\x02\x01\x09\x02\x01\x06")]
85126
)
86127

87-
def test_encode_ok_nested_sequence(self) -> None:
128+
def test_ok_nested_sequence(self) -> None:
88129
@asn1.sequence
89130
@_comparable_dataclass
90131
class Child:
@@ -98,3 +139,21 @@ class Parent:
98139
assert_roundtrips(
99140
[(Parent(foo=Child(foo=9)), b"\x30\x05\x30\x03\x02\x01\x09")]
100141
)
142+
143+
def test_ok_sequence_multiple_types(self) -> None:
144+
@asn1.sequence
145+
@_comparable_dataclass
146+
class Example:
147+
a: bool
148+
b: int
149+
c: bytes
150+
d: str
151+
152+
assert_roundtrips(
153+
[
154+
(
155+
Example(a=True, b=9, c=b"c", d="d"),
156+
b"\x30\x0c\x01\x01\xff\x02\x01\x09\x04\x01c\x0c\x01d",
157+
)
158+
]
159+
)

0 commit comments

Comments
 (0)