From db9795f7ef692dcc8eb36618532cbe7955a51092 Mon Sep 17 00:00:00 2001 From: Facundo Tuesca Date: Thu, 22 Jan 2026 01:34:03 +0100 Subject: [PATCH 01/11] asn1: Add support for CHOICE fields Signed-off-by: Facundo Tuesca --- src/cryptography/hazmat/asn1/__init__.py | 2 + src/cryptography/hazmat/asn1/asn1.py | 88 ++++++- .../bindings/_rust/declarative_asn1.pyi | 13 + src/rust/src/declarative_asn1/decode.rs | 115 ++++++++- src/rust/src/declarative_asn1/encode.rs | 105 +++++++- src/rust/src/declarative_asn1/types.rs | 106 ++++++-- src/rust/src/lib.rs | 2 +- tests/hazmat/asn1/test_api.py | 28 ++- tests/hazmat/asn1/test_serialization.py | 230 ++++++++++++++++++ 9 files changed, 642 insertions(+), 47 deletions(-) diff --git a/src/cryptography/hazmat/asn1/__init__.py b/src/cryptography/hazmat/asn1/__init__.py index 0ba1cd73e8fe..147d1d088b55 100644 --- a/src/cryptography/hazmat/asn1/__init__.py +++ b/src/cryptography/hazmat/asn1/__init__.py @@ -12,6 +12,7 @@ PrintableString, Size, UtcTime, + Variant, decode_der, encode_der, sequence, @@ -27,6 +28,7 @@ "PrintableString", "Size", "UtcTime", + "Variant", "decode_der", "encode_der", "sequence", diff --git a/src/cryptography/hazmat/asn1/asn1.py b/src/cryptography/hazmat/asn1/asn1.py index 5d8b245ac56a..07e29e3ea516 100644 --- a/src/cryptography/hazmat/asn1/asn1.py +++ b/src/cryptography/hazmat/asn1/asn1.py @@ -41,6 +41,30 @@ T = typing.TypeVar("T", covariant=True) U = typing.TypeVar("U") +Tag = typing.TypeVar("Tag") + + +@dataclasses.dataclass(frozen=True) +class Variant(typing.Generic[U, Tag]): + """ + A tagged variant for CHOICE fields with the same underlying type. + + Use this when you have multiple CHOICE alternatives with the same type + and need to distinguish between them: + + foo: ( + Annotated[Variant[int, "IntA"], Implicit(0)] + | Annotated[Variant[int, "IntB"], Implicit(1)] + ) + + Usage: + example = Example(foo=Variant(5, "IntA")) + decoded.foo.value # The int value + decoded.foo.tag # "IntA" or "IntB" + """ + + value: U + tag: str decode_der = declarative_asn1.decode_der @@ -150,10 +174,31 @@ def _normalize_field_type( ) rust_field_type = declarative_asn1.Type.Option(annotated_type) + else: - raise TypeError( - "union types other than `X | None` are currently not supported" + # Otherwise, the Union is a CHOICE + if isinstance(annotation.encoding, Implicit): + raise TypeError( + "CHOICE (`X | Y | ...`) types should not have an IMPLICIT " + "annotation" + ) + variants = [ + _type_to_variant(arg, field_name) + for arg in union_args + if arg is not type(None) + ] + rust_choice_type = declarative_asn1.Type.Choice(variants) + # If None is part of the union types, this is an OPTIONAL CHOICE + rust_field_type = ( + declarative_asn1.Type.Option( + declarative_asn1.AnnotatedType( + rust_choice_type, declarative_asn1.Annotation() + ) + ) + if NoneType in union_args + else rust_choice_type ) + elif get_type_origin(field_type) is builtins.list: inner_type = _normalize_field_type( get_type_args(field_type)[0], field_name @@ -165,6 +210,45 @@ def _normalize_field_type( return declarative_asn1.AnnotatedType(rust_field_type, annotation) +# Convert a type to a Variant. Used with types inside Union +# annotations (T1, T2, etc in `Union[T1, T2, ...]`). +def _type_to_variant( + t: typing.Any, field_name: str +) -> declarative_asn1.Variant: + is_annotated = get_type_origin(t) is Annotated + inner_type = get_type_args(t)[0] if is_annotated else t + + # Check if this is a Variant[T, Tag] type + if get_type_origin(inner_type) is Variant: + value_type, tag_literal = get_type_args(inner_type) + tag_name = get_type_args(tag_literal)[0] + + if hasattr(value_type, "__asn1_root__"): + rust_type = value_type.__asn1_root__.inner + else: + rust_type = declarative_asn1.non_root_python_to_rust(value_type) + + if is_annotated: + ann_type = declarative_asn1.AnnotatedType( + rust_type, + _extract_annotation(t.__metadata__, field_name), + ) + else: + ann_type = declarative_asn1.AnnotatedType( + rust_type, + declarative_asn1.Annotation(), + ) + + return declarative_asn1.Variant(Variant, ann_type, tag_name) + else: + # Plain type (not a tagged Variant) + return declarative_asn1.Variant( + inner_type, + _normalize_field_type(t, field_name), + None, + ) + + def _annotate_fields( raw_fields: dict[str, type], ) -> dict[str, declarative_asn1.AnnotatedType]: diff --git a/src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi b/src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi index 4281d0c617bb..6777ef71241c 100644 --- a/src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi +++ b/src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi @@ -15,6 +15,7 @@ class Type: Sequence: typing.ClassVar[type] SequenceOf: typing.ClassVar[type] Option: typing.ClassVar[type] + Choice: typing.ClassVar[type] PyBool: typing.ClassVar[type] PyInt: typing.ClassVar[type] PyBytes: typing.ClassVar[type] @@ -60,6 +61,18 @@ class AnnotatedTypeObject: cls, annotated_type: AnnotatedType, value: typing.Any ) -> AnnotatedTypeObject: ... +class Variant: + python_class: type + ann_type: AnnotatedType + tag_name: str | None + + def __new__( + cls, + python_class: type, + ann_type: AnnotatedType, + tag_name: str | None, + ) -> Variant: ... + class PrintableString: def __new__(cls, inner: str) -> PrintableString: ... def __repr__(self) -> str: ... diff --git a/src/rust/src/declarative_asn1/decode.rs b/src/rust/src/declarative_asn1/decode.rs index 1d08b503e4af..b86360af86d4 100644 --- a/src/rust/src/declarative_asn1/decode.rs +++ b/src/rust/src/declarative_asn1/decode.rs @@ -7,8 +7,9 @@ use pyo3::types::{PyAnyMethods, PyListMethods}; use crate::asn1::big_byte_slice_to_py_int; use crate::declarative_asn1::types::{ - check_size_constraint, type_to_tag, AnnotatedType, Annotation, BitString, Encoding, - GeneralizedTime, IA5String, PrintableString, Type, UtcTime, + check_size_constraint, expected_tags_for_type, expected_tags_for_variant, AnnotatedType, + Annotation, BitString, Encoding, GeneralizedTime, IA5String, PrintableString, Type, UtcTime, + Variant, }; use crate::error::CryptographyError; @@ -160,6 +161,46 @@ fn decode_bitstring<'a>( )?) } +// Utility function to handle explicit encoding when parsing +// CHOICE fields. +fn decode_choice_with_encoding<'a>( + py: pyo3::Python<'a>, + parser: &mut Parser<'a>, + ann_type: &AnnotatedType, + encoding: &Encoding, +) -> ParseResult> { + match encoding { + Encoding::Implicit(_) => Err(CryptographyError::Py( + pyo3::exceptions::PyValueError::new_err( + "invalid type definition: CHOICE fields cannot be implicitly encoded".to_string(), + ), + ))?, + Encoding::Explicit(n) => { + // Since we don't know which of the variants is present for this + // CHOICE field, we'll parse this as a generic TLV encoded with + // EXPLICIT, so `read_explicit_element` will consume the EXPLICIT + // wrapper tag, and the TLV data will contain the variant. + let tlv = parser.read_explicit_element::>(*n)?; + let type_without_explicit = AnnotatedType { + inner: ann_type.inner.clone_ref(py), + annotation: pyo3::Py::new( + py, + Annotation { + default: None, + encoding: None, + size: None, + }, + )?, + }; + // Parse the TLV data (which contains the field without the EXPLICIT + // wrapper) + asn1::parse(tlv.full_data(), |d| { + decode_annotated_type(py, d, &type_without_explicit) + }) + } + } +} + pub(crate) fn decode_annotated_type<'a>( py: pyo3::Python<'a>, parser: &mut Parser<'a>, @@ -172,10 +213,10 @@ pub(crate) fn decode_annotated_type<'a>( // Handle DEFAULT annotation if field is not present (by // returning the default value) if let Some(default) = &ann_type.annotation.get().default { - let expected_tag = type_to_tag(inner, encoding); - let next_tag = parser.peek_tag(); - if next_tag != Some(expected_tag) { - return Ok(default.clone_ref(py).into_bound(py)); + let expected_tags = expected_tags_for_type(py, inner, encoding); + match parser.peek_tag() { + Some(next_tag) if expected_tags.contains(&next_tag) => (), + _ => return Ok(default.clone_ref(py).into_bound(py)), } } @@ -210,9 +251,9 @@ pub(crate) fn decode_annotated_type<'a>( })? } Type::Option(cls) => { - let inner_tag = type_to_tag(cls.get().inner.get(), encoding); + let expected_tags = expected_tags_for_type(py, cls.get().inner.get(), encoding); match parser.peek_tag() { - Some(t) if t == inner_tag => { + Some(t) if expected_tags.contains(&t) => { // For optional types, annotations will always be associated to the `Optional` type // i.e: `Annotated[Optional[T], annotation]`, as opposed to the inner `T` type. // Therefore, when decoding the inner type `T` we must pass the annotation of the `Optional` @@ -225,6 +266,34 @@ pub(crate) fn decode_annotated_type<'a>( _ => pyo3::types::PyNone::get(py).to_owned().into_any(), } } + Type::Choice(ts) => match encoding { + Some(e) => decode_choice_with_encoding(py, parser, ann_type, e.get())?, + None => { + for t in ts.bind(py) { + let variant = t.cast::()?.get(); + let expected_tags = expected_tags_for_variant(py, variant); + match parser.peek_tag() { + Some(tag) if expected_tags.contains(&tag) => { + let decoded_value = + decode_annotated_type(py, parser, variant.ann_type.get())?; + return match &variant.tag_name { + Some(tag_name) => Ok(variant + .python_class + .call1(py, (decoded_value, tag_name))? + .into_bound(py)), + None => Ok(decoded_value), + }; + } + _ => continue, + } + } + Err(CryptographyError::Py( + pyo3::exceptions::PyValueError::new_err( + "could not find matching variant when parsing CHOICE field".to_string(), + ), + ))? + } + }, Type::PyBool() => decode_pybool(py, parser, encoding)?.into_any(), Type::PyInt() => decode_pyint(py, parser, encoding)?.into_any(), Type::PyBytes() => decode_pybytes(py, parser, annotation)?.into_any(), @@ -246,3 +315,33 @@ pub(crate) fn decode_annotated_type<'a>( _ => Ok(decoded), } } + +#[cfg(test)] +mod tests { + use crate::declarative_asn1::types::{AnnotatedType, Annotation, Encoding, Type, Variant}; + #[test] + fn test_decode_implicit_choice() { + pyo3::Python::initialize(); + pyo3::Python::attach(|py| { + let result = asn1::parse(&[], |parser| { + let variants: Vec = vec![]; + let choice = Type::Choice(pyo3::types::PyList::new(py, variants)?.unbind()); + let annotation = Annotation { + default: None, + encoding: None, + size: None, + }; + let ann_type = AnnotatedType { + inner: pyo3::Py::new(py, choice)?, + annotation: pyo3::Py::new(py, annotation)?, + }; + let encoding = Encoding::Implicit(0); + super::decode_choice_with_encoding(py, parser, &ann_type, &encoding) + }); + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(format!("{error}") + .contains("invalid type definition: CHOICE fields cannot be implicitly encoded")); + }); + } +} diff --git a/src/rust/src/declarative_asn1/encode.rs b/src/rust/src/declarative_asn1/encode.rs index 3d666e7a5c38..6a6346822b56 100644 --- a/src/rust/src/declarative_asn1/encode.rs +++ b/src/rust/src/declarative_asn1/encode.rs @@ -7,7 +7,7 @@ use pyo3::types::{PyAnyMethods, PyListMethods}; use crate::declarative_asn1::types::{ check_size_constraint, AnnotatedType, AnnotatedTypeObject, BitString, Encoding, - GeneralizedTime, IA5String, PrintableString, Type, UtcTime, + GeneralizedTime, IA5String, PrintableString, Type, UtcTime, Variant, }; fn write_value( @@ -105,6 +105,59 @@ impl asn1::Asn1Writable for AnnotatedTypeObject<'_> { Ok(()) } } + Type::Choice(ts) => { + for t in ts.bind(py) { + let variant = t + .cast::() + .map_err(|_| asn1::WriteError::AllocationError)? + .get(); + + if !value.is_exact_instance(variant.python_class.bind(py)) { + continue; + } + + // Check if this variant matches the value + let matches = match &variant.tag_name { + Some(expected_tag) => { + let value_tag: String = value + .getattr("tag") + .map_err(|_| asn1::WriteError::AllocationError)? + .extract() + .map_err(|_| asn1::WriteError::AllocationError)?; + &value_tag == expected_tag + } + None => true, + }; + + if matches { + let val = if variant.tag_name.is_some() { + value + .getattr("value") + .map_err(|_| asn1::WriteError::AllocationError)? + } else { + value + }; + let object = AnnotatedTypeObject { + annotated_type: variant.ann_type.get(), + value: val, + }; + match encoding { + Some(e) => match e.get() { + // CHOICE cannot be encoded as IMPLICIT + Encoding::Implicit(_) => { + return Err(asn1::WriteError::AllocationError) + } + Encoding::Explicit(n) => { + return writer.write_explicit_element(&object, *n) + } + }, + None => return object.write(writer), + } + } + } + // No matching variant found + Err(asn1::WriteError::AllocationError) + } Type::PyBool() => { let val: bool = value .extract() @@ -212,3 +265,53 @@ impl asn1::Asn1Writable for AnnotatedTypeObject<'_> { } } } + +#[cfg(test)] +mod tests { + use crate::declarative_asn1::types::{ + AnnotatedType, AnnotatedTypeObject, Annotation, Encoding, Type, Variant, + }; + use asn1::Asn1Writable; + use pyo3::PyTypeInfo; + #[test] + fn test_encode_implicit_choice() { + pyo3::Python::initialize(); + pyo3::Python::attach(|py| { + let annotation = Annotation { + default: None, + encoding: None, + size: None, + }; + let ann_type_variant = AnnotatedType { + inner: pyo3::Py::new(py, Type::PyInt()).unwrap(), + annotation: pyo3::Py::new(py, annotation).unwrap(), + }; + let variant = Variant { + python_class: pyo3::types::PyInt::type_object(py).unbind(), + ann_type: pyo3::Py::new(py, ann_type_variant).unwrap(), + tag_name: None, + }; + + let variants = vec![variant]; + let choice = Type::Choice(pyo3::types::PyList::new(py, variants).unwrap().unbind()); + let annotation = Annotation { + default: None, + encoding: Some(pyo3::Py::new(py, Encoding::Implicit(0)).unwrap()), + size: None, + }; + let ann_type = AnnotatedType { + inner: pyo3::Py::new(py, choice).unwrap(), + annotation: pyo3::Py::new(py, annotation).unwrap(), + }; + + let value = pyo3::types::PyInt::new(py, 3).into_any(); + let object = AnnotatedTypeObject { + annotated_type: &ann_type, + value, + }; + + let result = asn1::write(|writer| object.write(writer)); + assert!(result.is_err()); + }); + } +} diff --git a/src/rust/src/declarative_asn1/types.rs b/src/rust/src/declarative_asn1/types.rs index 63e2f59ec040..072075105a62 100644 --- a/src/rust/src/declarative_asn1/types.rs +++ b/src/rust/src/declarative_asn1/types.rs @@ -25,6 +25,9 @@ pub enum Type { SequenceOf(pyo3::Py), /// OPTIONAL (`T | None`) Option(pyo3::Py), + /// CHOICE (`T | U | ...`) + /// The list contains elements of type Variant + Choice(pyo3::Py), // Python types that we map to canonical ASN.1 types // @@ -55,6 +58,7 @@ pub enum Type { #[pyo3::pyclass(frozen, module = "cryptography.hazmat.bindings._rust.asn1")] #[derive(Debug)] pub struct AnnotatedType { + #[pyo3(get)] pub inner: pyo3::Py, #[pyo3(get)] pub annotation: pyo3::Py, @@ -135,6 +139,32 @@ impl Size { } } +#[pyo3::pyclass(frozen, module = "cryptography.hazmat.bindings._rust.asn1")] +pub struct Variant { + #[pyo3(get)] + pub python_class: pyo3::Py, + #[pyo3(get)] + pub ann_type: pyo3::Py, + #[pyo3(get)] + pub tag_name: Option, +} + +#[pyo3::pymethods] +impl Variant { + #[new] + fn new( + python_class: pyo3::Py, + ann_type: pyo3::Py, + tag_name: Option, + ) -> Self { + Self { + python_class, + ann_type, + tag_name, + } + } +} + // TODO: Once the minimum Python version is >= 3.10, use a `self_cell` // to store the owned PyString along with the dependent Asn1PrintableString // in order to avoid verifying the string twice (once during construction, @@ -419,29 +449,54 @@ pub(crate) fn python_class_to_annotated<'p>( } } -pub(crate) fn type_to_tag(t: &Type, encoding: &Option>) -> asn1::Tag { - let inner_tag = match t { - Type::Sequence(_, _) => asn1::Sequence::TAG, - Type::SequenceOf(_) => asn1::Sequence::TAG, - Type::Option(t) => type_to_tag(t.get().inner.get(), encoding), - Type::PyBool() => bool::TAG, - Type::PyInt() => asn1::BigInt::TAG, - Type::PyBytes() => <&[u8] as SimpleAsn1Readable>::TAG, - Type::PyStr() => asn1::Utf8String::TAG, - Type::PrintableString() => asn1::PrintableString::TAG, - Type::IA5String() => asn1::IA5String::TAG, - Type::ObjectIdentifier() => asn1::ObjectIdentifier::TAG, - Type::UtcTime() => asn1::UtcTime::TAG, - Type::GeneralizedTime() => asn1::GeneralizedTime::TAG, - Type::BitString() => asn1::BitString::TAG, +// Utility function to get the expected tags for an unnanotated variant. +pub(crate) fn expected_tags_for_variant(py: pyo3::Python<'_>, variant: &Variant) -> Vec { + let ann_type = variant.ann_type.get(); + expected_tags_for_type( + py, + ann_type.inner.get(), + &ann_type.annotation.get().encoding, + ) +} + +// Given a type, return the set of possible tags that we would expect +// to see when decoding it. This is usually a single tag per type, except +// when decoding a CHOICE value. +pub(crate) fn expected_tags_for_type( + py: pyo3::Python<'_>, + t: &Type, + encoding: &Option>, +) -> Vec { + let inner_tags = match t { + Type::Sequence(_, _) => vec![asn1::Sequence::TAG], + Type::SequenceOf(_) => vec![asn1::Sequence::TAG], + Type::Option(t) => expected_tags_for_type(py, t.get().inner.get(), encoding), + Type::Choice(variants) => variants + .bind(py) + .into_iter() + .flat_map(|v| expected_tags_for_variant(py, v.cast::().unwrap().get())) + .collect(), + Type::PyBool() => vec![bool::TAG], + Type::PyInt() => vec![asn1::BigInt::TAG], + Type::PyBytes() => vec![<&[u8] as SimpleAsn1Readable>::TAG], + Type::PyStr() => vec![asn1::Utf8String::TAG], + Type::PrintableString() => vec![asn1::PrintableString::TAG], + Type::IA5String() => vec![asn1::IA5String::TAG], + Type::ObjectIdentifier() => vec![asn1::ObjectIdentifier::TAG], + Type::UtcTime() => vec![asn1::UtcTime::TAG], + Type::GeneralizedTime() => vec![asn1::GeneralizedTime::TAG], + Type::BitString() => vec![asn1::BitString::TAG], }; match encoding { Some(e) => match e.get() { - Encoding::Implicit(n) => asn1::implicit_tag(*n, inner_tag), - Encoding::Explicit(n) => asn1::explicit_tag(*n), + Encoding::Implicit(n) => inner_tags + .into_iter() + .map(|x| asn1::implicit_tag(*n, x)) + .collect(), + Encoding::Explicit(n) => vec![asn1::explicit_tag(*n)], }, - None => inner_tag, + None => inner_tags, } } @@ -470,12 +525,12 @@ mod tests { use pyo3::IntoPyObject; - use super::{type_to_tag, AnnotatedType, Annotation, Type}; + use super::{expected_tags_for_type, AnnotatedType, Annotation, Type}; #[test] - // Needed for coverage of `type_to_tag(Type::Option(..))`, since - // `type_to_tag` is never called with an optional value. - fn test_option_type_to_tag() { + // Needed for coverage of `expected_tags_for_type(Type::Option(..))`, since + // `expected_tags_for_type` is never called with an optional value. + fn test_option_expected_tags_for_type() { pyo3::Python::initialize(); pyo3::Python::attach(|py| { @@ -509,8 +564,11 @@ mod tests { }, ) .unwrap(); - let expected_tag = type_to_tag(&Type::Option(optional_type), &None); - assert_eq!(expected_tag, type_to_tag(&Type::PyInt(), &None)) + let expected_tags = expected_tags_for_type(py, &Type::Option(optional_type), &None); + assert_eq!( + expected_tags, + expected_tags_for_type(py, &Type::PyInt(), &None) + ) }) } } diff --git a/src/rust/src/lib.rs b/src/rust/src/lib.rs index 32067299f467..d010951de661 100644 --- a/src/rust/src/lib.rs +++ b/src/rust/src/lib.rs @@ -151,7 +151,7 @@ mod _rust { #[pymodule_export] use crate::declarative_asn1::types::{ non_root_python_to_rust, AnnotatedType, Annotation, BitString, Encoding, - GeneralizedTime, IA5String, PrintableString, Size, Type, UtcTime, + GeneralizedTime, IA5String, PrintableString, Size, Type, UtcTime, Variant, }; } diff --git a/tests/hazmat/asn1/test_api.py b/tests/hazmat/asn1/test_api.py index 6c338c051c9f..4281ded6eb17 100644 --- a/tests/hazmat/asn1/test_api.py +++ b/tests/hazmat/asn1/test_api.py @@ -215,17 +215,6 @@ class Invalid: class Example: foo: Invalid - def test_fail_unsupported_union_field(self) -> None: - with pytest.raises( - TypeError, - match="union types other than `X \\| None` are currently not " - "supported", - ): - - @asn1.sequence - class Example: - invalid: typing.Union[int, str] - def test_fail_unsupported_annotation(self) -> None: with pytest.raises( TypeError, match="unsupported annotation: some annotation" @@ -341,6 +330,10 @@ def test_fields_of_variant_type(self) -> None: seq_of = declarative_asn1.Type.SequenceOf(ann_type) assert seq_of._0 is ann_type + my_list: typing.List[int] = list() + choice = declarative_asn1.Type.Choice(my_list) + assert choice._0 is my_list + def test_fields_of_variant_encoding(self) -> None: from cryptography.hazmat.bindings._rust import declarative_asn1 @@ -350,3 +343,16 @@ def test_fields_of_variant_encoding(self) -> None: explicit = declarative_asn1.Encoding.Explicit(0) assert implicit._0 == 0 assert explicit._0 == 0 + + def test_fail_choice_with_implicit_encoding(self) -> None: + with pytest.raises( + TypeError, + match=re.escape( + "CHOICE (`X | Y | ...`) types should not have an IMPLICIT " + "annotation" + ), + ): + + @asn1.sequence + class Example: + invalid: Annotated[typing.Union[int, bool], asn1.Implicit(0)] diff --git a/tests/hazmat/asn1/test_serialization.py b/tests/hazmat/asn1/test_serialization.py index cfef4b6f96b5..ea8ecd675aae 100644 --- a/tests/hazmat/asn1/test_serialization.py +++ b/tests/hazmat/asn1/test_serialization.py @@ -601,6 +601,236 @@ class Example: ): asn1.decode_der(Example, b"\x30\x05\xa2\x03\x02\x01\x09") + def test_sequence_with_choice(self) -> None: + @asn1.sequence + @_comparable_dataclass + class Example: + foo: typing.Union[int, bool, str] + + assert_roundtrips([(Example(foo=9), b"\x30\x03\x02\x01\x09")]) + assert_roundtrips([(Example(foo=True), b"\x30\x03\x01\x01\xff")]) + assert_roundtrips([(Example(foo="a"), b"\x30\x03\x0c\x01a")]) + + def test_sequence_with_optional_choice(self) -> None: + @asn1.sequence + @_comparable_dataclass + class Example: + foo: typing.Union[bool, str, None] + bar: int + + assert_roundtrips( + [(Example(foo=True, bar=1), b"\x30\x06\x01\x01\xff\x02\x01\x01")] + ) + + assert_roundtrips( + [(Example(foo=None, bar=1), b"\x30\x03\x02\x01\x01")] + ) + + def test_fail_sequence_with_choice_decode_nonexistent_variant( + self, + ) -> None: + @asn1.sequence + @_comparable_dataclass + class Example: + foo: typing.Union[bool, str] + + with pytest.raises( + ValueError, + match=re.escape( + "could not find matching variant when parsing CHOICE field" + ), + ): + asn1.decode_der(Example, b"\x30\x03\x02\x01\x09") + + def test_fail_sequence_with_choice_encode_nonexistent_variant( + self, + ) -> None: + @asn1.sequence + @_comparable_dataclass + class Example: + foo: typing.Union[bool, str] + + with pytest.raises( + ValueError, + ): + asn1.encode_der(Example(foo=3)) # type: ignore[arg-type] + + def test_sequence_with_explicit_choice(self) -> None: + @asn1.sequence + @_comparable_dataclass + class Example: + foo: Annotated[typing.Union[int, bool, str], asn1.Explicit(3)] + + assert_roundtrips([(Example(foo=9), b"\x30\x05\xa3\x03\x02\x01\x09")]) + assert_roundtrips( + [(Example(foo=True), b"\x30\x05\xa3\x03\x01\x01\xff")] + ) + assert_roundtrips([(Example(foo="a"), b"\x30\x05\xa3\x03\x0c\x01a")]) + + def test_sequence_with_choice_implicit_simple_variants(self) -> None: + @asn1.sequence + @_comparable_dataclass + class Example: + foo: typing.Union[ + Annotated[int, asn1.Implicit(0)], + Annotated[bool, asn1.Implicit(1)], + Annotated[str, asn1.Implicit(2)], + ] + + assert_roundtrips([(Example(foo=9), b"\x30\x03\x80\x01\x09")]) + assert_roundtrips([(Example(foo=True), b"\x30\x03\x81\x01\xff")]) + assert_roundtrips([(Example(foo="a"), b"\x30\x03\x82\x01a")]) + + def test_sequence_with_choice_explicit_simple_variants(self) -> None: + @asn1.sequence + @_comparable_dataclass + class Example: + foo: typing.Union[ + Annotated[int, asn1.Explicit(0)], + Annotated[bool, asn1.Explicit(1)], + Annotated[str, asn1.Explicit(2)], + ] + + assert_roundtrips([(Example(foo=9), b"\x30\x05\xa0\x03\x02\x01\x09")]) + assert_roundtrips( + [(Example(foo=True), b"\x30\x05\xa1\x03\x01\x01\xff")] + ) + assert_roundtrips([(Example(foo="a"), b"\x30\x05\xa2\x03\x0c\x01a")]) + + def test_sequence_with_choice_with_custom_variants(self) -> None: + @asn1.sequence + @_comparable_dataclass + class Example: + foo: typing.Union[ + Annotated[ + asn1.Variant[int, typing.Literal["IntA"]], asn1.Implicit(0) + ], + Annotated[ + asn1.Variant[int, typing.Literal["IntB"]], asn1.Implicit(1) + ], + Annotated[ + asn1.Variant[int, typing.Literal["IntC"]], asn1.Implicit(2) + ], + ] + + assert_roundtrips( + [(Example(foo=asn1.Variant(9, "IntA")), b"\x30\x03\x80\x01\x09")] + ) + assert_roundtrips( + [(Example(foo=asn1.Variant(9, "IntB")), b"\x30\x03\x81\x01\x09")] + ) + assert_roundtrips( + [(Example(foo=asn1.Variant(9, "IntC")), b"\x30\x03\x82\x01\x09")] + ) + + def test_sequence_with_choice_with_custom_variants_bool(self) -> None: + @asn1.sequence + @_comparable_dataclass + class Example: + foo: typing.Union[ + Annotated[ + asn1.Variant[bool, typing.Literal["BoolA"]], + asn1.Implicit(0), + ], + Annotated[ + asn1.Variant[bool, typing.Literal["BoolB"]], + asn1.Implicit(1), + ], + Annotated[ + asn1.Variant[bool, typing.Literal["BoolC"]], + asn1.Implicit(2), + ], + ] + + assert_roundtrips( + [ + ( + Example(foo=asn1.Variant(True, "BoolA")), + b"\x30\x03\x80\x01\xff", + ) + ] + ) + assert_roundtrips( + [ + ( + Example(foo=asn1.Variant(True, "BoolB")), + b"\x30\x03\x81\x01\xff", + ) + ] + ) + assert_roundtrips( + [ + ( + Example(foo=asn1.Variant(True, "BoolC")), + b"\x30\x03\x82\x01\xff", + ) + ] + ) + + def test_sequence_with_choice_with_sequence_variants(self) -> None: + @asn1.sequence + @_comparable_dataclass + class Example: + foo: int + + @asn1.sequence + @_comparable_dataclass + class ExampleUnion: + field: typing.Union[ + Annotated[ + asn1.Variant[Example, typing.Literal["ExampleA"]], + asn1.Implicit(0), + ], + Annotated[ + asn1.Variant[Example, typing.Literal["ExampleB"]], + asn1.Implicit(1), + ], + ] + + assert_roundtrips( + [ + ( + ExampleUnion( + field=asn1.Variant(Example(foo=9), "ExampleA") + ), + b"\x30\x05\xa0\x03\x02\x01\x09", + ) + ] + ) + assert_roundtrips( + [ + ( + ExampleUnion( + field=asn1.Variant(Example(foo=9), "ExampleB") + ), + b"\x30\x05\xa1\x03\x02\x01\x09", + ) + ] + ) + + def test_sequence_with_choice_with_non_annotated_custom_variants( + self, + ) -> None: + @asn1.sequence + @_comparable_dataclass + class Example: + foo: typing.Union[ + asn1.Variant[int, typing.Literal["MyInt"]], + asn1.Variant[bool, typing.Literal["MyBool"]], + ] + + assert_roundtrips( + [(Example(foo=asn1.Variant(9, "MyInt")), b"\x30\x03\x02\x01\x09")] + ) + assert_roundtrips( + [ + ( + Example(foo=asn1.Variant(True, "MyBool")), + b"\x30\x03\x01\x01\xff", + ) + ] + ) + class TestSize: def test_ok_sequenceof_size_restriction(self) -> None: From 518a166ffdf3445e1d19a0f204535baf241e2df5 Mon Sep 17 00:00:00 2001 From: Facundo Tuesca Date: Sun, 25 Jan 2026 22:31:49 +0100 Subject: [PATCH 02/11] only support all Variant, or all non-Variant types Signed-off-by: Facundo Tuesca --- src/cryptography/hazmat/asn1/asn1.py | 21 +++++++++++++++++++ tests/hazmat/asn1/test_api.py | 31 ++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/src/cryptography/hazmat/asn1/asn1.py b/src/cryptography/hazmat/asn1/asn1.py index 07e29e3ea516..feceb889c703 100644 --- a/src/cryptography/hazmat/asn1/asn1.py +++ b/src/cryptography/hazmat/asn1/asn1.py @@ -187,6 +187,27 @@ def _normalize_field_type( for arg in union_args if arg is not type(None) ] + + # Union types should either be all Variants + # (`Variant[..] | Variant[..] | etc`) or all non Variants + are_union_types_tagged = variants[0].tag_name is not None + if any( + (v.tag_name is not None) != are_union_types_tagged + for v in variants + ): + raise TypeError( + "When using `asn1.Variant` in a union, all the other " + "types in the union must also be `asn1.Variant`" + ) + + if are_union_types_tagged: + tags = [v.tag_name for v in variants] + if len(tags) != len(set(tags)): + raise TypeError( + "When using `asn1.Variant` in a union, the tags used " + "must be unique" + ) + rust_choice_type = declarative_asn1.Type.Choice(variants) # If None is part of the union types, this is an OPTIONAL CHOICE rust_field_type = ( diff --git a/tests/hazmat/asn1/test_api.py b/tests/hazmat/asn1/test_api.py index 4281ded6eb17..2827945b6e2d 100644 --- a/tests/hazmat/asn1/test_api.py +++ b/tests/hazmat/asn1/test_api.py @@ -312,6 +312,37 @@ class Example2: Annotated[int, asn1.Default(value=9)], None ] + def test_fail_choice_with_inconsistent_types(self) -> None: + with pytest.raises( + TypeError, + match=re.escape( + "When using `asn1.Variant` in a union, all the other " + "types in the union must also be `asn1.Variant`" + ), + ): + + @asn1.sequence + class Example2: + invalid: typing.Union[ + int, asn1.Variant[bool, typing.Literal["myTag"]] + ] + + def test_fail_choice_with_duplicate_tags(self) -> None: + with pytest.raises( + TypeError, + match=re.escape( + "When using `asn1.Variant` in a union, the tags used " + "must be unique" + ), + ): + + @asn1.sequence + class Example2: + invalid: typing.Union[ + asn1.Variant[int, typing.Literal["myTag"]], + asn1.Variant[bool, typing.Literal["myTag"]], + ] + def test_fields_of_variant_type(self) -> None: from cryptography.hazmat.bindings._rust import declarative_asn1 From f6bdd284de131a81d004bb1d370ebd753ef5a980 Mon Sep 17 00:00:00 2001 From: Facundo Tuesca Date: Tue, 27 Jan 2026 19:25:33 +0100 Subject: [PATCH 03/11] Update src/cryptography/hazmat/asn1/asn1.py Co-authored-by: Alex Gaynor --- src/cryptography/hazmat/asn1/asn1.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cryptography/hazmat/asn1/asn1.py b/src/cryptography/hazmat/asn1/asn1.py index feceb889c703..151972f52052 100644 --- a/src/cryptography/hazmat/asn1/asn1.py +++ b/src/cryptography/hazmat/asn1/asn1.py @@ -201,8 +201,8 @@ def _normalize_field_type( ) if are_union_types_tagged: - tags = [v.tag_name for v in variants] - if len(tags) != len(set(tags)): + tags = {v.tag_name for v in variants} + if len(variants) != len(tags): raise TypeError( "When using `asn1.Variant` in a union, the tags used " "must be unique" From 5d4deb7fd8241dc2342f27389a81d83f8edbabca Mon Sep 17 00:00:00 2001 From: Facundo Tuesca Date: Tue, 27 Jan 2026 19:19:06 +0100 Subject: [PATCH 04/11] simplify tests Signed-off-by: Facundo Tuesca --- tests/hazmat/asn1/test_serialization.py | 88 ++++++++++++++----------- 1 file changed, 48 insertions(+), 40 deletions(-) diff --git a/tests/hazmat/asn1/test_serialization.py b/tests/hazmat/asn1/test_serialization.py index ea8ecd675aae..6513211bd5c5 100644 --- a/tests/hazmat/asn1/test_serialization.py +++ b/tests/hazmat/asn1/test_serialization.py @@ -607,9 +607,13 @@ def test_sequence_with_choice(self) -> None: class Example: foo: typing.Union[int, bool, str] - assert_roundtrips([(Example(foo=9), b"\x30\x03\x02\x01\x09")]) - assert_roundtrips([(Example(foo=True), b"\x30\x03\x01\x01\xff")]) - assert_roundtrips([(Example(foo="a"), b"\x30\x03\x0c\x01a")]) + assert_roundtrips( + [ + (Example(foo=9), b"\x30\x03\x02\x01\x09"), + (Example(foo=True), b"\x30\x03\x01\x01\xff"), + (Example(foo="a"), b"\x30\x03\x0c\x01a"), + ] + ) def test_sequence_with_optional_choice(self) -> None: @asn1.sequence @@ -661,11 +665,13 @@ def test_sequence_with_explicit_choice(self) -> None: class Example: foo: Annotated[typing.Union[int, bool, str], asn1.Explicit(3)] - assert_roundtrips([(Example(foo=9), b"\x30\x05\xa3\x03\x02\x01\x09")]) assert_roundtrips( - [(Example(foo=True), b"\x30\x05\xa3\x03\x01\x01\xff")] + [ + (Example(foo=9), b"\x30\x05\xa3\x03\x02\x01\x09"), + (Example(foo=True), b"\x30\x05\xa3\x03\x01\x01\xff"), + (Example(foo="a"), b"\x30\x05\xa3\x03\x0c\x01a"), + ] ) - assert_roundtrips([(Example(foo="a"), b"\x30\x05\xa3\x03\x0c\x01a")]) def test_sequence_with_choice_implicit_simple_variants(self) -> None: @asn1.sequence @@ -677,9 +683,13 @@ class Example: Annotated[str, asn1.Implicit(2)], ] - assert_roundtrips([(Example(foo=9), b"\x30\x03\x80\x01\x09")]) - assert_roundtrips([(Example(foo=True), b"\x30\x03\x81\x01\xff")]) - assert_roundtrips([(Example(foo="a"), b"\x30\x03\x82\x01a")]) + assert_roundtrips( + [ + (Example(foo=9), b"\x30\x03\x80\x01\x09"), + (Example(foo=True), b"\x30\x03\x81\x01\xff"), + (Example(foo="a"), b"\x30\x03\x82\x01a"), + ] + ) def test_sequence_with_choice_explicit_simple_variants(self) -> None: @asn1.sequence @@ -691,11 +701,13 @@ class Example: Annotated[str, asn1.Explicit(2)], ] - assert_roundtrips([(Example(foo=9), b"\x30\x05\xa0\x03\x02\x01\x09")]) assert_roundtrips( - [(Example(foo=True), b"\x30\x05\xa1\x03\x01\x01\xff")] + [ + (Example(foo=9), b"\x30\x05\xa0\x03\x02\x01\x09"), + (Example(foo=True), b"\x30\x05\xa1\x03\x01\x01\xff"), + (Example(foo="a"), b"\x30\x05\xa2\x03\x0c\x01a"), + ] ) - assert_roundtrips([(Example(foo="a"), b"\x30\x05\xa2\x03\x0c\x01a")]) def test_sequence_with_choice_with_custom_variants(self) -> None: @asn1.sequence @@ -714,13 +726,20 @@ class Example: ] assert_roundtrips( - [(Example(foo=asn1.Variant(9, "IntA")), b"\x30\x03\x80\x01\x09")] - ) - assert_roundtrips( - [(Example(foo=asn1.Variant(9, "IntB")), b"\x30\x03\x81\x01\x09")] - ) - assert_roundtrips( - [(Example(foo=asn1.Variant(9, "IntC")), b"\x30\x03\x82\x01\x09")] + [ + ( + Example(foo=asn1.Variant(9, "IntA")), + b"\x30\x03\x80\x01\x09", + ), + ( + Example(foo=asn1.Variant(9, "IntB")), + b"\x30\x03\x81\x01\x09", + ), + ( + Example(foo=asn1.Variant(9, "IntC")), + b"\x30\x03\x82\x01\x09", + ), + ] ) def test_sequence_with_choice_with_custom_variants_bool(self) -> None: @@ -747,23 +766,15 @@ class Example: ( Example(foo=asn1.Variant(True, "BoolA")), b"\x30\x03\x80\x01\xff", - ) - ] - ) - assert_roundtrips( - [ + ), ( Example(foo=asn1.Variant(True, "BoolB")), b"\x30\x03\x81\x01\xff", - ) - ] - ) - assert_roundtrips( - [ + ), ( Example(foo=asn1.Variant(True, "BoolC")), b"\x30\x03\x82\x01\xff", - ) + ), ] ) @@ -794,17 +805,13 @@ class ExampleUnion: field=asn1.Variant(Example(foo=9), "ExampleA") ), b"\x30\x05\xa0\x03\x02\x01\x09", - ) - ] - ) - assert_roundtrips( - [ + ), ( ExampleUnion( field=asn1.Variant(Example(foo=9), "ExampleB") ), b"\x30\x05\xa1\x03\x02\x01\x09", - ) + ), ] ) @@ -819,15 +826,16 @@ class Example: asn1.Variant[bool, typing.Literal["MyBool"]], ] - assert_roundtrips( - [(Example(foo=asn1.Variant(9, "MyInt")), b"\x30\x03\x02\x01\x09")] - ) assert_roundtrips( [ + ( + Example(foo=asn1.Variant(9, "MyInt")), + b"\x30\x03\x02\x01\x09", + ), ( Example(foo=asn1.Variant(True, "MyBool")), b"\x30\x03\x01\x01\xff", - ) + ), ] ) From 7055eeda6c85389ef9731e57a4781d1cf4a99004 Mon Sep 17 00:00:00 2001 From: Facundo Tuesca Date: Tue, 27 Jan 2026 19:23:12 +0100 Subject: [PATCH 05/11] add comment about implicit CHOICEs Signed-off-by: Facundo Tuesca --- src/cryptography/hazmat/asn1/asn1.py | 1 + src/rust/src/declarative_asn1/decode.rs | 1 + src/rust/src/declarative_asn1/encode.rs | 2 +- 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/cryptography/hazmat/asn1/asn1.py b/src/cryptography/hazmat/asn1/asn1.py index 151972f52052..3d5909528a4b 100644 --- a/src/cryptography/hazmat/asn1/asn1.py +++ b/src/cryptography/hazmat/asn1/asn1.py @@ -178,6 +178,7 @@ def _normalize_field_type( else: # Otherwise, the Union is a CHOICE if isinstance(annotation.encoding, Implicit): + # CHOICEs cannot be IMPLICIT. See X.680 section 31.2.9. raise TypeError( "CHOICE (`X | Y | ...`) types should not have an IMPLICIT " "annotation" diff --git a/src/rust/src/declarative_asn1/decode.rs b/src/rust/src/declarative_asn1/decode.rs index b86360af86d4..4c6543be8127 100644 --- a/src/rust/src/declarative_asn1/decode.rs +++ b/src/rust/src/declarative_asn1/decode.rs @@ -171,6 +171,7 @@ fn decode_choice_with_encoding<'a>( ) -> ParseResult> { match encoding { Encoding::Implicit(_) => Err(CryptographyError::Py( + // CHOICEs cannot be IMPLICIT. See X.680 section 31.2.9. pyo3::exceptions::PyValueError::new_err( "invalid type definition: CHOICE fields cannot be implicitly encoded".to_string(), ), diff --git a/src/rust/src/declarative_asn1/encode.rs b/src/rust/src/declarative_asn1/encode.rs index 6a6346822b56..57abc95b5883 100644 --- a/src/rust/src/declarative_asn1/encode.rs +++ b/src/rust/src/declarative_asn1/encode.rs @@ -143,7 +143,7 @@ impl asn1::Asn1Writable for AnnotatedTypeObject<'_> { }; match encoding { Some(e) => match e.get() { - // CHOICE cannot be encoded as IMPLICIT + // CHOICEs cannot be IMPLICIT. See X.680 section 31.2.9. Encoding::Implicit(_) => { return Err(asn1::WriteError::AllocationError) } From 9506b460ade85d7942aa5d48fd23f942d2e8f759 Mon Sep 17 00:00:00 2001 From: Facundo Tuesca Date: Tue, 27 Jan 2026 19:46:27 +0100 Subject: [PATCH 06/11] check that tags are Literal types Signed-off-by: Facundo Tuesca --- src/cryptography/hazmat/asn1/asn1.py | 10 ++++++++-- tests/hazmat/asn1/test_api.py | 20 ++++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/src/cryptography/hazmat/asn1/asn1.py b/src/cryptography/hazmat/asn1/asn1.py index 3d5909528a4b..3abce5d86457 100644 --- a/src/cryptography/hazmat/asn1/asn1.py +++ b/src/cryptography/hazmat/asn1/asn1.py @@ -53,8 +53,8 @@ class Variant(typing.Generic[U, Tag]): and need to distinguish between them: foo: ( - Annotated[Variant[int, "IntA"], Implicit(0)] - | Annotated[Variant[int, "IntB"], Implicit(1)] + Annotated[Variant[int, typing.Literal["IntA"]], Implicit(0)] + | Annotated[Variant[int, typing.Literal["IntB"]], Implicit(1)] ) Usage: @@ -243,6 +243,12 @@ def _type_to_variant( # Check if this is a Variant[T, Tag] type if get_type_origin(inner_type) is Variant: value_type, tag_literal = get_type_args(inner_type) + if get_type_origin(tag_literal) is not typing.Literal: + raise TypeError( + "When using `asn1.Variant` in a type annotation, the second " + "type parameter must be a `typing.Literal` type. E.g: " + '`Variant[int, typing.Literal["MyInt"]]`.' + ) tag_name = get_type_args(tag_literal)[0] if hasattr(value_type, "__asn1_root__"): diff --git a/tests/hazmat/asn1/test_api.py b/tests/hazmat/asn1/test_api.py index 2827945b6e2d..f0bac55bdc92 100644 --- a/tests/hazmat/asn1/test_api.py +++ b/tests/hazmat/asn1/test_api.py @@ -387,3 +387,23 @@ def test_fail_choice_with_implicit_encoding(self) -> None: @asn1.sequence class Example: invalid: Annotated[typing.Union[int, bool], asn1.Implicit(0)] + + def test_fail_choice_with_non_literal_tag(self) -> None: + with pytest.raises( + TypeError, + match=re.escape( + "When using `asn1.Variant` in a type annotation, the second " + "type parameter must be a `typing.Literal` type. E.g: " + '`Variant[int, typing.Literal["MyInt"]]`.' + ), + ): + + @asn1.sequence + class Example: + foo: typing.Union[ + Annotated[asn1.Variant[int, bool], asn1.Implicit(0)], + Annotated[ + asn1.Variant[int, typing.Literal["IntB"]], + asn1.Implicit(1), + ], + ] From 3c6db8cfce74078a64baf0fa790a6199a1ca6914 Mon Sep 17 00:00:00 2001 From: Facundo Tuesca Date: Tue, 27 Jan 2026 20:50:42 +0100 Subject: [PATCH 07/11] refactor tag check Signed-off-by: Facundo Tuesca --- src/rust/src/declarative_asn1/decode.rs | 11 +- src/rust/src/declarative_asn1/types.rs | 128 +++++++++++++++--------- 2 files changed, 82 insertions(+), 57 deletions(-) diff --git a/src/rust/src/declarative_asn1/decode.rs b/src/rust/src/declarative_asn1/decode.rs index 4c6543be8127..f17c3bc75bd9 100644 --- a/src/rust/src/declarative_asn1/decode.rs +++ b/src/rust/src/declarative_asn1/decode.rs @@ -7,7 +7,7 @@ use pyo3::types::{PyAnyMethods, PyListMethods}; use crate::asn1::big_byte_slice_to_py_int; use crate::declarative_asn1::types::{ - check_size_constraint, expected_tags_for_type, expected_tags_for_variant, AnnotatedType, + check_size_constraint, is_tag_valid_for_type, is_tag_valid_for_variant, AnnotatedType, Annotation, BitString, Encoding, GeneralizedTime, IA5String, PrintableString, Type, UtcTime, Variant, }; @@ -214,9 +214,8 @@ pub(crate) fn decode_annotated_type<'a>( // Handle DEFAULT annotation if field is not present (by // returning the default value) if let Some(default) = &ann_type.annotation.get().default { - let expected_tags = expected_tags_for_type(py, inner, encoding); match parser.peek_tag() { - Some(next_tag) if expected_tags.contains(&next_tag) => (), + Some(next_tag) if is_tag_valid_for_type(py, next_tag, inner, encoding) => (), _ => return Ok(default.clone_ref(py).into_bound(py)), } } @@ -252,9 +251,8 @@ pub(crate) fn decode_annotated_type<'a>( })? } Type::Option(cls) => { - let expected_tags = expected_tags_for_type(py, cls.get().inner.get(), encoding); match parser.peek_tag() { - Some(t) if expected_tags.contains(&t) => { + Some(t) if is_tag_valid_for_type(py, t, cls.get().inner.get(), encoding) => { // For optional types, annotations will always be associated to the `Optional` type // i.e: `Annotated[Optional[T], annotation]`, as opposed to the inner `T` type. // Therefore, when decoding the inner type `T` we must pass the annotation of the `Optional` @@ -272,9 +270,8 @@ pub(crate) fn decode_annotated_type<'a>( None => { for t in ts.bind(py) { let variant = t.cast::()?.get(); - let expected_tags = expected_tags_for_variant(py, variant); match parser.peek_tag() { - Some(tag) if expected_tags.contains(&tag) => { + Some(tag) if is_tag_valid_for_variant(py, tag, variant, encoding) => { let decoded_value = decode_annotated_type(py, parser, variant.ann_type.get())?; return match &variant.tag_name { diff --git a/src/rust/src/declarative_asn1/types.rs b/src/rust/src/declarative_asn1/types.rs index 072075105a62..e87c3491738d 100644 --- a/src/rust/src/declarative_asn1/types.rs +++ b/src/rust/src/declarative_asn1/types.rs @@ -449,54 +449,80 @@ pub(crate) fn python_class_to_annotated<'p>( } } -// Utility function to get the expected tags for an unnanotated variant. -pub(crate) fn expected_tags_for_variant(py: pyo3::Python<'_>, variant: &Variant) -> Vec { - let ann_type = variant.ann_type.get(); - expected_tags_for_type( - py, - ann_type.inner.get(), - &ann_type.annotation.get().encoding, - ) +// Checks if encoding `tag_without_encoding` using `encoding` results +// in `tag` +fn check_tag_with_encoding( + tag_without_encoding: asn1::Tag, + encoding: &Option>, + tag: asn1::Tag, +) -> bool { + let tag_with_encoding = match encoding { + Some(e) => match e.get() { + Encoding::Implicit(n) => asn1::implicit_tag(*n, tag_without_encoding), + Encoding::Explicit(n) => asn1::explicit_tag(*n), + }, + None => tag_without_encoding, + }; + tag_with_encoding == tag } -// Given a type, return the set of possible tags that we would expect -// to see when decoding it. This is usually a single tag per type, except -// when decoding a CHOICE value. -pub(crate) fn expected_tags_for_type( +// Utility function to see if a tag matches an unnanotated variant. +pub(crate) fn is_tag_valid_for_variant( py: pyo3::Python<'_>, - t: &Type, + tag: asn1::Tag, + variant: &Variant, encoding: &Option>, -) -> Vec { - let inner_tags = match t { - Type::Sequence(_, _) => vec![asn1::Sequence::TAG], - Type::SequenceOf(_) => vec![asn1::Sequence::TAG], - Type::Option(t) => expected_tags_for_type(py, t.get().inner.get(), encoding), - Type::Choice(variants) => variants - .bind(py) - .into_iter() - .flat_map(|v| expected_tags_for_variant(py, v.cast::().unwrap().get())) - .collect(), - Type::PyBool() => vec![bool::TAG], - Type::PyInt() => vec![asn1::BigInt::TAG], - Type::PyBytes() => vec![<&[u8] as SimpleAsn1Readable>::TAG], - Type::PyStr() => vec![asn1::Utf8String::TAG], - Type::PrintableString() => vec![asn1::PrintableString::TAG], - Type::IA5String() => vec![asn1::IA5String::TAG], - Type::ObjectIdentifier() => vec![asn1::ObjectIdentifier::TAG], - Type::UtcTime() => vec![asn1::UtcTime::TAG], - Type::GeneralizedTime() => vec![asn1::GeneralizedTime::TAG], - Type::BitString() => vec![asn1::BitString::TAG], - }; +) -> bool { + let ann_type = variant.ann_type.get(); - match encoding { + // There are two encodings at play here: the encoding of the CHOICE itself, + // and the encoding of each of the variants. The encoding of the CHOICE will + // only affect the tag if it's EXPLICIT (where it adds a wrapper). Otherwise, + // we use the encoding of the variant. + let encoding_to_match = match encoding { Some(e) => match e.get() { - Encoding::Implicit(n) => inner_tags - .into_iter() - .map(|x| asn1::implicit_tag(*n, x)) - .collect(), - Encoding::Explicit(n) => vec![asn1::explicit_tag(*n)], + Encoding::Implicit(_) => &ann_type.annotation.get().encoding, + Encoding::Explicit(_) => encoding, }, - None => inner_tags, + None => &ann_type.annotation.get().encoding, + }; + + is_tag_valid_for_type(py, tag, ann_type.inner.get(), encoding_to_match) +} + +// Given `tag` and `encoding`, returns whether that tag with that encoding +// matches what one would expect to see when decoding `type_` +pub(crate) fn is_tag_valid_for_type( + py: pyo3::Python<'_>, + tag: asn1::Tag, + type_: &Type, + encoding: &Option>, +) -> bool { + match type_ { + Type::Sequence(_, _) => check_tag_with_encoding(asn1::Sequence::TAG, encoding, tag), + Type::SequenceOf(_) => check_tag_with_encoding(asn1::Sequence::TAG, encoding, tag), + Type::Option(t) => is_tag_valid_for_type(py, tag, t.get().inner.get(), encoding), + Type::Choice(variants) => variants.bind(py).into_iter().any(|v| { + is_tag_valid_for_variant(py, tag, v.cast::().unwrap().get(), encoding) + }), + Type::PyBool() => check_tag_with_encoding(bool::TAG, encoding, tag), + Type::PyInt() => check_tag_with_encoding(asn1::BigInt::TAG, encoding, tag), + Type::PyBytes() => { + check_tag_with_encoding(<&[u8] as SimpleAsn1Readable>::TAG, encoding, tag) + } + Type::PyStr() => check_tag_with_encoding(asn1::Utf8String::TAG, encoding, tag), + Type::PrintableString() => { + check_tag_with_encoding(asn1::PrintableString::TAG, encoding, tag) + } + Type::IA5String() => check_tag_with_encoding(asn1::IA5String::TAG, encoding, tag), + Type::ObjectIdentifier() => { + check_tag_with_encoding(asn1::ObjectIdentifier::TAG, encoding, tag) + } + Type::UtcTime() => check_tag_with_encoding(asn1::UtcTime::TAG, encoding, tag), + Type::GeneralizedTime() => { + check_tag_with_encoding(asn1::GeneralizedTime::TAG, encoding, tag) + } + Type::BitString() => check_tag_with_encoding(asn1::BitString::TAG, encoding, tag), } } @@ -523,14 +549,15 @@ pub(crate) fn check_size_constraint( #[cfg(test)] mod tests { + use asn1::SimpleAsn1Readable; use pyo3::IntoPyObject; - use super::{expected_tags_for_type, AnnotatedType, Annotation, Type}; + use super::{is_tag_valid_for_type, AnnotatedType, Annotation, Type}; #[test] - // Needed for coverage of `expected_tags_for_type(Type::Option(..))`, since - // `expected_tags_for_type` is never called with an optional value. - fn test_option_expected_tags_for_type() { + // Needed for coverage of `is_tag_valid_for_type(Type::Option(..))`, since + // `is_tag_valid_for_type` is never called with an optional value. + fn test_option_is_tag_valid_for_type() { pyo3::Python::initialize(); pyo3::Python::attach(|py| { @@ -564,11 +591,12 @@ mod tests { }, ) .unwrap(); - let expected_tags = expected_tags_for_type(py, &Type::Option(optional_type), &None); - assert_eq!( - expected_tags, - expected_tags_for_type(py, &Type::PyInt(), &None) - ) + assert!(is_tag_valid_for_type( + py, + asn1::BigInt::TAG, + &Type::Option(optional_type), + &None + )); }) } } From 35a56da09e84ca1eefb2f0d1b4ac9bff119a0da8 Mon Sep 17 00:00:00 2001 From: Facundo Tuesca Date: Tue, 27 Jan 2026 22:03:24 +0100 Subject: [PATCH 08/11] add tests for missing coverage Signed-off-by: Facundo Tuesca --- src/cryptography/hazmat/asn1/asn1.py | 2 +- src/rust/src/declarative_asn1/types.rs | 44 +++++++++++++++- tests/hazmat/asn1/test_serialization.py | 70 +++++++++++++++++++++++++ 3 files changed, 113 insertions(+), 3 deletions(-) diff --git a/src/cryptography/hazmat/asn1/asn1.py b/src/cryptography/hazmat/asn1/asn1.py index 3abce5d86457..5ca6f99290d9 100644 --- a/src/cryptography/hazmat/asn1/asn1.py +++ b/src/cryptography/hazmat/asn1/asn1.py @@ -125,7 +125,7 @@ def _normalize_field_type( # from it if it exists. if get_type_origin(field_type) is Annotated: annotation = _extract_annotation(field_type.__metadata__, field_name) - field_type, _ = get_type_args(field_type) + field_type, *_ = get_type_args(field_type) else: annotation = declarative_asn1.Annotation() diff --git a/src/rust/src/declarative_asn1/types.rs b/src/rust/src/declarative_asn1/types.rs index e87c3491738d..2eecf2753efd 100644 --- a/src/rust/src/declarative_asn1/types.rs +++ b/src/rust/src/declarative_asn1/types.rs @@ -550,9 +550,12 @@ pub(crate) fn check_size_constraint( mod tests { use asn1::SimpleAsn1Readable; - use pyo3::IntoPyObject; + use pyo3::{IntoPyObject, PyTypeInfo}; - use super::{is_tag_valid_for_type, AnnotatedType, Annotation, Type}; + use super::{ + is_tag_valid_for_type, is_tag_valid_for_variant, AnnotatedType, Annotation, Encoding, Type, + Variant, + }; #[test] // Needed for coverage of `is_tag_valid_for_type(Type::Option(..))`, since @@ -599,4 +602,41 @@ mod tests { )); }) } + #[test] + // Needed for coverage of + // `is_tag_valid_for_variant(..., encoding=Encoding::Implicit)`, since + // `is_tag_valid_for_variant` is never called with an implicit encoding. + fn test_is_tag_valid_for_implicit_variant() { + pyo3::Python::initialize(); + + pyo3::Python::attach(|py| { + let ann_type = pyo3::Py::new( + py, + AnnotatedType { + inner: pyo3::Py::new(py, Type::PyInt()).unwrap(), + annotation: Annotation { + default: None, + encoding: None, + size: None, + } + .into_pyobject(py) + .unwrap() + .unbind(), + }, + ) + .unwrap(); + let variant = Variant { + python_class: pyo3::types::PyInt::type_object(py).unbind(), + ann_type, + tag_name: None, + }; + let encoding = pyo3::Py::new(py, Encoding::Implicit(3)).ok(); + assert!(is_tag_valid_for_variant( + py, + asn1::BigInt::TAG, + &variant, + &encoding + )); + }) + } } diff --git a/tests/hazmat/asn1/test_serialization.py b/tests/hazmat/asn1/test_serialization.py index 6513211bd5c5..ed70be2d77ab 100644 --- a/tests/hazmat/asn1/test_serialization.py +++ b/tests/hazmat/asn1/test_serialization.py @@ -488,6 +488,76 @@ class Example: ] ) + def test_ok_sequence_all_types_default(self) -> None: + default_time = datetime.datetime( + 2019, + 12, + 16, + 3, + 2, + 10, + tzinfo=datetime.timezone.utc, + ) + default_oid = x509.ObjectIdentifier("1.3.6.1.4.1.343") + + @asn1.sequence + @_comparable_dataclass + class Example: + a: Annotated[int, asn1.Default(3)] + b: Annotated[bytes, asn1.Default(b"\x00")] + c: Annotated[ + asn1.PrintableString, asn1.Default(asn1.PrintableString("a")) + ] + d: Annotated[ + asn1.UtcTime, + asn1.Default(asn1.UtcTime(default_time)), + ] + e: Annotated[ + asn1.GeneralizedTime, + asn1.Default(asn1.GeneralizedTime(default_time)), + ] + f: Annotated[typing.List[int], asn1.Default([1])] + g: Annotated[ + asn1.BitString, + asn1.Default( + asn1.BitString(data=b"", padding_bits=0), + ), + ] + h: Annotated[asn1.IA5String, asn1.Default(asn1.IA5String("a"))] + i: Annotated[ + x509.ObjectIdentifier, + asn1.Default(default_oid), + ] + j: Annotated[ + typing.Union[int, bool], asn1.Default(3), asn1.Explicit(0) + ] + k: Annotated[str, asn1.Default("a"), asn1.Implicit(0)] + only_field_present: Annotated[ + str, asn1.Default("a"), asn1.Implicit(1) + ] + + assert_roundtrips( + [ + ( + Example( + a=3, + b=b"\x00", + c=asn1.PrintableString("a"), + d=asn1.UtcTime(default_time), + e=asn1.GeneralizedTime(default_time), + f=[1], + g=asn1.BitString(data=b"", padding_bits=0), + h=asn1.IA5String("a"), + i=default_oid, + j=3, + k="a", + only_field_present="b", + ), + b"\x30\x03\x81\x01b", + ) + ] + ) + def test_ok_sequence_with_default_annotations(self) -> None: @asn1.sequence @_comparable_dataclass From a221748dec83bf30ebaff1d887ec7df5a9aa5d01 Mon Sep 17 00:00:00 2001 From: Facundo Tuesca Date: Thu, 29 Jan 2026 21:26:08 +0100 Subject: [PATCH 09/11] add bound to Tag type Signed-off-by: Facundo Tuesca --- src/cryptography/hazmat/asn1/asn1.py | 2 +- tests/hazmat/asn1/test_api.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/cryptography/hazmat/asn1/asn1.py b/src/cryptography/hazmat/asn1/asn1.py index 5ca6f99290d9..83e0d5a90cee 100644 --- a/src/cryptography/hazmat/asn1/asn1.py +++ b/src/cryptography/hazmat/asn1/asn1.py @@ -41,7 +41,7 @@ T = typing.TypeVar("T", covariant=True) U = typing.TypeVar("U") -Tag = typing.TypeVar("Tag") +Tag = typing.TypeVar("Tag", bound=typing.LiteralString) @dataclasses.dataclass(frozen=True) diff --git a/tests/hazmat/asn1/test_api.py b/tests/hazmat/asn1/test_api.py index f0bac55bdc92..1caaa08f5a6b 100644 --- a/tests/hazmat/asn1/test_api.py +++ b/tests/hazmat/asn1/test_api.py @@ -401,7 +401,10 @@ def test_fail_choice_with_non_literal_tag(self) -> None: @asn1.sequence class Example: foo: typing.Union[ - Annotated[asn1.Variant[int, bool], asn1.Implicit(0)], + Annotated[ + asn1.Variant[int, str], + asn1.Implicit(0), + ], Annotated[ asn1.Variant[int, typing.Literal["IntB"]], asn1.Implicit(1), From 66b62b869172e8f9e29cbe7b72c9dcac8f17a16a Mon Sep 17 00:00:00 2001 From: Facundo Tuesca Date: Thu, 29 Jan 2026 21:49:52 +0100 Subject: [PATCH 10/11] fix LiteralString usage in Python < 3.11 Signed-off-by: Facundo Tuesca --- src/cryptography/hazmat/asn1/asn1.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/cryptography/hazmat/asn1/asn1.py b/src/cryptography/hazmat/asn1/asn1.py index 83e0d5a90cee..9f846d0c927d 100644 --- a/src/cryptography/hazmat/asn1/asn1.py +++ b/src/cryptography/hazmat/asn1/asn1.py @@ -13,6 +13,8 @@ if sys.version_info < (3, 11): import typing_extensions + LiteralString = typing_extensions.LiteralString + # We use the `include_extras` parameter of `get_type_hints`, which was # added in Python 3.9. This can be replaced by the `typing` version # once the min version is >= 3.9 @@ -31,6 +33,7 @@ get_type_args = typing.get_args get_type_origin = typing.get_origin Annotated = typing.Annotated + LiteralString = typing.LiteralString if sys.version_info < (3, 10): NoneType = type(None) @@ -41,7 +44,7 @@ T = typing.TypeVar("T", covariant=True) U = typing.TypeVar("U") -Tag = typing.TypeVar("Tag", bound=typing.LiteralString) +Tag = typing.TypeVar("Tag", bound=LiteralString) @dataclasses.dataclass(frozen=True) From 1840dd24be24196542aef0ae914cdf9fd0d7e2b7 Mon Sep 17 00:00:00 2001 From: Facundo Tuesca Date: Sat, 31 Jan 2026 23:04:35 +0100 Subject: [PATCH 11/11] remove redundant assert_roundtrips call Signed-off-by: Facundo Tuesca --- tests/hazmat/asn1/test_serialization.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/hazmat/asn1/test_serialization.py b/tests/hazmat/asn1/test_serialization.py index 63cd492f6aea..718d0bcc4fc2 100644 --- a/tests/hazmat/asn1/test_serialization.py +++ b/tests/hazmat/asn1/test_serialization.py @@ -699,11 +699,13 @@ class Example: bar: int assert_roundtrips( - [(Example(foo=True, bar=1), b"\x30\x06\x01\x01\xff\x02\x01\x01")] - ) - - assert_roundtrips( - [(Example(foo=None, bar=1), b"\x30\x03\x02\x01\x01")] + [ + ( + Example(foo=True, bar=1), + b"\x30\x06\x01\x01\xff\x02\x01\x01", + ), + (Example(foo=None, bar=1), b"\x30\x03\x02\x01\x01"), + ] ) def test_fail_sequence_with_choice_decode_nonexistent_variant(