Skip to content
2 changes: 2 additions & 0 deletions src/cryptography/hazmat/asn1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
PrintableString,
Size,
UtcTime,
Variant,
decode_der,
encode_der,
sequence,
Expand All @@ -27,6 +28,7 @@
"PrintableString",
"Size",
"UtcTime",
"Variant",
"decode_der",
"encode_der",
"sequence",
Expand Down
121 changes: 118 additions & 3 deletions src/cryptography/hazmat/asn1/asn1.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
if sys.version_info < (3, 11):
import typing_extensions

LiteralString = typing_extensions.LiteralString

# We use the `include_extras` parameter of `get_type_hints`, which was
# added in Python 3.9. This can be replaced by the `typing` version
# once the min version is >= 3.9
Expand All @@ -31,6 +33,7 @@
get_type_args = typing.get_args
get_type_origin = typing.get_origin
Annotated = typing.Annotated
LiteralString = typing.LiteralString

if sys.version_info < (3, 10):
NoneType = type(None)
Expand All @@ -41,6 +44,30 @@

T = typing.TypeVar("T", covariant=True)
U = typing.TypeVar("U")
Tag = typing.TypeVar("Tag", bound=LiteralString)


@dataclasses.dataclass(frozen=True)
class Variant(typing.Generic[U, Tag]):
"""
A tagged variant for CHOICE fields with the same underlying type.

Use this when you have multiple CHOICE alternatives with the same type
and need to distinguish between them:

foo: (
Annotated[Variant[int, typing.Literal["IntA"]], Implicit(0)]
| Annotated[Variant[int, typing.Literal["IntB"]], Implicit(1)]
)

Usage:
example = Example(foo=Variant(5, "IntA"))
decoded.foo.value # The int value
decoded.foo.tag # "IntA" or "IntB"
"""

value: U
tag: str


decode_der = declarative_asn1.decode_der
Expand Down Expand Up @@ -101,7 +128,7 @@ def _normalize_field_type(
# from it if it exists.
if get_type_origin(field_type) is Annotated:
annotation = _extract_annotation(field_type.__metadata__, field_name)
field_type, _ = get_type_args(field_type)
field_type, *_ = get_type_args(field_type)
else:
annotation = declarative_asn1.Annotation()

Expand Down Expand Up @@ -150,10 +177,53 @@ def _normalize_field_type(
)

rust_field_type = declarative_asn1.Type.Option(annotated_type)

else:
raise TypeError(
"union types other than `X | None` are currently not supported"
# Otherwise, the Union is a CHOICE
if isinstance(annotation.encoding, Implicit):
# CHOICEs cannot be IMPLICIT. See X.680 section 31.2.9.
raise TypeError(
"CHOICE (`X | Y | ...`) types should not have an IMPLICIT "
"annotation"
)
variants = [
_type_to_variant(arg, field_name)
for arg in union_args
if arg is not type(None)
]

# Union types should either be all Variants
# (`Variant[..] | Variant[..] | etc`) or all non Variants
are_union_types_tagged = variants[0].tag_name is not None
if any(
(v.tag_name is not None) != are_union_types_tagged
for v in variants
):
raise TypeError(
"When using `asn1.Variant` in a union, all the other "
"types in the union must also be `asn1.Variant`"
)

if are_union_types_tagged:
tags = {v.tag_name for v in variants}
if len(variants) != len(tags):
raise TypeError(
"When using `asn1.Variant` in a union, the tags used "
"must be unique"
)

rust_choice_type = declarative_asn1.Type.Choice(variants)
# If None is part of the union types, this is an OPTIONAL CHOICE
rust_field_type = (
declarative_asn1.Type.Option(
declarative_asn1.AnnotatedType(
rust_choice_type, declarative_asn1.Annotation()
)
)
if NoneType in union_args
else rust_choice_type
)

elif get_type_origin(field_type) is builtins.list:
inner_type = _normalize_field_type(
get_type_args(field_type)[0], field_name
Expand All @@ -165,6 +235,51 @@ def _normalize_field_type(
return declarative_asn1.AnnotatedType(rust_field_type, annotation)


# Convert a type to a Variant. Used with types inside Union
# annotations (T1, T2, etc in `Union[T1, T2, ...]`).
def _type_to_variant(
t: typing.Any, field_name: str
) -> declarative_asn1.Variant:
is_annotated = get_type_origin(t) is Annotated
inner_type = get_type_args(t)[0] if is_annotated else t

# Check if this is a Variant[T, Tag] type
if get_type_origin(inner_type) is Variant:
value_type, tag_literal = get_type_args(inner_type)
if get_type_origin(tag_literal) is not typing.Literal:
raise TypeError(
"When using `asn1.Variant` in a type annotation, the second "
"type parameter must be a `typing.Literal` type. E.g: "
'`Variant[int, typing.Literal["MyInt"]]`.'
)
tag_name = get_type_args(tag_literal)[0]

if hasattr(value_type, "__asn1_root__"):
rust_type = value_type.__asn1_root__.inner
else:
rust_type = declarative_asn1.non_root_python_to_rust(value_type)

if is_annotated:
ann_type = declarative_asn1.AnnotatedType(
rust_type,
_extract_annotation(t.__metadata__, field_name),
)
else:
ann_type = declarative_asn1.AnnotatedType(
rust_type,
declarative_asn1.Annotation(),
)

return declarative_asn1.Variant(Variant, ann_type, tag_name)
else:
# Plain type (not a tagged Variant)
return declarative_asn1.Variant(
inner_type,
_normalize_field_type(t, field_name),
None,
)


def _annotate_fields(
raw_fields: dict[str, type],
) -> dict[str, declarative_asn1.AnnotatedType]:
Expand Down
13 changes: 13 additions & 0 deletions src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class Type:
Sequence: typing.ClassVar[type]
SequenceOf: typing.ClassVar[type]
Option: typing.ClassVar[type]
Choice: typing.ClassVar[type]
PyBool: typing.ClassVar[type]
PyInt: typing.ClassVar[type]
PyBytes: typing.ClassVar[type]
Expand Down Expand Up @@ -60,6 +61,18 @@ class AnnotatedTypeObject:
cls, annotated_type: AnnotatedType, value: typing.Any
) -> AnnotatedTypeObject: ...

class Variant:
python_class: type
ann_type: AnnotatedType
tag_name: str | None

def __new__(
cls,
python_class: type,
ann_type: AnnotatedType,
tag_name: str | None,
) -> Variant: ...

class PrintableString:
def __new__(cls, inner: str) -> PrintableString: ...
def __repr__(self) -> str: ...
Expand Down
107 changes: 103 additions & 4 deletions src/rust/src/declarative_asn1/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ use pyo3::types::{PyAnyMethods, PyListMethods};

use crate::asn1::big_byte_slice_to_py_int;
use crate::declarative_asn1::types::{
check_size_constraint, is_tag_valid_for_type, AnnotatedType, Annotation, BitString, Encoding,
GeneralizedTime, IA5String, PrintableString, Type, UtcTime,
check_size_constraint, is_tag_valid_for_type, is_tag_valid_for_variant, AnnotatedType,
Annotation, BitString, Encoding, GeneralizedTime, IA5String, PrintableString, Type, UtcTime,
Variant,
};
use crate::error::CryptographyError;

Expand Down Expand Up @@ -160,6 +161,47 @@ fn decode_bitstring<'a>(
)?)
}

// Utility function to handle explicit encoding when parsing
// CHOICE fields.
fn decode_choice_with_encoding<'a>(
py: pyo3::Python<'a>,
parser: &mut Parser<'a>,
ann_type: &AnnotatedType,
encoding: &Encoding,
) -> ParseResult<pyo3::Bound<'a, pyo3::PyAny>> {
match encoding {
Encoding::Implicit(_) => Err(CryptographyError::Py(
// CHOICEs cannot be IMPLICIT. See X.680 section 31.2.9.
pyo3::exceptions::PyValueError::new_err(
"invalid type definition: CHOICE fields cannot be implicitly encoded".to_string(),
),
))?,
Encoding::Explicit(n) => {
// Since we don't know which of the variants is present for this
// CHOICE field, we'll parse this as a generic TLV encoded with
// EXPLICIT, so `read_explicit_element` will consume the EXPLICIT
// wrapper tag, and the TLV data will contain the variant.
let tlv = parser.read_explicit_element::<asn1::Tlv<'_>>(*n)?;
let type_without_explicit = AnnotatedType {
inner: ann_type.inner.clone_ref(py),
annotation: pyo3::Py::new(
py,
Annotation {
default: None,
encoding: None,
size: None,
},
)?,
};
// Parse the TLV data (which contains the field without the EXPLICIT
// wrapper)
asn1::parse(tlv.full_data(), |d| {
decode_annotated_type(py, d, &type_without_explicit)
})
}
}
}

pub(crate) fn decode_annotated_type<'a>(
py: pyo3::Python<'a>,
parser: &mut Parser<'a>,
Expand All @@ -173,7 +215,7 @@ pub(crate) fn decode_annotated_type<'a>(
// returning the default value)
if let Some(default) = &ann_type.annotation.get().default {
match parser.peek_tag() {
Some(next_tag) if is_tag_valid_for_type(next_tag, inner, encoding) => (),
Some(next_tag) if is_tag_valid_for_type(py, next_tag, inner, encoding) => (),
_ => return Ok(default.clone_ref(py).into_bound(py)),
}
}
Expand Down Expand Up @@ -210,7 +252,7 @@ pub(crate) fn decode_annotated_type<'a>(
}
Type::Option(cls) => {
match parser.peek_tag() {
Some(t) if is_tag_valid_for_type(t, cls.get().inner.get(), encoding) => {
Some(t) if is_tag_valid_for_type(py, t, cls.get().inner.get(), encoding) => {
// For optional types, annotations will always be associated to the `Optional` type
// i.e: `Annotated[Optional[T], annotation]`, as opposed to the inner `T` type.
// Therefore, when decoding the inner type `T` we must pass the annotation of the `Optional`
Expand All @@ -223,6 +265,33 @@ pub(crate) fn decode_annotated_type<'a>(
_ => pyo3::types::PyNone::get(py).to_owned().into_any(),
}
}
Type::Choice(ts) => match encoding {
Some(e) => decode_choice_with_encoding(py, parser, ann_type, e.get())?,
None => {
for t in ts.bind(py) {
let variant = t.cast::<Variant>()?.get();
match parser.peek_tag() {
Some(tag) if is_tag_valid_for_variant(py, tag, variant, encoding) => {
let decoded_value =
decode_annotated_type(py, parser, variant.ann_type.get())?;
return match &variant.tag_name {
Some(tag_name) => Ok(variant
.python_class
.call1(py, (decoded_value, tag_name))?
.into_bound(py)),
None => Ok(decoded_value),
};
}
_ => continue,
}
}
Err(CryptographyError::Py(
pyo3::exceptions::PyValueError::new_err(
"could not find matching variant when parsing CHOICE field".to_string(),
),
))?
}
},
Type::PyBool() => decode_pybool(py, parser, encoding)?.into_any(),
Type::PyInt() => decode_pyint(py, parser, encoding)?.into_any(),
Type::PyBytes() => decode_pybytes(py, parser, annotation)?.into_any(),
Expand All @@ -244,3 +313,33 @@ pub(crate) fn decode_annotated_type<'a>(
_ => Ok(decoded),
}
}

#[cfg(test)]
mod tests {
use crate::declarative_asn1::types::{AnnotatedType, Annotation, Encoding, Type, Variant};
#[test]
fn test_decode_implicit_choice() {
pyo3::Python::initialize();
pyo3::Python::attach(|py| {
let result = asn1::parse(&[], |parser| {
let variants: Vec<Variant> = vec![];
let choice = Type::Choice(pyo3::types::PyList::new(py, variants)?.unbind());
let annotation = Annotation {
default: None,
encoding: None,
size: None,
};
let ann_type = AnnotatedType {
inner: pyo3::Py::new(py, choice)?,
annotation: pyo3::Py::new(py, annotation)?,
};
let encoding = Encoding::Implicit(0);
super::decode_choice_with_encoding(py, parser, &ann_type, &encoding)
});
assert!(result.is_err());
let error = result.unwrap_err();
assert!(format!("{error}")
.contains("invalid type definition: CHOICE fields cannot be implicitly encoded"));
});
}
}
Loading