Skip to content

Commit a4d382e

Browse files
committed
asn1: Add support for CHOICE fields
Signed-off-by: Facundo Tuesca <facundo.tuesca@trailofbits.com>
1 parent d59470e commit a4d382e

File tree

9 files changed

+680
-64
lines changed

9 files changed

+680
-64
lines changed

src/cryptography/hazmat/asn1/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
PrintableString,
1313
Size,
1414
UtcTime,
15+
Variant,
1516
decode_der,
1617
encode_der,
1718
sequence,
@@ -27,6 +28,7 @@
2728
"PrintableString",
2829
"Size",
2930
"UtcTime",
31+
"Variant",
3032
"decode_der",
3133
"encode_der",
3234
"sequence",

src/cryptography/hazmat/asn1/asn1.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,30 @@
4141

4242
T = typing.TypeVar("T", covariant=True)
4343
U = typing.TypeVar("U")
44+
Tag = typing.TypeVar("Tag")
45+
46+
47+
@dataclasses.dataclass(frozen=True)
48+
class Variant(typing.Generic[U, Tag]):
49+
"""
50+
A tagged variant for CHOICE fields with the same underlying type.
51+
52+
Use this when you have multiple CHOICE alternatives with the same type
53+
and need to distinguish between them:
54+
55+
foo: (
56+
Annotated[Variant[int, "IntA"], Implicit(0)]
57+
| Annotated[Variant[int, "IntB"], Implicit(1)]
58+
)
59+
60+
Usage:
61+
example = Example(foo=Variant(5, "IntA"))
62+
decoded.foo.value # The int value
63+
decoded.foo.tag # "IntA" or "IntB"
64+
"""
65+
66+
value: U
67+
tag: str
4468

4569

4670
decode_der = declarative_asn1.decode_der
@@ -150,10 +174,31 @@ def _normalize_field_type(
150174
)
151175

152176
rust_field_type = declarative_asn1.Type.Option(annotated_type)
177+
153178
else:
154-
raise TypeError(
155-
"union types other than `X | None` are currently not supported"
179+
# Otherwise, the Union is a CHOICE
180+
if isinstance(annotation.encoding, Implicit):
181+
raise TypeError(
182+
"CHOICE (`X | Y | ...`) types should not have an IMPLICIT "
183+
"annotation"
184+
)
185+
variants = [
186+
_type_to_variant(arg, field_name)
187+
for arg in union_args
188+
if arg is not type(None)
189+
]
190+
rust_choice_type = declarative_asn1.Type.Choice(variants)
191+
# If None is part of the union types, this is an OPTIONAL CHOICE
192+
rust_field_type = (
193+
declarative_asn1.Type.Option(
194+
declarative_asn1.AnnotatedType(
195+
rust_choice_type, declarative_asn1.Annotation()
196+
)
197+
)
198+
if NoneType in union_args
199+
else rust_choice_type
156200
)
201+
157202
elif get_type_origin(field_type) is builtins.list:
158203
inner_type = _normalize_field_type(
159204
get_type_args(field_type)[0], field_name
@@ -165,6 +210,45 @@ def _normalize_field_type(
165210
return declarative_asn1.AnnotatedType(rust_field_type, annotation)
166211

167212

213+
# Convert a type to a Variant. Used with types inside Union
214+
# annotations (T1, T2, etc in `Union[T1, T2, ...]`).
215+
def _type_to_variant(
216+
t: typing.Any, field_name: str
217+
) -> declarative_asn1.Variant:
218+
is_annotated = get_type_origin(t) is Annotated
219+
inner_type = get_type_args(t)[0] if is_annotated else t
220+
221+
# Check if this is a Variant[T, Tag] type
222+
if get_type_origin(inner_type) is Variant:
223+
value_type, tag_literal = get_type_args(inner_type)
224+
tag_name = get_type_args(tag_literal)[0]
225+
226+
if hasattr(value_type, "__asn1_root__"):
227+
rust_type = value_type.__asn1_root__.inner
228+
else:
229+
rust_type = declarative_asn1.non_root_python_to_rust(value_type)
230+
231+
if is_annotated:
232+
ann_type = declarative_asn1.AnnotatedType(
233+
rust_type,
234+
_extract_annotation(t.__metadata__, field_name),
235+
)
236+
else:
237+
ann_type = declarative_asn1.AnnotatedType(
238+
rust_type,
239+
declarative_asn1.Annotation(),
240+
)
241+
242+
return declarative_asn1.Variant(Variant, ann_type, tag_name)
243+
else:
244+
# Plain type (not a tagged Variant)
245+
return declarative_asn1.Variant(
246+
inner_type,
247+
_normalize_field_type(t, field_name),
248+
None,
249+
)
250+
251+
168252
def _annotate_fields(
169253
raw_fields: dict[str, type],
170254
) -> dict[str, declarative_asn1.AnnotatedType]:

src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class Type:
1515
Sequence: typing.ClassVar[type]
1616
SequenceOf: typing.ClassVar[type]
1717
Option: typing.ClassVar[type]
18+
Choice: typing.ClassVar[type]
1819
PyBool: typing.ClassVar[type]
1920
PyInt: typing.ClassVar[type]
2021
PyBytes: typing.ClassVar[type]
@@ -60,6 +61,18 @@ class AnnotatedTypeObject:
6061
cls, annotated_type: AnnotatedType, value: typing.Any
6162
) -> AnnotatedTypeObject: ...
6263

64+
class Variant:
65+
python_class: type
66+
ann_type: AnnotatedType
67+
tag_name: str | None
68+
69+
def __new__(
70+
cls,
71+
python_class: type,
72+
ann_type: AnnotatedType,
73+
tag_name: str | None,
74+
) -> Variant: ...
75+
6376
class PrintableString:
6477
def __new__(cls, inner: str) -> PrintableString: ...
6578
def __repr__(self) -> str: ...

src/rust/src/declarative_asn1/decode.rs

Lines changed: 107 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ use pyo3::types::{PyAnyMethods, PyListMethods};
77

88
use crate::asn1::big_byte_slice_to_py_int;
99
use crate::declarative_asn1::types::{
10-
check_size_constraint, type_to_tag, AnnotatedType, Annotation, BitString, Encoding,
11-
GeneralizedTime, IA5String, PrintableString, Type, UtcTime,
10+
check_size_constraint, expected_tags_for_type, expected_tags_for_variant, AnnotatedType,
11+
Annotation, BitString, Encoding, GeneralizedTime, IA5String, PrintableString, Type, UtcTime,
12+
Variant,
1213
};
1314
use crate::error::CryptographyError;
1415

@@ -160,6 +161,46 @@ fn decode_bitstring<'a>(
160161
)?)
161162
}
162163

164+
// Utility function to handle explicit encoding when parsing
165+
// CHOICE fields.
166+
fn decode_choice_with_encoding<'a>(
167+
py: pyo3::Python<'a>,
168+
parser: &mut Parser<'a>,
169+
ann_type: &AnnotatedType,
170+
encoding: &Encoding,
171+
) -> ParseResult<pyo3::Bound<'a, pyo3::PyAny>> {
172+
match encoding {
173+
Encoding::Implicit(_) => Err(CryptographyError::Py(
174+
pyo3::exceptions::PyValueError::new_err(
175+
"invalid type definition: CHOICE fields cannot be implicitly encoded".to_string(),
176+
),
177+
))?,
178+
Encoding::Explicit(n) => {
179+
// Since we don't know which of the variants is present for this
180+
// CHOICE field, we'll parse this as a generic TLV encoded with
181+
// EXPLICIT, so `read_explicit_element` will consume the EXPLICIT
182+
// wrapper tag, and the TLV data will contain the variant.
183+
let tlv = parser.read_explicit_element::<asn1::Tlv<'_>>(*n)?;
184+
let type_without_explicit = AnnotatedType {
185+
inner: ann_type.inner.clone_ref(py),
186+
annotation: pyo3::Py::new(
187+
py,
188+
Annotation {
189+
default: None,
190+
encoding: None,
191+
size: None,
192+
},
193+
)?,
194+
};
195+
// Parse the TLV data (which contains the field without the EXPLICIT
196+
// wrapper)
197+
asn1::parse(tlv.full_data(), |d| {
198+
decode_annotated_type(py, d, &type_without_explicit)
199+
})
200+
}
201+
}
202+
}
203+
163204
pub(crate) fn decode_annotated_type<'a>(
164205
py: pyo3::Python<'a>,
165206
parser: &mut Parser<'a>,
@@ -172,10 +213,10 @@ pub(crate) fn decode_annotated_type<'a>(
172213
// Handle DEFAULT annotation if field is not present (by
173214
// returning the default value)
174215
if let Some(default) = &ann_type.annotation.get().default {
175-
let expected_tag = type_to_tag(inner, encoding);
176-
let next_tag = parser.peek_tag();
177-
if next_tag != Some(expected_tag) {
178-
return Ok(default.clone_ref(py).into_bound(py));
216+
let expected_tags = expected_tags_for_type(py, inner, encoding);
217+
match parser.peek_tag() {
218+
Some(next_tag) if expected_tags.contains(&next_tag) => (),
219+
_ => return Ok(default.clone_ref(py).into_bound(py)),
179220
}
180221
}
181222

@@ -210,9 +251,9 @@ pub(crate) fn decode_annotated_type<'a>(
210251
})?
211252
}
212253
Type::Option(cls) => {
213-
let inner_tag = type_to_tag(cls.get().inner.get(), encoding);
254+
let expected_tags = expected_tags_for_type(py, cls.get().inner.get(), encoding);
214255
match parser.peek_tag() {
215-
Some(t) if t == inner_tag => {
256+
Some(t) if expected_tags.contains(&t) => {
216257
// For optional types, annotations will always be associated to the `Optional` type
217258
// i.e: `Annotated[Optional[T], annotation]`, as opposed to the inner `T` type.
218259
// Therefore, when decoding the inner type `T` we must pass the annotation of the `Optional`
@@ -225,6 +266,34 @@ pub(crate) fn decode_annotated_type<'a>(
225266
_ => pyo3::types::PyNone::get(py).to_owned().into_any(),
226267
}
227268
}
269+
Type::Choice(ts) => match encoding {
270+
Some(e) => decode_choice_with_encoding(py, parser, ann_type, e.get())?,
271+
None => {
272+
for t in ts.bind(py) {
273+
let variant = t.cast::<Variant>()?.get();
274+
let expected_tags = expected_tags_for_variant(py, variant);
275+
match parser.peek_tag() {
276+
Some(tag) if expected_tags.contains(&tag) => {
277+
let decoded_value =
278+
decode_annotated_type(py, parser, variant.ann_type.get())?;
279+
return match &variant.tag_name {
280+
Some(tag_name) => Ok(variant
281+
.python_class
282+
.call1(py, (decoded_value, tag_name))?
283+
.into_bound(py)),
284+
None => Ok(decoded_value),
285+
};
286+
}
287+
_ => continue,
288+
}
289+
}
290+
Err(CryptographyError::Py(
291+
pyo3::exceptions::PyValueError::new_err(
292+
"could not find matching variant when parsing CHOICE field".to_string(),
293+
),
294+
))?
295+
}
296+
},
228297
Type::PyBool() => decode_pybool(py, parser, encoding)?.into_any(),
229298
Type::PyInt() => decode_pyint(py, parser, encoding)?.into_any(),
230299
Type::PyBytes() => decode_pybytes(py, parser, annotation)?.into_any(),
@@ -246,3 +315,33 @@ pub(crate) fn decode_annotated_type<'a>(
246315
_ => Ok(decoded),
247316
}
248317
}
318+
319+
#[cfg(test)]
320+
mod tests {
321+
use crate::declarative_asn1::types::{AnnotatedType, Annotation, Encoding, Type, Variant};
322+
#[test]
323+
fn test_decode_implicit_choice() {
324+
pyo3::Python::initialize();
325+
pyo3::Python::attach(|py| {
326+
let result = asn1::parse(&[], |parser| {
327+
let variants: Vec<Variant> = vec![];
328+
let choice = Type::Choice(pyo3::types::PyList::new(py, variants)?.unbind());
329+
let annotation = Annotation {
330+
default: None,
331+
encoding: None,
332+
size: None,
333+
};
334+
let ann_type = AnnotatedType {
335+
inner: pyo3::Py::new(py, choice)?,
336+
annotation: pyo3::Py::new(py, annotation)?,
337+
};
338+
let encoding = Encoding::Implicit(0);
339+
super::decode_choice_with_encoding(py, parser, &ann_type, &encoding)
340+
});
341+
assert!(result.is_err());
342+
let error = result.unwrap_err();
343+
assert!(format!("{error}")
344+
.contains("invalid type definition: CHOICE fields cannot be implicitly encoded"));
345+
});
346+
}
347+
}

0 commit comments

Comments
 (0)