diff --git a/src/cryptography/hazmat/asn1/__init__.py b/src/cryptography/hazmat/asn1/__init__.py index 2a68fc114f40..6496fbd730cb 100644 --- a/src/cryptography/hazmat/asn1/__init__.py +++ b/src/cryptography/hazmat/asn1/__init__.py @@ -3,6 +3,7 @@ # for complete details. from cryptography.hazmat.asn1.asn1 import ( + BitString, Default, Explicit, GeneralizedTime, @@ -16,6 +17,7 @@ ) __all__ = [ + "BitString", "Default", "Explicit", "GeneralizedTime", diff --git a/src/cryptography/hazmat/asn1/asn1.py b/src/cryptography/hazmat/asn1/asn1.py index b1ce1136dd6c..61ea615e92bc 100644 --- a/src/cryptography/hazmat/asn1/asn1.py +++ b/src/cryptography/hazmat/asn1/asn1.py @@ -240,3 +240,4 @@ class Default(typing.Generic[U]): PrintableString = declarative_asn1.PrintableString UtcTime = declarative_asn1.UtcTime GeneralizedTime = declarative_asn1.GeneralizedTime +BitString = declarative_asn1.BitString diff --git a/src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi b/src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi index 60e2040c5f30..08c64a95bbae 100644 --- a/src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi +++ b/src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi @@ -77,3 +77,10 @@ class GeneralizedTime: def __repr__(self) -> str: ... def __eq__(self, other: object) -> bool: ... def as_datetime(self) -> datetime.datetime: ... + +class BitString: + def __new__(cls, data: bytes, padding_bits: int) -> BitString: ... + def __repr__(self) -> str: ... + def __eq__(self, other: object) -> bool: ... + def as_bytes(self) -> bytes: ... + def padding_bits(self) -> int: ... diff --git a/src/rust/src/declarative_asn1/decode.rs b/src/rust/src/declarative_asn1/decode.rs index 1878417fd824..dbe07dec7144 100644 --- a/src/rust/src/declarative_asn1/decode.rs +++ b/src/rust/src/declarative_asn1/decode.rs @@ -8,7 +8,8 @@ use pyo3::types::PyListMethods; use crate::asn1::big_byte_slice_to_py_int; use crate::declarative_asn1::types::{ - type_to_tag, AnnotatedType, Encoding, GeneralizedTime, PrintableString, Type, UtcTime, + type_to_tag, AnnotatedType, BitString, Encoding, GeneralizedTime, PrintableString, Type, + UtcTime, }; use crate::error::CryptographyError; @@ -118,6 +119,22 @@ fn decode_generalized_time<'a>( Ok(pyo3::Bound::new(py, GeneralizedTime { inner })?) } +fn decode_bitstring<'a>( + py: pyo3::Python<'a>, + parser: &mut Parser<'a>, + encoding: &Option>, +) -> ParseResult> { + let value = read_value::>(parser, encoding)?; + let data = pyo3::types::PyBytes::new(py, value.as_bytes()).unbind(); + Ok(pyo3::Bound::new( + py, + BitString { + data, + padding_bits: value.padding_bits(), + }, + )?) +} + pub(crate) fn decode_annotated_type<'a>( py: pyo3::Python<'a>, parser: &mut Parser<'a>, @@ -201,6 +218,7 @@ pub(crate) fn decode_annotated_type<'a>( Type::PrintableString() => decode_printable_string(py, parser, encoding)?.into_any(), Type::UtcTime() => decode_utc_time(py, parser, encoding)?.into_any(), Type::GeneralizedTime() => decode_generalized_time(py, parser, encoding)?.into_any(), + Type::BitString() => decode_bitstring(py, parser, encoding)?.into_any(), }; match &ann_type.annotation.get().default { diff --git a/src/rust/src/declarative_asn1/encode.rs b/src/rust/src/declarative_asn1/encode.rs index b45364113c0c..9eb6b9932245 100644 --- a/src/rust/src/declarative_asn1/encode.rs +++ b/src/rust/src/declarative_asn1/encode.rs @@ -7,7 +7,8 @@ use pyo3::types::PyAnyMethods; use pyo3::types::PyListMethods; use crate::declarative_asn1::types::{ - AnnotatedType, AnnotatedTypeObject, Encoding, GeneralizedTime, PrintableString, Type, UtcTime, + AnnotatedType, AnnotatedTypeObject, BitString, Encoding, GeneralizedTime, PrintableString, + Type, UtcTime, }; fn write_value( @@ -171,6 +172,16 @@ impl asn1::Asn1Writable for AnnotatedTypeObject<'_> { .map_err(|_| asn1::WriteError::AllocationError)?; write_value(writer, &generalized_time, encoding) } + Type::BitString() => { + let val: &pyo3::Bound<'_, BitString> = value + .cast() + .map_err(|_| asn1::WriteError::AllocationError)?; + + let bitstring: asn1::BitString<'_> = + asn1::BitString::new(val.get().data.as_bytes(py), val.get().padding_bits) + .ok_or(asn1::WriteError::AllocationError)?; + write_value(writer, &bitstring, encoding) + } } } } diff --git a/src/rust/src/declarative_asn1/types.rs b/src/rust/src/declarative_asn1/types.rs index 3c07483309d5..d97b7ead8840 100644 --- a/src/rust/src/declarative_asn1/types.rs +++ b/src/rust/src/declarative_asn1/types.rs @@ -19,7 +19,7 @@ pub enum Type { /// The first element is the Python class that represents the sequence, /// the second element is a dict of the (already converted) fields of the class. Sequence(pyo3::Py, pyo3::Py), - /// SEQUENCEOF (`list[`T`]`) + /// SEQUENCE OF (`list[`T`]`) SequenceOf(pyo3::Py), /// OPTIONAL (`T | None`) Option(pyo3::Py), @@ -40,6 +40,8 @@ pub enum Type { UtcTime(), /// GeneralizedTime (`datetime`) GeneralizedTime(), + /// BIT STRING (`bytes`) + BitString(), } /// A type that we know how to encode/decode, along with any @@ -249,6 +251,54 @@ impl GeneralizedTime { } } +#[derive(pyo3::FromPyObject)] +#[pyo3::pyclass(frozen, module = "cryptography.hazmat.bindings._rust.asn1")] +pub struct BitString { + pub(crate) data: pyo3::Py, + pub(crate) padding_bits: u8, +} + +#[pyo3::pymethods] +impl BitString { + #[new] + #[pyo3(signature = (data, padding_bits,))] + fn new( + py: pyo3::Python<'_>, + data: pyo3::Py, + padding_bits: u8, + ) -> pyo3::PyResult { + if asn1::BitString::new(data.as_bytes(py), padding_bits).is_none() { + return Err(pyo3::exceptions::PyValueError::new_err(format!( + "invalid BIT STRING: data: {data}, padding_bits: {padding_bits}" + ))); + } + + Ok(BitString { data, padding_bits }) + } + + #[pyo3(signature = ())] + pub fn as_bytes(&self, py: pyo3::Python<'_>) -> pyo3::Py { + self.data.clone_ref(py) + } + + #[pyo3(signature = ())] + pub fn padding_bits(&self) -> u8 { + self.padding_bits + } + + fn __eq__(&self, py: pyo3::Python<'_>, other: pyo3::PyRef<'_, Self>) -> pyo3::PyResult { + Ok((**self.data.bind(py)).eq(other.data.bind(py))? + && self.padding_bits == other.padding_bits) + } + + pub fn __repr__(&self) -> pyo3::PyResult { + Ok(format!( + "BitString(data: {}, padding_bits: {})", + self.data, self.padding_bits, + )) + } +} + /// Utility function for converting builtin Python types /// to their Rust `Type` equivalent. #[pyo3::pyfunction] @@ -270,6 +320,8 @@ pub fn non_root_python_to_rust<'p>( Type::UtcTime().into_pyobject(py) } else if class.is(GeneralizedTime::type_object(py)) { Type::GeneralizedTime().into_pyobject(py) + } else if class.is(BitString::type_object(py)) { + Type::BitString().into_pyobject(py) } else { Err(pyo3::exceptions::PyTypeError::new_err(format!( "cannot handle type: {class:?}" @@ -326,6 +378,7 @@ pub(crate) fn type_to_tag(t: &Type, encoding: &Option>) -> as Type::PrintableString() => asn1::PrintableString::TAG, Type::UtcTime() => asn1::UtcTime::TAG, Type::GeneralizedTime() => asn1::GeneralizedTime::TAG, + Type::BitString() => asn1::BitString::TAG, }; match encoding { diff --git a/src/rust/src/lib.rs b/src/rust/src/lib.rs index ee6e6cbbdfe5..2eadd8c9ab35 100644 --- a/src/rust/src/lib.rs +++ b/src/rust/src/lib.rs @@ -151,8 +151,8 @@ mod _rust { #[pymodule_export] use crate::declarative_asn1::types::{ - non_root_python_to_rust, AnnotatedType, Annotation, Encoding, GeneralizedTime, - PrintableString, Size, Type, UtcTime, + non_root_python_to_rust, AnnotatedType, Annotation, BitString, Encoding, + GeneralizedTime, PrintableString, Size, Type, UtcTime, }; } diff --git a/tests/hazmat/asn1/test_api.py b/tests/hazmat/asn1/test_api.py index 4a35df2fe570..7889b7fefe38 100644 --- a/tests/hazmat/asn1/test_api.py +++ b/tests/hazmat/asn1/test_api.py @@ -107,6 +107,35 @@ def test_invalid_generalized_time(self) -> None: # We don't allow naive datetime objects asn1.GeneralizedTime(datetime.datetime(2000, 1, 1, 10, 10, 10)) + def test_bitstring_getters(self) -> None: + data = b"\x01\x02\x30" + bt = asn1.BitString(data=data, padding_bits=2) + + assert bt.as_bytes() == data + assert bt.padding_bits() == 2 + + def test_repr_bitstring(self) -> None: + data = b"\x01\x02\x30" + assert ( + repr(asn1.BitString(data, 2)) + == f"BitString(data: {data!r}, padding_bits: 2)" + ) + + def test_invalid_bitstring(self) -> None: + with pytest.raises( + ValueError, + match="invalid BIT STRING", + ): + # Padding bits cannot be > 7 + asn1.BitString(data=b"\x01\x02\x03", padding_bits=8) + + with pytest.raises( + ValueError, + match="invalid BIT STRING", + ): + # Padding bits have to be zero + asn1.BitString(data=b"\x01\x02\x03", padding_bits=2) + class TestSequenceAPI: def test_fail_unsupported_field(self) -> None: diff --git a/tests/hazmat/asn1/test_serialization.py b/tests/hazmat/asn1/test_serialization.py index 2e800f447972..12bcf0747194 100644 --- a/tests/hazmat/asn1/test_serialization.py +++ b/tests/hazmat/asn1/test_serialization.py @@ -224,6 +224,51 @@ def test_generalized_time(self) -> None: ) +class TestBitString: + def test_ok_bitstring(self) -> None: + assert_roundtrips( + [ + ( + asn1.BitString(data=b"\x6e\x5d\xc0", padding_bits=6), + b"\x03\x04\x06\x6e\x5d\xc0", + ), + ( + asn1.BitString(data=b"", padding_bits=0), + b"\x03\x01\x00", + ), + ( + asn1.BitString(data=b"\x00", padding_bits=7), + b"\x03\x02\x07\x00", + ), + ( + asn1.BitString(data=b"\x80", padding_bits=7), + b"\x03\x02\x07\x80", + ), + ( + asn1.BitString(data=b"\x81\xf0", padding_bits=4), + b"\x03\x03\x04\x81\xf0", + ), + ] + ) + + def test_fail_bitstring(self) -> None: + with pytest.raises(ValueError, match="error parsing asn1 value"): + # Prefix with number of padding bits missing + asn1.decode_der(asn1.BitString, b"\x03\x00") + + with pytest.raises(ValueError, match="error parsing asn1 value"): + # Non-zero padding bits + asn1.decode_der(asn1.BitString, b"\x03\x02\x07\x01") + + with pytest.raises(ValueError, match="error parsing asn1 value"): + # Non-zero padding bits + asn1.decode_der(asn1.BitString, b"\x03\x02\x07\x40") + + with pytest.raises(ValueError, match="error parsing asn1 value"): + # Padding bits > 7 + asn1.decode_der(asn1.BitString, b"\x03\x02\x08\x00") + + class TestSequence: def test_ok_sequence_single_field(self) -> None: @asn1.sequence @@ -492,12 +537,20 @@ class Example: e: typing.Union[asn1.UtcTime, None] f: typing.Union[asn1.GeneralizedTime, None] g: typing.Union[typing.List[int], None] + h: typing.Union[asn1.BitString, None] assert_roundtrips( [ ( Example( - a=None, b=None, c=None, d=None, e=None, f=None, g=None + a=None, + b=None, + c=None, + d=None, + e=None, + f=None, + g=None, + h=None, ), b"\x30\x00", )