Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/cryptography/hazmat/asn1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details.

from cryptography.hazmat.asn1.asn1 import decode_der, encode_der, sequence
from cryptography.hazmat.asn1.asn1 import (
PrintableString,
decode_der,
encode_der,
sequence,
)

__all__ = [
"PrintableString",
"decode_der",
"encode_der",
"sequence",
Expand Down
3 changes: 3 additions & 0 deletions src/cryptography/hazmat/asn1/asn1.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,6 @@ def sequence(cls: type[U]) -> type[U]:
)(cls)
_register_asn1_sequence(dataclass_cls)
return dataclass_cls


PrintableString = declarative_asn1.PrintableString
6 changes: 6 additions & 0 deletions src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,9 @@ class AnnotatedTypeObject:
def __new__(
cls, annotated_type: AnnotatedType, value: typing.Any
) -> AnnotatedTypeObject: ...
def __repr__(self) -> str: ...
def __eq__(self, other: object) -> bool: ...

class PrintableString:
def __new__(cls, inner: str) -> PrintableString: ...
def as_str(self) -> str: ...
12 changes: 11 additions & 1 deletion src/rust/src/declarative_asn1/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use asn1::Parser;
use pyo3::types::PyAnyMethods;

use crate::asn1::big_byte_slice_to_py_int;
use crate::declarative_asn1::types::{AnnotatedType, Type};
use crate::declarative_asn1::types::{AnnotatedType, PrintableString, Type};
use crate::error::CryptographyError;

type ParseResult<T> = Result<T, CryptographyError>;
Expand Down Expand Up @@ -48,6 +48,15 @@ fn decode_pystr<'a>(
Ok(pyo3::types::PyString::new(py, value.as_str()))
}

fn decode_printable_string<'a>(
py: pyo3::Python<'a>,
parser: &mut Parser<'a>,
) -> ParseResult<pyo3::Bound<'a, PrintableString>> {
let value = parser.read_element::<asn1::PrintableString<'a>>()?.as_str();
let inner = pyo3::types::PyString::new(py, value).unbind();
Ok(pyo3::Bound::new(py, PrintableString { inner })?)
}

pub(crate) fn decode_annotated_type<'a>(
py: pyo3::Python<'a>,
parser: &mut Parser<'a>,
Expand Down Expand Up @@ -78,5 +87,6 @@ pub(crate) fn decode_annotated_type<'a>(
Type::PyInt() => Ok(decode_pyint(py, parser)?.into_any()),
Type::PyBytes() => Ok(decode_pybytes(py, parser)?.into_any()),
Type::PyStr() => Ok(decode_pystr(py, parser)?.into_any()),
Type::PrintableString() => Ok(decode_printable_string(py, parser)?.into_any()),
}
}
16 changes: 15 additions & 1 deletion src/rust/src/declarative_asn1/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
use asn1::{SimpleAsn1Writable, Writer};
use pyo3::types::PyAnyMethods;

use crate::declarative_asn1::types::{AnnotatedType, AnnotatedTypeObject, Type};
use crate::declarative_asn1::types::{AnnotatedType, AnnotatedTypeObject, PrintableString, Type};

fn write_value<T: SimpleAsn1Writable>(
writer: &mut Writer<'_>,
Expand Down Expand Up @@ -73,6 +73,20 @@ impl asn1::Asn1Writable for AnnotatedTypeObject<'_> {
let asn1_string: asn1::Utf8String<'_> = asn1::Utf8String::new(&val);
write_value(writer, &asn1_string)
}
Type::PrintableString() => {
let val: &pyo3::Bound<'_, PrintableString> = value
.downcast()
.map_err(|_| asn1::WriteError::AllocationError)?;
let inner_str = val
.get()
.inner
.to_cow(py)
.map_err(|_| asn1::WriteError::AllocationError)?;
let printable_string: asn1::PrintableString<'_> =
asn1::PrintableString::new(&inner_str)
.ok_or(asn1::WriteError::AllocationError)?;
write_value(writer, &printable_string)
}
}
}
}
42 changes: 41 additions & 1 deletion src/rust/src/declarative_asn1/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// 2.0, and the BSD License. See the LICENSE file in the root of this repository
// for complete details.

use asn1::PrintableString as Asn1PrintableString;
use pyo3::types::PyAnyMethods;
use pyo3::{IntoPyObject, PyTypeInfo};

Expand Down Expand Up @@ -32,6 +33,9 @@ pub enum Type {
/// `str` -> `UTF8String`
#[pyo3(constructor = ())]
PyStr(),
/// PrintableString (`str`)
#[pyo3(constructor = ())]
PrintableString(),
}

/// A type that we know how to encode/decode, along with any
Expand Down Expand Up @@ -70,6 +74,40 @@ impl Annotation {
}
}

#[derive(pyo3::FromPyObject)]
#[pyo3::pyclass(frozen, module = "cryptography.hazmat.bindings._rust.asn1")]
pub struct PrintableString {
pub(crate) inner: pyo3::Py<pyo3::types::PyString>,
}

#[pyo3::pymethods]
impl PrintableString {
#[new]
#[pyo3(signature = (inner,))]
fn new(py: pyo3::Python<'_>, inner: pyo3::Py<pyo3::types::PyString>) -> pyo3::PyResult<Self> {
if Asn1PrintableString::new(&inner.to_cow(py)?).is_none() {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"invalid PrintableString: {inner}"
)));
}

Ok(PrintableString { inner })
}

#[pyo3(signature = ())]
pub fn as_str(&self, py: pyo3::Python<'_>) -> pyo3::PyResult<pyo3::Py<pyo3::types::PyString>> {
Ok(self.inner.clone_ref(py))
}

fn __eq__(&self, py: pyo3::Python<'_>, other: pyo3::PyRef<'_, Self>) -> pyo3::PyResult<bool> {
(**self.inner.bind(py)).eq(other.inner.bind(py))
}

pub fn __repr__(&self, py: pyo3::Python<'_>) -> pyo3::PyResult<String> {
Ok(format!("PrintableString({})", self.inner.bind(py).repr()?))
}
}

/// Utility function for converting builtin Python types
/// to their Rust `Type` equivalent.
#[pyo3::pyfunction]
Expand All @@ -85,6 +123,8 @@ pub fn non_root_python_to_rust<'p>(
Type::PyStr().into_pyobject(py)
} else if class.is(pyo3::types::PyBytes::type_object(py)) {
Type::PyBytes().into_pyobject(py)
} else if class.is(PrintableString::type_object(py)) {
Type::PrintableString().into_pyobject(py)
} else {
Err(pyo3::exceptions::PyTypeError::new_err(format!(
"cannot handle type: {class:?}"
Expand Down Expand Up @@ -131,5 +171,5 @@ pub(crate) fn python_class_to_annotated<'p>(
#[pyo3::pymodule(gil_used = false)]
pub(crate) mod types {
#[pymodule_export]
use super::{AnnotatedType, Annotation, Type};
use super::{AnnotatedType, Annotation, PrintableString, Type};
}
2 changes: 1 addition & 1 deletion src/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ mod _rust {

#[pymodule_export]
use crate::declarative_asn1::types::{
non_root_python_to_rust, AnnotatedType, Annotation, Type,
non_root_python_to_rust, AnnotatedType, Annotation, PrintableString, Type,
};
}

Expand Down
15 changes: 14 additions & 1 deletion tests/hazmat/asn1/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,20 @@
import cryptography.hazmat.asn1 as asn1


class TestClassAPI:
class TestTypesAPI:
def test_repr_printable_string(self) -> None:
my_string = "MyString"
assert (
repr(asn1.PrintableString(my_string))
== f"PrintableString({my_string!r})"
)

def test_invalid_printable_string(self) -> None:
with pytest.raises(ValueError, match="invalid PrintableString: café"):
asn1.PrintableString("café")


class TestSequenceAPI:
def test_fail_unsupported_field(self) -> None:
# Not a sequence
class Unsupported:
Expand Down
15 changes: 14 additions & 1 deletion tests/hazmat/asn1/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ def _comparable_dataclass(cls: typing.Type[U]) -> typing.Type[U]:
)(cls)


# Checks that the encoding-decoding roundtrip results
# in the expected values and is consistent.
def assert_roundtrips(
test_cases: typing.List[typing.Tuple[typing.Any, bytes]],
test_cases: typing.List[typing.Tuple[U, bytes]],
) -> None:
for obj, obj_bytes in test_cases:
encoded = asn1.encode_der(obj)
Expand Down Expand Up @@ -105,6 +107,17 @@ def test_string(self) -> None:
)


class TestPrintableString:
def test_ok_printable_string(self) -> None:
assert_roundtrips(
[
(asn1.PrintableString(""), b"\x13\x00"),
(asn1.PrintableString("hello"), b"\x13\x05hello"),
(asn1.PrintableString("Test User 1"), b"\x13\x0bTest User 1"),
]
)


class TestSequence:
def test_ok_sequence_single_field(self) -> None:
@asn1.sequence
Expand Down
Loading