diff --git a/src/cryptography/hazmat/asn1/__init__.py b/src/cryptography/hazmat/asn1/__init__.py index 0ba1cd73e8fe..147d1d088b55 100644 --- a/src/cryptography/hazmat/asn1/__init__.py +++ b/src/cryptography/hazmat/asn1/__init__.py @@ -12,6 +12,7 @@ PrintableString, Size, UtcTime, + Variant, decode_der, encode_der, sequence, @@ -27,6 +28,7 @@ "PrintableString", "Size", "UtcTime", + "Variant", "decode_der", "encode_der", "sequence", diff --git a/src/cryptography/hazmat/asn1/asn1.py b/src/cryptography/hazmat/asn1/asn1.py index 5d8b245ac56a..9f846d0c927d 100644 --- a/src/cryptography/hazmat/asn1/asn1.py +++ b/src/cryptography/hazmat/asn1/asn1.py @@ -13,6 +13,8 @@ if sys.version_info < (3, 11): import typing_extensions + LiteralString = typing_extensions.LiteralString + # We use the `include_extras` parameter of `get_type_hints`, which was # added in Python 3.9. This can be replaced by the `typing` version # once the min version is >= 3.9 @@ -31,6 +33,7 @@ get_type_args = typing.get_args get_type_origin = typing.get_origin Annotated = typing.Annotated + LiteralString = typing.LiteralString if sys.version_info < (3, 10): NoneType = type(None) @@ -41,6 +44,30 @@ T = typing.TypeVar("T", covariant=True) U = typing.TypeVar("U") +Tag = typing.TypeVar("Tag", bound=LiteralString) + + +@dataclasses.dataclass(frozen=True) +class Variant(typing.Generic[U, Tag]): + """ + A tagged variant for CHOICE fields with the same underlying type. + + Use this when you have multiple CHOICE alternatives with the same type + and need to distinguish between them: + + foo: ( + Annotated[Variant[int, typing.Literal["IntA"]], Implicit(0)] + | Annotated[Variant[int, typing.Literal["IntB"]], Implicit(1)] + ) + + Usage: + example = Example(foo=Variant(5, "IntA")) + decoded.foo.value # The int value + decoded.foo.tag # "IntA" or "IntB" + """ + + value: U + tag: str decode_der = declarative_asn1.decode_der @@ -101,7 +128,7 @@ def _normalize_field_type( # from it if it exists. if get_type_origin(field_type) is Annotated: annotation = _extract_annotation(field_type.__metadata__, field_name) - field_type, _ = get_type_args(field_type) + field_type, *_ = get_type_args(field_type) else: annotation = declarative_asn1.Annotation() @@ -150,10 +177,53 @@ def _normalize_field_type( ) rust_field_type = declarative_asn1.Type.Option(annotated_type) + else: - raise TypeError( - "union types other than `X | None` are currently not supported" + # Otherwise, the Union is a CHOICE + if isinstance(annotation.encoding, Implicit): + # CHOICEs cannot be IMPLICIT. See X.680 section 31.2.9. + raise TypeError( + "CHOICE (`X | Y | ...`) types should not have an IMPLICIT " + "annotation" + ) + variants = [ + _type_to_variant(arg, field_name) + for arg in union_args + if arg is not type(None) + ] + + # Union types should either be all Variants + # (`Variant[..] | Variant[..] | etc`) or all non Variants + are_union_types_tagged = variants[0].tag_name is not None + if any( + (v.tag_name is not None) != are_union_types_tagged + for v in variants + ): + raise TypeError( + "When using `asn1.Variant` in a union, all the other " + "types in the union must also be `asn1.Variant`" + ) + + if are_union_types_tagged: + tags = {v.tag_name for v in variants} + if len(variants) != len(tags): + raise TypeError( + "When using `asn1.Variant` in a union, the tags used " + "must be unique" + ) + + rust_choice_type = declarative_asn1.Type.Choice(variants) + # If None is part of the union types, this is an OPTIONAL CHOICE + rust_field_type = ( + declarative_asn1.Type.Option( + declarative_asn1.AnnotatedType( + rust_choice_type, declarative_asn1.Annotation() + ) + ) + if NoneType in union_args + else rust_choice_type ) + elif get_type_origin(field_type) is builtins.list: inner_type = _normalize_field_type( get_type_args(field_type)[0], field_name @@ -165,6 +235,51 @@ def _normalize_field_type( return declarative_asn1.AnnotatedType(rust_field_type, annotation) +# Convert a type to a Variant. Used with types inside Union +# annotations (T1, T2, etc in `Union[T1, T2, ...]`). +def _type_to_variant( + t: typing.Any, field_name: str +) -> declarative_asn1.Variant: + is_annotated = get_type_origin(t) is Annotated + inner_type = get_type_args(t)[0] if is_annotated else t + + # Check if this is a Variant[T, Tag] type + if get_type_origin(inner_type) is Variant: + value_type, tag_literal = get_type_args(inner_type) + if get_type_origin(tag_literal) is not typing.Literal: + raise TypeError( + "When using `asn1.Variant` in a type annotation, the second " + "type parameter must be a `typing.Literal` type. E.g: " + '`Variant[int, typing.Literal["MyInt"]]`.' + ) + tag_name = get_type_args(tag_literal)[0] + + if hasattr(value_type, "__asn1_root__"): + rust_type = value_type.__asn1_root__.inner + else: + rust_type = declarative_asn1.non_root_python_to_rust(value_type) + + if is_annotated: + ann_type = declarative_asn1.AnnotatedType( + rust_type, + _extract_annotation(t.__metadata__, field_name), + ) + else: + ann_type = declarative_asn1.AnnotatedType( + rust_type, + declarative_asn1.Annotation(), + ) + + return declarative_asn1.Variant(Variant, ann_type, tag_name) + else: + # Plain type (not a tagged Variant) + return declarative_asn1.Variant( + inner_type, + _normalize_field_type(t, field_name), + None, + ) + + def _annotate_fields( raw_fields: dict[str, type], ) -> dict[str, declarative_asn1.AnnotatedType]: diff --git a/src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi b/src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi index 4281d0c617bb..6777ef71241c 100644 --- a/src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi +++ b/src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi @@ -15,6 +15,7 @@ class Type: Sequence: typing.ClassVar[type] SequenceOf: typing.ClassVar[type] Option: typing.ClassVar[type] + Choice: typing.ClassVar[type] PyBool: typing.ClassVar[type] PyInt: typing.ClassVar[type] PyBytes: typing.ClassVar[type] @@ -60,6 +61,18 @@ class AnnotatedTypeObject: cls, annotated_type: AnnotatedType, value: typing.Any ) -> AnnotatedTypeObject: ... +class Variant: + python_class: type + ann_type: AnnotatedType + tag_name: str | None + + def __new__( + cls, + python_class: type, + ann_type: AnnotatedType, + tag_name: str | None, + ) -> Variant: ... + class PrintableString: def __new__(cls, inner: str) -> PrintableString: ... def __repr__(self) -> str: ... diff --git a/src/rust/src/declarative_asn1/decode.rs b/src/rust/src/declarative_asn1/decode.rs index c9408e1a1034..f17c3bc75bd9 100644 --- a/src/rust/src/declarative_asn1/decode.rs +++ b/src/rust/src/declarative_asn1/decode.rs @@ -7,8 +7,9 @@ use pyo3::types::{PyAnyMethods, PyListMethods}; use crate::asn1::big_byte_slice_to_py_int; use crate::declarative_asn1::types::{ - check_size_constraint, is_tag_valid_for_type, AnnotatedType, Annotation, BitString, Encoding, - GeneralizedTime, IA5String, PrintableString, Type, UtcTime, + check_size_constraint, is_tag_valid_for_type, is_tag_valid_for_variant, AnnotatedType, + Annotation, BitString, Encoding, GeneralizedTime, IA5String, PrintableString, Type, UtcTime, + Variant, }; use crate::error::CryptographyError; @@ -160,6 +161,47 @@ fn decode_bitstring<'a>( )?) } +// Utility function to handle explicit encoding when parsing +// CHOICE fields. +fn decode_choice_with_encoding<'a>( + py: pyo3::Python<'a>, + parser: &mut Parser<'a>, + ann_type: &AnnotatedType, + encoding: &Encoding, +) -> ParseResult> { + match encoding { + Encoding::Implicit(_) => Err(CryptographyError::Py( + // CHOICEs cannot be IMPLICIT. See X.680 section 31.2.9. + pyo3::exceptions::PyValueError::new_err( + "invalid type definition: CHOICE fields cannot be implicitly encoded".to_string(), + ), + ))?, + Encoding::Explicit(n) => { + // Since we don't know which of the variants is present for this + // CHOICE field, we'll parse this as a generic TLV encoded with + // EXPLICIT, so `read_explicit_element` will consume the EXPLICIT + // wrapper tag, and the TLV data will contain the variant. + let tlv = parser.read_explicit_element::>(*n)?; + let type_without_explicit = AnnotatedType { + inner: ann_type.inner.clone_ref(py), + annotation: pyo3::Py::new( + py, + Annotation { + default: None, + encoding: None, + size: None, + }, + )?, + }; + // Parse the TLV data (which contains the field without the EXPLICIT + // wrapper) + asn1::parse(tlv.full_data(), |d| { + decode_annotated_type(py, d, &type_without_explicit) + }) + } + } +} + pub(crate) fn decode_annotated_type<'a>( py: pyo3::Python<'a>, parser: &mut Parser<'a>, @@ -173,7 +215,7 @@ pub(crate) fn decode_annotated_type<'a>( // returning the default value) if let Some(default) = &ann_type.annotation.get().default { match parser.peek_tag() { - Some(next_tag) if is_tag_valid_for_type(next_tag, inner, encoding) => (), + Some(next_tag) if is_tag_valid_for_type(py, next_tag, inner, encoding) => (), _ => return Ok(default.clone_ref(py).into_bound(py)), } } @@ -210,7 +252,7 @@ pub(crate) fn decode_annotated_type<'a>( } Type::Option(cls) => { match parser.peek_tag() { - Some(t) if is_tag_valid_for_type(t, cls.get().inner.get(), encoding) => { + Some(t) if is_tag_valid_for_type(py, t, cls.get().inner.get(), encoding) => { // For optional types, annotations will always be associated to the `Optional` type // i.e: `Annotated[Optional[T], annotation]`, as opposed to the inner `T` type. // Therefore, when decoding the inner type `T` we must pass the annotation of the `Optional` @@ -223,6 +265,33 @@ pub(crate) fn decode_annotated_type<'a>( _ => pyo3::types::PyNone::get(py).to_owned().into_any(), } } + Type::Choice(ts) => match encoding { + Some(e) => decode_choice_with_encoding(py, parser, ann_type, e.get())?, + None => { + for t in ts.bind(py) { + let variant = t.cast::()?.get(); + match parser.peek_tag() { + Some(tag) if is_tag_valid_for_variant(py, tag, variant, encoding) => { + let decoded_value = + decode_annotated_type(py, parser, variant.ann_type.get())?; + return match &variant.tag_name { + Some(tag_name) => Ok(variant + .python_class + .call1(py, (decoded_value, tag_name))? + .into_bound(py)), + None => Ok(decoded_value), + }; + } + _ => continue, + } + } + Err(CryptographyError::Py( + pyo3::exceptions::PyValueError::new_err( + "could not find matching variant when parsing CHOICE field".to_string(), + ), + ))? + } + }, Type::PyBool() => decode_pybool(py, parser, encoding)?.into_any(), Type::PyInt() => decode_pyint(py, parser, encoding)?.into_any(), Type::PyBytes() => decode_pybytes(py, parser, annotation)?.into_any(), @@ -244,3 +313,33 @@ pub(crate) fn decode_annotated_type<'a>( _ => Ok(decoded), } } + +#[cfg(test)] +mod tests { + use crate::declarative_asn1::types::{AnnotatedType, Annotation, Encoding, Type, Variant}; + #[test] + fn test_decode_implicit_choice() { + pyo3::Python::initialize(); + pyo3::Python::attach(|py| { + let result = asn1::parse(&[], |parser| { + let variants: Vec = vec![]; + let choice = Type::Choice(pyo3::types::PyList::new(py, variants)?.unbind()); + let annotation = Annotation { + default: None, + encoding: None, + size: None, + }; + let ann_type = AnnotatedType { + inner: pyo3::Py::new(py, choice)?, + annotation: pyo3::Py::new(py, annotation)?, + }; + let encoding = Encoding::Implicit(0); + super::decode_choice_with_encoding(py, parser, &ann_type, &encoding) + }); + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(format!("{error}") + .contains("invalid type definition: CHOICE fields cannot be implicitly encoded")); + }); + } +} diff --git a/src/rust/src/declarative_asn1/encode.rs b/src/rust/src/declarative_asn1/encode.rs index 3d666e7a5c38..57abc95b5883 100644 --- a/src/rust/src/declarative_asn1/encode.rs +++ b/src/rust/src/declarative_asn1/encode.rs @@ -7,7 +7,7 @@ use pyo3::types::{PyAnyMethods, PyListMethods}; use crate::declarative_asn1::types::{ check_size_constraint, AnnotatedType, AnnotatedTypeObject, BitString, Encoding, - GeneralizedTime, IA5String, PrintableString, Type, UtcTime, + GeneralizedTime, IA5String, PrintableString, Type, UtcTime, Variant, }; fn write_value( @@ -105,6 +105,59 @@ impl asn1::Asn1Writable for AnnotatedTypeObject<'_> { Ok(()) } } + Type::Choice(ts) => { + for t in ts.bind(py) { + let variant = t + .cast::() + .map_err(|_| asn1::WriteError::AllocationError)? + .get(); + + if !value.is_exact_instance(variant.python_class.bind(py)) { + continue; + } + + // Check if this variant matches the value + let matches = match &variant.tag_name { + Some(expected_tag) => { + let value_tag: String = value + .getattr("tag") + .map_err(|_| asn1::WriteError::AllocationError)? + .extract() + .map_err(|_| asn1::WriteError::AllocationError)?; + &value_tag == expected_tag + } + None => true, + }; + + if matches { + let val = if variant.tag_name.is_some() { + value + .getattr("value") + .map_err(|_| asn1::WriteError::AllocationError)? + } else { + value + }; + let object = AnnotatedTypeObject { + annotated_type: variant.ann_type.get(), + value: val, + }; + match encoding { + Some(e) => match e.get() { + // CHOICEs cannot be IMPLICIT. See X.680 section 31.2.9. + Encoding::Implicit(_) => { + return Err(asn1::WriteError::AllocationError) + } + Encoding::Explicit(n) => { + return writer.write_explicit_element(&object, *n) + } + }, + None => return object.write(writer), + } + } + } + // No matching variant found + Err(asn1::WriteError::AllocationError) + } Type::PyBool() => { let val: bool = value .extract() @@ -212,3 +265,53 @@ impl asn1::Asn1Writable for AnnotatedTypeObject<'_> { } } } + +#[cfg(test)] +mod tests { + use crate::declarative_asn1::types::{ + AnnotatedType, AnnotatedTypeObject, Annotation, Encoding, Type, Variant, + }; + use asn1::Asn1Writable; + use pyo3::PyTypeInfo; + #[test] + fn test_encode_implicit_choice() { + pyo3::Python::initialize(); + pyo3::Python::attach(|py| { + let annotation = Annotation { + default: None, + encoding: None, + size: None, + }; + let ann_type_variant = AnnotatedType { + inner: pyo3::Py::new(py, Type::PyInt()).unwrap(), + annotation: pyo3::Py::new(py, annotation).unwrap(), + }; + let variant = Variant { + python_class: pyo3::types::PyInt::type_object(py).unbind(), + ann_type: pyo3::Py::new(py, ann_type_variant).unwrap(), + tag_name: None, + }; + + let variants = vec![variant]; + let choice = Type::Choice(pyo3::types::PyList::new(py, variants).unwrap().unbind()); + let annotation = Annotation { + default: None, + encoding: Some(pyo3::Py::new(py, Encoding::Implicit(0)).unwrap()), + size: None, + }; + let ann_type = AnnotatedType { + inner: pyo3::Py::new(py, choice).unwrap(), + annotation: pyo3::Py::new(py, annotation).unwrap(), + }; + + let value = pyo3::types::PyInt::new(py, 3).into_any(); + let object = AnnotatedTypeObject { + annotated_type: &ann_type, + value, + }; + + let result = asn1::write(|writer| object.write(writer)); + assert!(result.is_err()); + }); + } +} diff --git a/src/rust/src/declarative_asn1/types.rs b/src/rust/src/declarative_asn1/types.rs index 19374203ea54..2eecf2753efd 100644 --- a/src/rust/src/declarative_asn1/types.rs +++ b/src/rust/src/declarative_asn1/types.rs @@ -25,6 +25,9 @@ pub enum Type { SequenceOf(pyo3::Py), /// OPTIONAL (`T | None`) Option(pyo3::Py), + /// CHOICE (`T | U | ...`) + /// The list contains elements of type Variant + Choice(pyo3::Py), // Python types that we map to canonical ASN.1 types // @@ -55,6 +58,7 @@ pub enum Type { #[pyo3::pyclass(frozen, module = "cryptography.hazmat.bindings._rust.asn1")] #[derive(Debug)] pub struct AnnotatedType { + #[pyo3(get)] pub inner: pyo3::Py, #[pyo3(get)] pub annotation: pyo3::Py, @@ -135,6 +139,32 @@ impl Size { } } +#[pyo3::pyclass(frozen, module = "cryptography.hazmat.bindings._rust.asn1")] +pub struct Variant { + #[pyo3(get)] + pub python_class: pyo3::Py, + #[pyo3(get)] + pub ann_type: pyo3::Py, + #[pyo3(get)] + pub tag_name: Option, +} + +#[pyo3::pymethods] +impl Variant { + #[new] + fn new( + python_class: pyo3::Py, + ann_type: pyo3::Py, + tag_name: Option, + ) -> Self { + Self { + python_class, + ann_type, + tag_name, + } + } +} + // TODO: Once the minimum Python version is >= 3.10, use a `self_cell` // to store the owned PyString along with the dependent Asn1PrintableString // in order to avoid verifying the string twice (once during construction, @@ -436,9 +466,34 @@ fn check_tag_with_encoding( tag_with_encoding == tag } +// Utility function to see if a tag matches an unnanotated variant. +pub(crate) fn is_tag_valid_for_variant( + py: pyo3::Python<'_>, + tag: asn1::Tag, + variant: &Variant, + encoding: &Option>, +) -> bool { + let ann_type = variant.ann_type.get(); + + // There are two encodings at play here: the encoding of the CHOICE itself, + // and the encoding of each of the variants. The encoding of the CHOICE will + // only affect the tag if it's EXPLICIT (where it adds a wrapper). Otherwise, + // we use the encoding of the variant. + let encoding_to_match = match encoding { + Some(e) => match e.get() { + Encoding::Implicit(_) => &ann_type.annotation.get().encoding, + Encoding::Explicit(_) => encoding, + }, + None => &ann_type.annotation.get().encoding, + }; + + is_tag_valid_for_type(py, tag, ann_type.inner.get(), encoding_to_match) +} + // Given `tag` and `encoding`, returns whether that tag with that encoding // matches what one would expect to see when decoding `type_` pub(crate) fn is_tag_valid_for_type( + py: pyo3::Python<'_>, tag: asn1::Tag, type_: &Type, encoding: &Option>, @@ -446,7 +501,10 @@ pub(crate) fn is_tag_valid_for_type( match type_ { Type::Sequence(_, _) => check_tag_with_encoding(asn1::Sequence::TAG, encoding, tag), Type::SequenceOf(_) => check_tag_with_encoding(asn1::Sequence::TAG, encoding, tag), - Type::Option(t) => is_tag_valid_for_type(tag, t.get().inner.get(), encoding), + Type::Option(t) => is_tag_valid_for_type(py, tag, t.get().inner.get(), encoding), + Type::Choice(variants) => variants.bind(py).into_iter().any(|v| { + is_tag_valid_for_variant(py, tag, v.cast::().unwrap().get(), encoding) + }), Type::PyBool() => check_tag_with_encoding(bool::TAG, encoding, tag), Type::PyInt() => check_tag_with_encoding(asn1::BigInt::TAG, encoding, tag), Type::PyBytes() => { @@ -492,9 +550,12 @@ pub(crate) fn check_size_constraint( mod tests { use asn1::SimpleAsn1Readable; - use pyo3::IntoPyObject; + use pyo3::{IntoPyObject, PyTypeInfo}; - use super::{is_tag_valid_for_type, AnnotatedType, Annotation, Type}; + use super::{ + is_tag_valid_for_type, is_tag_valid_for_variant, AnnotatedType, Annotation, Encoding, Type, + Variant, + }; #[test] // Needed for coverage of `is_tag_valid_for_type(Type::Option(..))`, since @@ -534,10 +595,48 @@ mod tests { ) .unwrap(); assert!(is_tag_valid_for_type( + py, asn1::BigInt::TAG, &Type::Option(optional_type), &None )); }) } + #[test] + // Needed for coverage of + // `is_tag_valid_for_variant(..., encoding=Encoding::Implicit)`, since + // `is_tag_valid_for_variant` is never called with an implicit encoding. + fn test_is_tag_valid_for_implicit_variant() { + pyo3::Python::initialize(); + + pyo3::Python::attach(|py| { + let ann_type = pyo3::Py::new( + py, + AnnotatedType { + inner: pyo3::Py::new(py, Type::PyInt()).unwrap(), + annotation: Annotation { + default: None, + encoding: None, + size: None, + } + .into_pyobject(py) + .unwrap() + .unbind(), + }, + ) + .unwrap(); + let variant = Variant { + python_class: pyo3::types::PyInt::type_object(py).unbind(), + ann_type, + tag_name: None, + }; + let encoding = pyo3::Py::new(py, Encoding::Implicit(3)).ok(); + assert!(is_tag_valid_for_variant( + py, + asn1::BigInt::TAG, + &variant, + &encoding + )); + }) + } } diff --git a/src/rust/src/lib.rs b/src/rust/src/lib.rs index 32067299f467..d010951de661 100644 --- a/src/rust/src/lib.rs +++ b/src/rust/src/lib.rs @@ -151,7 +151,7 @@ mod _rust { #[pymodule_export] use crate::declarative_asn1::types::{ non_root_python_to_rust, AnnotatedType, Annotation, BitString, Encoding, - GeneralizedTime, IA5String, PrintableString, Size, Type, UtcTime, + GeneralizedTime, IA5String, PrintableString, Size, Type, UtcTime, Variant, }; } diff --git a/tests/hazmat/asn1/test_api.py b/tests/hazmat/asn1/test_api.py index 6c338c051c9f..1caaa08f5a6b 100644 --- a/tests/hazmat/asn1/test_api.py +++ b/tests/hazmat/asn1/test_api.py @@ -215,17 +215,6 @@ class Invalid: class Example: foo: Invalid - def test_fail_unsupported_union_field(self) -> None: - with pytest.raises( - TypeError, - match="union types other than `X \\| None` are currently not " - "supported", - ): - - @asn1.sequence - class Example: - invalid: typing.Union[int, str] - def test_fail_unsupported_annotation(self) -> None: with pytest.raises( TypeError, match="unsupported annotation: some annotation" @@ -323,6 +312,37 @@ class Example2: Annotated[int, asn1.Default(value=9)], None ] + def test_fail_choice_with_inconsistent_types(self) -> None: + with pytest.raises( + TypeError, + match=re.escape( + "When using `asn1.Variant` in a union, all the other " + "types in the union must also be `asn1.Variant`" + ), + ): + + @asn1.sequence + class Example2: + invalid: typing.Union[ + int, asn1.Variant[bool, typing.Literal["myTag"]] + ] + + def test_fail_choice_with_duplicate_tags(self) -> None: + with pytest.raises( + TypeError, + match=re.escape( + "When using `asn1.Variant` in a union, the tags used " + "must be unique" + ), + ): + + @asn1.sequence + class Example2: + invalid: typing.Union[ + asn1.Variant[int, typing.Literal["myTag"]], + asn1.Variant[bool, typing.Literal["myTag"]], + ] + def test_fields_of_variant_type(self) -> None: from cryptography.hazmat.bindings._rust import declarative_asn1 @@ -341,6 +361,10 @@ def test_fields_of_variant_type(self) -> None: seq_of = declarative_asn1.Type.SequenceOf(ann_type) assert seq_of._0 is ann_type + my_list: typing.List[int] = list() + choice = declarative_asn1.Type.Choice(my_list) + assert choice._0 is my_list + def test_fields_of_variant_encoding(self) -> None: from cryptography.hazmat.bindings._rust import declarative_asn1 @@ -350,3 +374,39 @@ def test_fields_of_variant_encoding(self) -> None: explicit = declarative_asn1.Encoding.Explicit(0) assert implicit._0 == 0 assert explicit._0 == 0 + + def test_fail_choice_with_implicit_encoding(self) -> None: + with pytest.raises( + TypeError, + match=re.escape( + "CHOICE (`X | Y | ...`) types should not have an IMPLICIT " + "annotation" + ), + ): + + @asn1.sequence + class Example: + invalid: Annotated[typing.Union[int, bool], asn1.Implicit(0)] + + def test_fail_choice_with_non_literal_tag(self) -> None: + with pytest.raises( + TypeError, + match=re.escape( + "When using `asn1.Variant` in a type annotation, the second " + "type parameter must be a `typing.Literal` type. E.g: " + '`Variant[int, typing.Literal["MyInt"]]`.' + ), + ): + + @asn1.sequence + class Example: + foo: typing.Union[ + Annotated[ + asn1.Variant[int, str], + asn1.Implicit(0), + ], + Annotated[ + asn1.Variant[int, typing.Literal["IntB"]], + asn1.Implicit(1), + ], + ] diff --git a/tests/hazmat/asn1/test_serialization.py b/tests/hazmat/asn1/test_serialization.py index da3d390916a9..718d0bcc4fc2 100644 --- a/tests/hazmat/asn1/test_serialization.py +++ b/tests/hazmat/asn1/test_serialization.py @@ -494,6 +494,76 @@ class Example: ] ) + def test_ok_sequence_all_types_default(self) -> None: + default_time = datetime.datetime( + 2019, + 12, + 16, + 3, + 2, + 10, + tzinfo=datetime.timezone.utc, + ) + default_oid = x509.ObjectIdentifier("1.3.6.1.4.1.343") + + @asn1.sequence + @_comparable_dataclass + class Example: + a: Annotated[int, asn1.Default(3)] + b: Annotated[bytes, asn1.Default(b"\x00")] + c: Annotated[ + asn1.PrintableString, asn1.Default(asn1.PrintableString("a")) + ] + d: Annotated[ + asn1.UtcTime, + asn1.Default(asn1.UtcTime(default_time)), + ] + e: Annotated[ + asn1.GeneralizedTime, + asn1.Default(asn1.GeneralizedTime(default_time)), + ] + f: Annotated[typing.List[int], asn1.Default([1])] + g: Annotated[ + asn1.BitString, + asn1.Default( + asn1.BitString(data=b"", padding_bits=0), + ), + ] + h: Annotated[asn1.IA5String, asn1.Default(asn1.IA5String("a"))] + i: Annotated[ + x509.ObjectIdentifier, + asn1.Default(default_oid), + ] + j: Annotated[ + typing.Union[int, bool], asn1.Default(3), asn1.Explicit(0) + ] + k: Annotated[str, asn1.Default("a"), asn1.Implicit(0)] + only_field_present: Annotated[ + str, asn1.Default("a"), asn1.Implicit(1) + ] + + assert_roundtrips( + [ + ( + Example( + a=3, + b=b"\x00", + c=asn1.PrintableString("a"), + d=asn1.UtcTime(default_time), + e=asn1.GeneralizedTime(default_time), + f=[1], + g=asn1.BitString(data=b"", padding_bits=0), + h=asn1.IA5String("a"), + i=default_oid, + j=3, + k="a", + only_field_present="b", + ), + b"\x30\x03\x81\x01b", + ) + ] + ) + def test_ok_sequence_with_default_annotations(self) -> None: @asn1.sequence @_comparable_dataclass @@ -607,6 +677,246 @@ class Example: ): asn1.decode_der(Example, b"\x30\x05\xa2\x03\x02\x01\x09") + def test_sequence_with_choice(self) -> None: + @asn1.sequence + @_comparable_dataclass + class Example: + foo: typing.Union[int, bool, str] + + assert_roundtrips( + [ + (Example(foo=9), b"\x30\x03\x02\x01\x09"), + (Example(foo=True), b"\x30\x03\x01\x01\xff"), + (Example(foo="a"), b"\x30\x03\x0c\x01a"), + ] + ) + + def test_sequence_with_optional_choice(self) -> None: + @asn1.sequence + @_comparable_dataclass + class Example: + foo: typing.Union[bool, str, None] + bar: int + + assert_roundtrips( + [ + ( + Example(foo=True, bar=1), + b"\x30\x06\x01\x01\xff\x02\x01\x01", + ), + (Example(foo=None, bar=1), b"\x30\x03\x02\x01\x01"), + ] + ) + + def test_fail_sequence_with_choice_decode_nonexistent_variant( + self, + ) -> None: + @asn1.sequence + @_comparable_dataclass + class Example: + foo: typing.Union[bool, str] + + with pytest.raises( + ValueError, + match=re.escape( + "could not find matching variant when parsing CHOICE field" + ), + ): + asn1.decode_der(Example, b"\x30\x03\x02\x01\x09") + + def test_fail_sequence_with_choice_encode_nonexistent_variant( + self, + ) -> None: + @asn1.sequence + @_comparable_dataclass + class Example: + foo: typing.Union[bool, str] + + with pytest.raises( + ValueError, + ): + asn1.encode_der(Example(foo=3)) # type: ignore[arg-type] + + def test_sequence_with_explicit_choice(self) -> None: + @asn1.sequence + @_comparable_dataclass + class Example: + foo: Annotated[typing.Union[int, bool, str], asn1.Explicit(3)] + + assert_roundtrips( + [ + (Example(foo=9), b"\x30\x05\xa3\x03\x02\x01\x09"), + (Example(foo=True), b"\x30\x05\xa3\x03\x01\x01\xff"), + (Example(foo="a"), b"\x30\x05\xa3\x03\x0c\x01a"), + ] + ) + + def test_sequence_with_choice_implicit_simple_variants(self) -> None: + @asn1.sequence + @_comparable_dataclass + class Example: + foo: typing.Union[ + Annotated[int, asn1.Implicit(0)], + Annotated[bool, asn1.Implicit(1)], + Annotated[str, asn1.Implicit(2)], + ] + + assert_roundtrips( + [ + (Example(foo=9), b"\x30\x03\x80\x01\x09"), + (Example(foo=True), b"\x30\x03\x81\x01\xff"), + (Example(foo="a"), b"\x30\x03\x82\x01a"), + ] + ) + + def test_sequence_with_choice_explicit_simple_variants(self) -> None: + @asn1.sequence + @_comparable_dataclass + class Example: + foo: typing.Union[ + Annotated[int, asn1.Explicit(0)], + Annotated[bool, asn1.Explicit(1)], + Annotated[str, asn1.Explicit(2)], + ] + + assert_roundtrips( + [ + (Example(foo=9), b"\x30\x05\xa0\x03\x02\x01\x09"), + (Example(foo=True), b"\x30\x05\xa1\x03\x01\x01\xff"), + (Example(foo="a"), b"\x30\x05\xa2\x03\x0c\x01a"), + ] + ) + + def test_sequence_with_choice_with_custom_variants(self) -> None: + @asn1.sequence + @_comparable_dataclass + class Example: + foo: typing.Union[ + Annotated[ + asn1.Variant[int, typing.Literal["IntA"]], asn1.Implicit(0) + ], + Annotated[ + asn1.Variant[int, typing.Literal["IntB"]], asn1.Implicit(1) + ], + Annotated[ + asn1.Variant[int, typing.Literal["IntC"]], asn1.Implicit(2) + ], + ] + + assert_roundtrips( + [ + ( + Example(foo=asn1.Variant(9, "IntA")), + b"\x30\x03\x80\x01\x09", + ), + ( + Example(foo=asn1.Variant(9, "IntB")), + b"\x30\x03\x81\x01\x09", + ), + ( + Example(foo=asn1.Variant(9, "IntC")), + b"\x30\x03\x82\x01\x09", + ), + ] + ) + + def test_sequence_with_choice_with_custom_variants_bool(self) -> None: + @asn1.sequence + @_comparable_dataclass + class Example: + foo: typing.Union[ + Annotated[ + asn1.Variant[bool, typing.Literal["BoolA"]], + asn1.Implicit(0), + ], + Annotated[ + asn1.Variant[bool, typing.Literal["BoolB"]], + asn1.Implicit(1), + ], + Annotated[ + asn1.Variant[bool, typing.Literal["BoolC"]], + asn1.Implicit(2), + ], + ] + + assert_roundtrips( + [ + ( + Example(foo=asn1.Variant(True, "BoolA")), + b"\x30\x03\x80\x01\xff", + ), + ( + Example(foo=asn1.Variant(True, "BoolB")), + b"\x30\x03\x81\x01\xff", + ), + ( + Example(foo=asn1.Variant(True, "BoolC")), + b"\x30\x03\x82\x01\xff", + ), + ] + ) + + def test_sequence_with_choice_with_sequence_variants(self) -> None: + @asn1.sequence + @_comparable_dataclass + class Example: + foo: int + + @asn1.sequence + @_comparable_dataclass + class ExampleUnion: + field: typing.Union[ + Annotated[ + asn1.Variant[Example, typing.Literal["ExampleA"]], + asn1.Implicit(0), + ], + Annotated[ + asn1.Variant[Example, typing.Literal["ExampleB"]], + asn1.Implicit(1), + ], + ] + + assert_roundtrips( + [ + ( + ExampleUnion( + field=asn1.Variant(Example(foo=9), "ExampleA") + ), + b"\x30\x05\xa0\x03\x02\x01\x09", + ), + ( + ExampleUnion( + field=asn1.Variant(Example(foo=9), "ExampleB") + ), + b"\x30\x05\xa1\x03\x02\x01\x09", + ), + ] + ) + + def test_sequence_with_choice_with_non_annotated_custom_variants( + self, + ) -> None: + @asn1.sequence + @_comparable_dataclass + class Example: + foo: typing.Union[ + asn1.Variant[int, typing.Literal["MyInt"]], + asn1.Variant[bool, typing.Literal["MyBool"]], + ] + + assert_roundtrips( + [ + ( + Example(foo=asn1.Variant(9, "MyInt")), + b"\x30\x03\x02\x01\x09", + ), + ( + Example(foo=asn1.Variant(True, "MyBool")), + b"\x30\x03\x01\x01\xff", + ), + ] + ) + class TestSize: def test_ok_sequenceof_size_restriction(self) -> None: