diff --git a/src/cryptography/hazmat/asn1/__init__.py b/src/cryptography/hazmat/asn1/__init__.py index 5b4bc48a4ee2..1126d14c10a9 100644 --- a/src/cryptography/hazmat/asn1/__init__.py +++ b/src/cryptography/hazmat/asn1/__init__.py @@ -2,9 +2,15 @@ # 2.0, and the BSD License. See the LICENSE file in the root of this repository # for complete details. -from cryptography.hazmat.asn1.asn1 import decode_der, encode_der, sequence +from cryptography.hazmat.asn1.asn1 import ( + PrintableString, + decode_der, + encode_der, + sequence, +) __all__ = [ + "PrintableString", "decode_der", "encode_der", "sequence", diff --git a/src/cryptography/hazmat/asn1/asn1.py b/src/cryptography/hazmat/asn1/asn1.py index b8fb8509995b..a5cd7a6074ff 100644 --- a/src/cryptography/hazmat/asn1/asn1.py +++ b/src/cryptography/hazmat/asn1/asn1.py @@ -115,3 +115,6 @@ def sequence(cls: type[U]) -> type[U]: )(cls) _register_asn1_sequence(dataclass_cls) return dataclass_cls + + +PrintableString = declarative_asn1.PrintableString diff --git a/src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi b/src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi index 5bcc9bd49f1f..e4b2a99f5864 100644 --- a/src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi +++ b/src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi @@ -34,3 +34,9 @@ class AnnotatedTypeObject: def __new__( cls, annotated_type: AnnotatedType, value: typing.Any ) -> AnnotatedTypeObject: ... + def __repr__(self) -> str: ... + def __eq__(self, other: object) -> bool: ... + +class PrintableString: + def __new__(cls, inner: str) -> PrintableString: ... + def as_str(self) -> str: ... diff --git a/src/rust/src/declarative_asn1/decode.rs b/src/rust/src/declarative_asn1/decode.rs index d851dd43c67c..e13dfb9a8f5c 100644 --- a/src/rust/src/declarative_asn1/decode.rs +++ b/src/rust/src/declarative_asn1/decode.rs @@ -6,7 +6,7 @@ use asn1::Parser; use pyo3::types::PyAnyMethods; use crate::asn1::big_byte_slice_to_py_int; -use crate::declarative_asn1::types::{AnnotatedType, Type}; +use crate::declarative_asn1::types::{AnnotatedType, PrintableString, Type}; use crate::error::CryptographyError; type ParseResult = Result; @@ -48,6 +48,15 @@ fn decode_pystr<'a>( Ok(pyo3::types::PyString::new(py, value.as_str())) } +fn decode_printable_string<'a>( + py: pyo3::Python<'a>, + parser: &mut Parser<'a>, +) -> ParseResult> { + let value = parser.read_element::>()?.as_str(); + let inner = pyo3::types::PyString::new(py, value).unbind(); + Ok(pyo3::Bound::new(py, PrintableString { inner })?) +} + pub(crate) fn decode_annotated_type<'a>( py: pyo3::Python<'a>, parser: &mut Parser<'a>, @@ -78,5 +87,6 @@ pub(crate) fn decode_annotated_type<'a>( Type::PyInt() => Ok(decode_pyint(py, parser)?.into_any()), Type::PyBytes() => Ok(decode_pybytes(py, parser)?.into_any()), Type::PyStr() => Ok(decode_pystr(py, parser)?.into_any()), + Type::PrintableString() => Ok(decode_printable_string(py, parser)?.into_any()), } } diff --git a/src/rust/src/declarative_asn1/encode.rs b/src/rust/src/declarative_asn1/encode.rs index d9b5b19f4905..3e3afba57fb8 100644 --- a/src/rust/src/declarative_asn1/encode.rs +++ b/src/rust/src/declarative_asn1/encode.rs @@ -5,7 +5,7 @@ use asn1::{SimpleAsn1Writable, Writer}; use pyo3::types::PyAnyMethods; -use crate::declarative_asn1::types::{AnnotatedType, AnnotatedTypeObject, Type}; +use crate::declarative_asn1::types::{AnnotatedType, AnnotatedTypeObject, PrintableString, Type}; fn write_value( writer: &mut Writer<'_>, @@ -73,6 +73,20 @@ impl asn1::Asn1Writable for AnnotatedTypeObject<'_> { let asn1_string: asn1::Utf8String<'_> = asn1::Utf8String::new(&val); write_value(writer, &asn1_string) } + Type::PrintableString() => { + let val: &pyo3::Bound<'_, PrintableString> = value + .downcast() + .map_err(|_| asn1::WriteError::AllocationError)?; + let inner_str = val + .get() + .inner + .to_cow(py) + .map_err(|_| asn1::WriteError::AllocationError)?; + let printable_string: asn1::PrintableString<'_> = + asn1::PrintableString::new(&inner_str) + .ok_or(asn1::WriteError::AllocationError)?; + write_value(writer, &printable_string) + } } } } diff --git a/src/rust/src/declarative_asn1/types.rs b/src/rust/src/declarative_asn1/types.rs index e3fc9d3d5fca..9be3e2d8a576 100644 --- a/src/rust/src/declarative_asn1/types.rs +++ b/src/rust/src/declarative_asn1/types.rs @@ -2,6 +2,7 @@ // 2.0, and the BSD License. See the LICENSE file in the root of this repository // for complete details. +use asn1::PrintableString as Asn1PrintableString; use pyo3::types::PyAnyMethods; use pyo3::{IntoPyObject, PyTypeInfo}; @@ -32,6 +33,9 @@ pub enum Type { /// `str` -> `UTF8String` #[pyo3(constructor = ())] PyStr(), + /// PrintableString (`str`) + #[pyo3(constructor = ())] + PrintableString(), } /// A type that we know how to encode/decode, along with any @@ -70,6 +74,40 @@ impl Annotation { } } +#[derive(pyo3::FromPyObject)] +#[pyo3::pyclass(frozen, module = "cryptography.hazmat.bindings._rust.asn1")] +pub struct PrintableString { + pub(crate) inner: pyo3::Py, +} + +#[pyo3::pymethods] +impl PrintableString { + #[new] + #[pyo3(signature = (inner,))] + fn new(py: pyo3::Python<'_>, inner: pyo3::Py) -> pyo3::PyResult { + if Asn1PrintableString::new(&inner.to_cow(py)?).is_none() { + return Err(pyo3::exceptions::PyValueError::new_err(format!( + "invalid PrintableString: {inner}" + ))); + } + + Ok(PrintableString { inner }) + } + + #[pyo3(signature = ())] + pub fn as_str(&self, py: pyo3::Python<'_>) -> pyo3::PyResult> { + Ok(self.inner.clone_ref(py)) + } + + fn __eq__(&self, py: pyo3::Python<'_>, other: pyo3::PyRef<'_, Self>) -> pyo3::PyResult { + (**self.inner.bind(py)).eq(other.inner.bind(py)) + } + + pub fn __repr__(&self, py: pyo3::Python<'_>) -> pyo3::PyResult { + Ok(format!("PrintableString({})", self.inner.bind(py).repr()?)) + } +} + /// Utility function for converting builtin Python types /// to their Rust `Type` equivalent. #[pyo3::pyfunction] @@ -85,6 +123,8 @@ pub fn non_root_python_to_rust<'p>( Type::PyStr().into_pyobject(py) } else if class.is(pyo3::types::PyBytes::type_object(py)) { Type::PyBytes().into_pyobject(py) + } else if class.is(PrintableString::type_object(py)) { + Type::PrintableString().into_pyobject(py) } else { Err(pyo3::exceptions::PyTypeError::new_err(format!( "cannot handle type: {class:?}" @@ -131,5 +171,5 @@ pub(crate) fn python_class_to_annotated<'p>( #[pyo3::pymodule(gil_used = false)] pub(crate) mod types { #[pymodule_export] - use super::{AnnotatedType, Annotation, Type}; + use super::{AnnotatedType, Annotation, PrintableString, Type}; } diff --git a/src/rust/src/lib.rs b/src/rust/src/lib.rs index 9bd62a740f6e..6286046bf419 100644 --- a/src/rust/src/lib.rs +++ b/src/rust/src/lib.rs @@ -124,7 +124,7 @@ mod _rust { #[pymodule_export] use crate::declarative_asn1::types::{ - non_root_python_to_rust, AnnotatedType, Annotation, Type, + non_root_python_to_rust, AnnotatedType, Annotation, PrintableString, Type, }; } diff --git a/tests/hazmat/asn1/test_api.py b/tests/hazmat/asn1/test_api.py index 23aaaddfca58..c9e1ed50e9ac 100644 --- a/tests/hazmat/asn1/test_api.py +++ b/tests/hazmat/asn1/test_api.py @@ -9,7 +9,20 @@ import cryptography.hazmat.asn1 as asn1 -class TestClassAPI: +class TestTypesAPI: + def test_repr_printable_string(self) -> None: + my_string = "MyString" + assert ( + repr(asn1.PrintableString(my_string)) + == f"PrintableString({my_string!r})" + ) + + def test_invalid_printable_string(self) -> None: + with pytest.raises(ValueError, match="invalid PrintableString: café"): + asn1.PrintableString("café") + + +class TestSequenceAPI: def test_fail_unsupported_field(self) -> None: # Not a sequence class Unsupported: diff --git a/tests/hazmat/asn1/test_serialization.py b/tests/hazmat/asn1/test_serialization.py index 80fed3706d2c..b88cc9b5d5bc 100644 --- a/tests/hazmat/asn1/test_serialization.py +++ b/tests/hazmat/asn1/test_serialization.py @@ -34,8 +34,10 @@ def _comparable_dataclass(cls: typing.Type[U]) -> typing.Type[U]: )(cls) +# Checks that the encoding-decoding roundtrip results +# in the expected values and is consistent. def assert_roundtrips( - test_cases: typing.List[typing.Tuple[typing.Any, bytes]], + test_cases: typing.List[typing.Tuple[U, bytes]], ) -> None: for obj, obj_bytes in test_cases: encoded = asn1.encode_der(obj) @@ -105,6 +107,17 @@ def test_string(self) -> None: ) +class TestPrintableString: + def test_ok_printable_string(self) -> None: + assert_roundtrips( + [ + (asn1.PrintableString(""), b"\x13\x00"), + (asn1.PrintableString("hello"), b"\x13\x05hello"), + (asn1.PrintableString("Test User 1"), b"\x13\x0bTest User 1"), + ] + ) + + class TestSequence: def test_ok_sequence_single_field(self) -> None: @asn1.sequence