diff --git a/asn1_derive/src/lib.rs b/asn1_derive/src/lib.rs index 89bd2ad3..9b7c219d 100644 --- a/asn1_derive/src/lib.rs +++ b/asn1_derive/src/lib.rs @@ -63,10 +63,25 @@ fn derive_asn1_read_expand(input: syn::DeriveInput) -> syn::Result syn::Result { + for attr in &input.attrs { + if attr.path().is_ident("error_type") { + let err_type: syn::Type = attr.parse_args()?; + return Ok(err_type); + } + } + Err(syn::Error::new_spanned( + input, + "Error type for asn1::Asn1Writable/Asn1DefinedByWritable implementation was not specified", + )) +} + +#[proc_macro_derive( + Asn1Write, + attributes(explicit, implicit, default, defined_by, error_type) +)] pub fn derive_asn1_write(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = syn::parse_macro_input!(input as syn::DeriveInput); - derive_asn1_write_expand(input) .unwrap_or_else(syn::Error::into_compile_error) .into() @@ -75,11 +90,20 @@ pub fn derive_asn1_write(input: proc_macro::TokenStream) -> proc_macro::TokenStr fn derive_asn1_write_expand(mut input: syn::DeriveInput) -> syn::Result { let name = &input.ident; let fields = all_field_types(&input.data, false, &input.generics)?; + let err_type = get_err_attribute(&input)?; add_bounds( + &mut input.generics, + fields.clone(), + syn::parse_quote!(asn1::Asn1Writable), + syn::parse_quote!(asn1::Asn1DefinedByWritable), + true, + ); + add_write_error_bounds( &mut input.generics, fields, syn::parse_quote!(asn1::Asn1Writable), syn::parse_quote!(asn1::Asn1DefinedByWritable), + err_type.clone(), true, ); let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); @@ -89,8 +113,9 @@ fn derive_asn1_write_expand(mut input: syn::DeriveInput) -> syn::Result::TAG; - fn write_data(&self, dest: &mut asn1::WriteBuf) -> asn1::WriteResult { + fn write_data(&self, dest: &mut asn1::WriteBuf) -> Result<(), Self::Error> { #write_block Ok(()) @@ -106,7 +131,8 @@ fn derive_asn1_write_expand(mut input: syn::DeriveInput) -> syn::Result asn1::WriteResult { + type Error = #err_type; + fn write(&self, w: &mut asn1::Writer) -> Result<(), Self::Error> { #write_block } fn encoded_length(&self) -> Option { @@ -115,7 +141,8 @@ fn derive_asn1_write_expand(mut input: syn::DeriveInput) -> syn::Result asn1::WriteResult { + type Error = #err_type; + fn write(&self, w: &mut asn1::Writer) -> Result<(), Self::Error> { (*self).write(w) } @@ -256,7 +283,7 @@ fn derive_asn1_defined_by_read_expand( }) } -#[proc_macro_derive(Asn1DefinedByWrite, attributes(default, defined_by))] +#[proc_macro_derive(Asn1DefinedByWrite, attributes(default, defined_by, error_type))] pub fn derive_asn1_defined_by_write(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = syn::parse_macro_input!(input as syn::DeriveInput); @@ -270,11 +297,20 @@ fn derive_asn1_defined_by_write_expand( ) -> syn::Result { let name = &input.ident; let fields = all_field_types(&input.data, true, &input.generics)?; + let err_type = get_err_attribute(&input)?; add_bounds( + &mut input.generics, + fields.clone(), + syn::parse_quote!(asn1::Asn1Writable), + syn::parse_quote!(asn1::Asn1DefinedByWritable), + true, + ); + add_write_error_bounds( &mut input.generics, fields, syn::parse_quote!(asn1::Asn1Writable), syn::parse_quote!(asn1::Asn1DefinedByWritable), + err_type.clone(), true, ); let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); @@ -329,13 +365,14 @@ fn derive_asn1_defined_by_write_expand( Ok(quote::quote! { impl #impl_generics asn1::Asn1DefinedByWritable for #name #ty_generics #where_clause { + type Error = #err_type; fn item(&self) -> &asn1::ObjectIdentifier { match self { #(#item_blocks)* } } - fn write(&self, w: &mut asn1::Writer) -> asn1::WriteResult { + fn write(&self, w: &mut asn1::Writer) -> Result<(), Self::Error> { match self { #(#write_blocks)* } @@ -604,6 +641,100 @@ fn add_bounds( } } +fn add_write_error_bounds( + generics: &mut syn::Generics, + field_types: Vec<(syn::Type, OpType, bool)>, + bound: syn::TypeParamBound, + defined_by_bound: syn::TypeParamBound, + write_error_type: syn::Type, + add_ref: bool, +) { + let where_clause = if field_types.is_empty() { + return; + } else { + generics + .where_clause + .get_or_insert_with(|| syn::WhereClause { + where_token: Default::default(), + predicates: syn::punctuated::Punctuated::new(), + }) + }; + for (f, op_type, has_default) in &field_types { + let (bounded_ty, lifetimes): (syn::Type, Option) = match ( + op_type, add_ref, + ) { + (OpType::Regular, _) => (syn::parse_quote!(<#f as #bound>::Error), None), + (OpType::DefinedBy(_), _) => { + (syn::parse_quote!(<#f as #defined_by_bound>::Error), None) + } + (OpType::Implicit(OpTypeArgs { value, required }), false) => { + if *required || *has_default { + ( + syn::parse_quote!( as #bound>::Error), + None, + ) + } else { + ( + syn::parse_quote!(::T, #value> as #bound>::Error), + None, + ) + } + } + (OpType::Implicit(OpTypeArgs { value, required }), true) => { + if *required || *has_default { + ( + syn::parse_quote!( as #bound>::Error), + Some(syn::parse_quote!(for<'asn1_internal>)), + ) + } else { + ( + syn::parse_quote!(::T, #value> as #bound>::Error), + Some(syn::parse_quote!(for<'asn1_internal>)), + ) + } + } + (OpType::Explicit(OpTypeArgs { value, required }), false) => { + if *required || *has_default { + ( + syn::parse_quote!( as #bound>::Error), + None, + ) + } else { + ( + syn::parse_quote!(::T, #value> as #bound>::Error), + None, + ) + } + } + (OpType::Explicit(OpTypeArgs { value, required }), true) => { + if *required || *has_default { + ( + syn::parse_quote!( as #bound>::Error), + Some(syn::parse_quote!(for<'asn1_internal>)), + ) + } else { + ( + syn::parse_quote!(::T, #value> as #bound>::Error), + Some(syn::parse_quote!(for<'asn1_internal>)), + ) + } + } + }; + where_clause + .predicates + .push(syn::WherePredicate::Type(syn::PredicateType { + lifetimes, + bounded_ty, + colon_token: Default::default(), + bounds: { + let mut p = syn::punctuated::Punctuated::new(); + p.push(syn::parse_quote!(Into<#write_error_type>)); + p + }, + })); + } +} + #[derive(Clone)] enum OpType { Regular, diff --git a/examples/no_std.rs b/examples/no_std.rs index b8a92678..5d8c4036 100644 --- a/examples/no_std.rs +++ b/examples/no_std.rs @@ -17,11 +17,13 @@ fn main() { }; let computed = asn1::write(|w| { - w.write_element(&asn1::SequenceWriter::new(&|w: &mut asn1::Writer| { - w.write_element(&1i64)?; - w.write_element(&3i64)?; - Ok(()) - })) + w.write_element(&asn1::SequenceWriter::::new( + &|w: &mut asn1::Writer| { + w.write_element(&1i64)?; + w.write_element(&3i64)?; + Ok(()) + }, + )) }) .unwrap(); unsafe { diff --git a/src/types.rs b/src/types.rs index b874b130..c42a0cc8 100644 --- a/src/types.rs +++ b/src/types.rs @@ -12,7 +12,7 @@ use core::mem; use crate::writer::Writer; use crate::{ parse, parse_single, BitString, ObjectIdentifier, OwnedBitString, ParseError, ParseErrorKind, - ParseLocation, ParseResult, Parser, Tag, WriteBuf, WriteResult, + ParseLocation, ParseResult, Parser, Tag, WriteBuf, WriteError, WriteResult, }; /// Any type that can be parsed as DER ASN.1. @@ -67,11 +67,14 @@ impl<'a, T: SimpleAsn1Readable<'a>> SimpleAsn1Readable<'a> for Box { /// Any type that can be written as DER ASN.1. pub trait Asn1Writable: Sized { + /// Error type for `Self::write()` + type Error: From; + /// Write this value to the given writer. /// /// This method should write the complete ASN.1 encoding of this value, /// including the tag, length, and content bytes. - fn write(&self, dest: &mut Writer<'_>) -> WriteResult; + fn write(&self, dest: &mut Writer<'_>) -> Result<(), Self::Error>; /// Get the complete encoded length (tag + length + content), if it can be /// calculated efficiently. @@ -84,6 +87,9 @@ pub trait Asn1Writable: Sized { /// Types with a fixed-tag that can be written as DER ASN.1. pub trait SimpleAsn1Writable: Sized { + /// Error type for `Self::write_data()` + type Error: From; + /// The ASN.1 tag that this type uses when writing. const TAG: Tag; @@ -91,7 +97,7 @@ pub trait SimpleAsn1Writable: Sized { /// /// This method should write only the value bytes (without the tag and /// length) to the buffer. - fn write_data(&self, dest: &mut WriteBuf) -> WriteResult; + fn write_data(&self, dest: &mut WriteBuf) -> Result<(), Self::Error>; /// Get the length of the data content (without tag and length bytes) if it /// can be calculated efficiently. @@ -117,11 +123,14 @@ pub trait Asn1DefinedByReadable<'a, T: Asn1Readable<'a>>: Sized { /// /// `T` is the type of the `DEFINED BY` field (nearly always `ObjectIdentifier`). pub trait Asn1DefinedByWritable: Sized { + /// Error type for `Self::write()` + type Error: From; + /// Get a reference to the `DEFINED BY` value. fn item(&self) -> &T; /// Write this value to the given writer. - fn write(&self, dest: &mut Writer<'_>) -> WriteResult; + fn write(&self, dest: &mut Writer<'_>) -> Result<(), Self::Error>; /// Get the complete encoded length (tag + length + content), if it can be /// calculated efficiently. @@ -133,8 +142,10 @@ pub trait Asn1DefinedByWritable: Sized { } impl Asn1Writable for T { + type Error = T::Error; + #[inline] - fn write(&self, w: &mut Writer<'_>) -> WriteResult { + fn write(&self, w: &mut Writer<'_>) -> Result<(), Self::Error> { w.write_tlv(Self::TAG, self.data_length(), move |dest| { self.write_data(dest) }) @@ -146,8 +157,10 @@ impl Asn1Writable for T { } impl SimpleAsn1Writable for &T { + type Error = T::Error; const TAG: Tag = T::TAG; - fn write_data(&self, dest: &mut WriteBuf) -> WriteResult { + + fn write_data(&self, dest: &mut WriteBuf) -> Result<(), Self::Error> { T::write_data(self, dest) } @@ -157,8 +170,10 @@ impl SimpleAsn1Writable for &T { } impl SimpleAsn1Writable for Box { + type Error = T::Error; const TAG: Tag = T::TAG; - fn write_data(&self, dest: &mut WriteBuf) -> WriteResult { + + fn write_data(&self, dest: &mut WriteBuf) -> Result<(), Self::Error> { T::write_data(self, dest) } @@ -214,6 +229,8 @@ impl<'a> Asn1Readable<'a> for Tlv<'a> { } } impl Asn1Writable for Tlv<'_> { + type Error = WriteError; + #[inline] fn write(&self, w: &mut Writer<'_>) -> WriteResult { w.write_tlv(self.tag, Some(self.data.len()), move |dest| { @@ -227,6 +244,8 @@ impl Asn1Writable for Tlv<'_> { } impl Asn1Writable for &Tlv<'_> { + type Error = WriteError; + #[inline] fn write(&self, w: &mut Writer<'_>) -> WriteResult { Tlv::write(self, w) @@ -254,7 +273,9 @@ impl SimpleAsn1Readable<'_> for Null { } impl SimpleAsn1Writable for Null { + type Error = WriteError; const TAG: Tag = Tag::primitive(0x05); + #[inline] fn write_data(&self, _dest: &mut WriteBuf) -> WriteResult { Ok(()) @@ -277,7 +298,9 @@ impl SimpleAsn1Readable<'_> for bool { } impl SimpleAsn1Writable for bool { + type Error = WriteError; const TAG: Tag = Tag::primitive(0x1); + fn write_data(&self, dest: &mut WriteBuf) -> WriteResult { if *self { dest.push_byte(0xff) @@ -299,7 +322,9 @@ impl<'a> SimpleAsn1Readable<'a> for &'a [u8] { } impl SimpleAsn1Writable for &[u8] { + type Error = WriteError; const TAG: Tag = Tag::primitive(0x04); + fn write_data(&self, dest: &mut WriteBuf) -> WriteResult { dest.push_slice(self) } @@ -318,7 +343,9 @@ impl SimpleAsn1Readable<'_> for [u8; N] { } impl SimpleAsn1Writable for [u8; N] { + type Error = WriteError; const TAG: Tag = Tag::primitive(0x04); + fn write_data(&self, dest: &mut WriteBuf) -> WriteResult { dest.push_slice(self) } @@ -355,8 +382,9 @@ impl<'a, T: Asn1Readable<'a>> SimpleAsn1Readable<'a> for OctetStringEncoded { } impl SimpleAsn1Writable for OctetStringEncoded { + type Error = T::Error; const TAG: Tag = Tag::primitive(0x04); - fn write_data(&self, dest: &mut WriteBuf) -> WriteResult { + fn write_data(&self, dest: &mut WriteBuf) -> Result<(), Self::Error> { self.0.write(&mut Writer::new(dest)) } @@ -430,7 +458,9 @@ impl<'a> SimpleAsn1Readable<'a> for PrintableString<'a> { } impl SimpleAsn1Writable for PrintableString<'_> { + type Error = WriteError; const TAG: Tag = Tag::primitive(0x13); + fn write_data(&self, dest: &mut WriteBuf) -> WriteResult { dest.push_slice(self.0.as_bytes()) } @@ -483,7 +513,9 @@ impl<'a> SimpleAsn1Readable<'a> for IA5String<'a> { } } impl SimpleAsn1Writable for IA5String<'_> { + type Error = WriteError; const TAG: Tag = Tag::primitive(0x16); + fn write_data(&self, dest: &mut WriteBuf) -> WriteResult { dest.push_slice(self.0.as_bytes()) } @@ -520,7 +552,9 @@ impl<'a> SimpleAsn1Readable<'a> for Utf8String<'a> { } } impl SimpleAsn1Writable for Utf8String<'_> { + type Error = WriteError; const TAG: Tag = Tag::primitive(0x0c); + fn write_data(&self, dest: &mut WriteBuf) -> WriteResult { dest.push_slice(self.0.as_bytes()) } @@ -579,7 +613,9 @@ impl<'a> SimpleAsn1Readable<'a> for VisibleString<'a> { } } impl SimpleAsn1Writable for VisibleString<'_> { + type Error = WriteError; const TAG: Tag = Tag::primitive(0x1a); + fn write_data(&self, dest: &mut WriteBuf) -> WriteResult { dest.push_slice(self.0.as_bytes()) } @@ -633,7 +669,9 @@ impl<'a> SimpleAsn1Readable<'a> for BMPString<'a> { } } impl SimpleAsn1Writable for BMPString<'_> { + type Error = WriteError; const TAG: Tag = Tag::primitive(0x1e); + fn write_data(&self, dest: &mut WriteBuf) -> WriteResult { dest.push_slice(self.as_utf16_be_bytes()) } @@ -687,7 +725,9 @@ impl<'a> SimpleAsn1Readable<'a> for UniversalString<'a> { } } impl SimpleAsn1Writable for UniversalString<'_> { + type Error = WriteError; const TAG: Tag = Tag::primitive(0x1c); + fn write_data(&self, dest: &mut WriteBuf) -> WriteResult { dest.push_slice(self.as_utf32_be_bytes()) } @@ -743,8 +783,10 @@ macro_rules! impl_asn1_element_for_int { } } impl SimpleAsn1Writable for $t { + type Error = WriteError; const TAG: Tag = Tag::primitive(0x02); - fn write_data(&self, dest: &mut WriteBuf) -> WriteResult { + + fn write_data(&self, dest: &mut WriteBuf) -> Result<(), Self::Error> { let num_bytes = self.data_length().unwrap() as u32; for i in (1..=num_bytes).rev() { @@ -810,7 +852,9 @@ impl<'a> SimpleAsn1Readable<'a> for BigUint<'a> { } } impl SimpleAsn1Writable for BigUint<'_> { + type Error = WriteError; const TAG: Tag = Tag::primitive(0x02); + fn write_data(&self, dest: &mut WriteBuf) -> WriteResult { dest.push_slice(self.as_bytes()) } @@ -852,7 +896,9 @@ impl SimpleAsn1Readable<'_> for OwnedBigUint { } } impl SimpleAsn1Writable for OwnedBigUint { + type Error = WriteError; const TAG: Tag = Tag::primitive(0x02); + fn write_data(&self, dest: &mut WriteBuf) -> WriteResult { dest.push_slice(self.as_bytes()) } @@ -900,7 +946,9 @@ impl<'a> SimpleAsn1Readable<'a> for BigInt<'a> { } } impl SimpleAsn1Writable for BigInt<'_> { + type Error = WriteError; const TAG: Tag = Tag::primitive(0x02); + fn write_data(&self, dest: &mut WriteBuf) -> WriteResult { dest.push_slice(self.as_bytes()) } @@ -946,7 +994,9 @@ impl SimpleAsn1Readable<'_> for OwnedBigInt { } } impl SimpleAsn1Writable for OwnedBigInt { + type Error = WriteError; const TAG: Tag = Tag::primitive(0x02); + fn write_data(&self, dest: &mut WriteBuf) -> WriteResult { dest.push_slice(self.as_bytes()) } @@ -963,7 +1013,9 @@ impl<'a> SimpleAsn1Readable<'a> for ObjectIdentifier { } } impl SimpleAsn1Writable for ObjectIdentifier { + type Error = WriteError; const TAG: Tag = Tag::primitive(0x06); + fn write_data(&self, dest: &mut WriteBuf) -> WriteResult { dest.push_slice(self.as_der()) } @@ -983,7 +1035,9 @@ impl<'a> SimpleAsn1Readable<'a> for BitString<'a> { } } impl SimpleAsn1Writable for BitString<'_> { + type Error = WriteError; const TAG: Tag = Tag::primitive(0x03); + fn write_data(&self, dest: &mut WriteBuf) -> WriteResult { dest.push_byte(self.padding_bits())?; dest.push_slice(self.as_bytes()) @@ -1001,7 +1055,9 @@ impl<'a> SimpleAsn1Readable<'a> for OwnedBitString { } } impl SimpleAsn1Writable for OwnedBitString { + type Error = WriteError; const TAG: Tag = Tag::primitive(0x03); + fn write_data(&self, dest: &mut WriteBuf) -> WriteResult { self.as_bitstring().write_data(dest) } @@ -1193,7 +1249,9 @@ impl SimpleAsn1Readable<'_> for UtcTime { } impl SimpleAsn1Writable for UtcTime { + type Error = WriteError; const TAG: Tag = Tag::primitive(0x17); + fn write_data(&self, dest: &mut WriteBuf) -> WriteResult { let dt = self.as_datetime(); let year = if 1950 <= dt.year() && dt.year() < 2000 { @@ -1252,7 +1310,9 @@ impl SimpleAsn1Readable<'_> for X509GeneralizedTime { } impl SimpleAsn1Writable for X509GeneralizedTime { + type Error = WriteError; const TAG: Tag = Tag::primitive(0x18); + fn write_data(&self, dest: &mut WriteBuf) -> WriteResult { let dt = self.as_datetime(); push_four_digits(dest, dt.year())?; @@ -1360,7 +1420,9 @@ impl SimpleAsn1Readable<'_> for GeneralizedTime { } impl SimpleAsn1Writable for GeneralizedTime { + type Error = WriteError; const TAG: Tag = Tag::primitive(0x18); + fn write_data(&self, dest: &mut WriteBuf) -> WriteResult { let dt = self.as_datetime(); push_four_digits(dest, dt.year())?; @@ -1426,6 +1488,7 @@ impl<'a> SimpleAsn1Readable<'a> for Enumerated { } impl SimpleAsn1Writable for Enumerated { + type Error = WriteError; const TAG: Tag = Tag::primitive(0xa); fn write_data(&self, dest: &mut WriteBuf) -> WriteResult { @@ -1452,8 +1515,10 @@ impl<'a, T: Asn1Readable<'a>> Asn1Readable<'a> for Option { } impl Asn1Writable for Option { + type Error = T::Error; + #[inline] - fn write(&self, w: &mut Writer<'_>) -> WriteResult { + fn write(&self, w: &mut Writer<'_>) -> Result<(), Self::Error> { if let Some(v) = self { w.write_element(v) } else { @@ -1515,9 +1580,11 @@ macro_rules! declare_choice { impl< $( - $number: Asn1Writable, + $number: Asn1Writable, )+ > Asn1Writable for $count<$($number,)+> { + type Error = WriteError; + fn write(&self, w: &mut Writer<'_>) -> WriteResult { match self { $( @@ -1576,7 +1643,9 @@ impl<'a> SimpleAsn1Readable<'a> for Sequence<'a> { } } impl SimpleAsn1Writable for Sequence<'_> { + type Error = WriteError; const TAG: Tag = Tag::constructed(0x10); + #[inline] fn write_data(&self, data: &mut WriteBuf) -> WriteResult { data.push_slice(self.data) @@ -1589,21 +1658,23 @@ impl SimpleAsn1Writable for Sequence<'_> { /// Writes an ASN.1 `SEQUENCE` using a callback that writes the inner /// elements. -pub struct SequenceWriter<'a> { - f: &'a dyn Fn(&mut Writer<'_>) -> WriteResult, +pub struct SequenceWriter<'a, E: From = WriteError> { + f: &'a dyn Fn(&mut Writer<'_>) -> Result<(), E>, } -impl<'a> SequenceWriter<'a> { +impl<'a, E: From> SequenceWriter<'a, E> { #[inline] - pub fn new(f: &'a dyn Fn(&mut Writer<'_>) -> WriteResult) -> Self { + pub fn new(f: &'a dyn Fn(&mut Writer<'_>) -> Result<(), E>) -> SequenceWriter<'a, E> { SequenceWriter { f } } } -impl SimpleAsn1Writable for SequenceWriter<'_> { +impl> SimpleAsn1Writable for SequenceWriter<'_, E> { + type Error = E; const TAG: Tag = Tag::constructed(0x10); + #[inline] - fn write_data(&self, dest: &mut WriteBuf) -> WriteResult { + fn write_data(&self, dest: &mut WriteBuf) -> Result<(), Self::Error> { (self.f)(&mut Writer::new(dest)) } @@ -1746,8 +1817,10 @@ impl< const MAXIMUM_LEN: usize, > SimpleAsn1Writable for SequenceOf<'a, T, MINIMUM_LEN, MAXIMUM_LEN> { + type Error = T::Error; const TAG: Tag = Tag::constructed(0x10); - fn write_data(&self, dest: &mut WriteBuf) -> WriteResult { + + fn write_data(&self, dest: &mut WriteBuf) -> Result<(), Self::Error> { let mut w = Writer::new(dest); for el in self.clone() { w.write_element(&el)?; @@ -1779,8 +1852,10 @@ impl> SequenceOfWriter<'_, T, V> { } impl> SimpleAsn1Writable for SequenceOfWriter<'_, T, V> { + type Error = T::Error; const TAG: Tag = Tag::constructed(0x10); - fn write_data(&self, dest: &mut WriteBuf) -> WriteResult { + + fn write_data(&self, dest: &mut WriteBuf) -> Result<(), Self::Error> { let mut w = Writer::new(dest); for el in self.vals.borrow() { w.write_element(el)?; @@ -1894,8 +1969,10 @@ impl<'a, T: Asn1Readable<'a>> Iterator for SetOf<'a, T> { } impl<'a, T: Asn1Readable<'a> + Asn1Writable> SimpleAsn1Writable for SetOf<'a, T> { + type Error = T::Error; const TAG: Tag = Tag::constructed(0x11); - fn write_data(&self, dest: &mut WriteBuf) -> WriteResult { + + fn write_data(&self, dest: &mut WriteBuf) -> Result<(), Self::Error> { let mut w = Writer::new(dest); // We are known to be ordered correctly because that's an invariant for // `self`, so we don't need to sort here. @@ -1929,8 +2006,10 @@ impl> SetOfWriter<'_, T, V> { } impl> SimpleAsn1Writable for SetOfWriter<'_, T, V> { + type Error = T::Error; const TAG: Tag = Tag::constructed(0x11); - fn write_data(&self, dest: &mut WriteBuf) -> WriteResult { + + fn write_data(&self, dest: &mut WriteBuf) -> Result<(), Self::Error> { let vals = self.vals.borrow(); if vals.is_empty() { return Ok(()); @@ -2005,9 +2084,10 @@ impl<'a, T: SimpleAsn1Readable<'a>, const TAG: u32> SimpleAsn1Readable<'a> } impl SimpleAsn1Writable for Implicit { + type Error = T::Error; const TAG: Tag = crate::implicit_tag(TAG, T::TAG); - fn write_data(&self, dest: &mut WriteBuf) -> WriteResult { + fn write_data(&self, dest: &mut WriteBuf) -> Result<(), Self::Error> { self.inner.write_data(dest) } @@ -2051,8 +2131,10 @@ impl<'a, T: Asn1Readable<'a>, const TAG: u32> SimpleAsn1Readable<'a> for Explici } impl SimpleAsn1Writable for Explicit { + type Error = T::Error; const TAG: Tag = crate::explicit_tag(TAG); - fn write_data(&self, dest: &mut WriteBuf) -> WriteResult { + + fn write_data(&self, dest: &mut WriteBuf) -> Result<(), Self::Error> { Writer::new(dest).write_element(&self.inner) } fn data_length(&self) -> Option { @@ -2074,10 +2156,12 @@ impl<'a, T: Asn1Readable<'a>, U: Asn1DefinedByReadable<'a, T>, const TAG: u32> impl, const TAG: u32> Asn1DefinedByWritable for Explicit { + type Error = U::Error; + fn item(&self) -> &T { self.as_inner().item() } - fn write(&self, dest: &mut Writer<'_>) -> WriteResult { + fn write(&self, dest: &mut Writer<'_>) -> Result<(), Self::Error> { dest.write_tlv( crate::explicit_tag(TAG), self.as_inner().encoded_length(), @@ -2109,6 +2193,8 @@ impl<'a, T: Asn1Readable<'a>> Asn1Readable<'a> for DefinedByMarker { } impl Asn1Writable for DefinedByMarker { + type Error = WriteError; + fn write(&self, _: &mut Writer<'_>) -> WriteResult { panic!("write() should never be called on a DefinedByMarker") } diff --git a/src/writer.rs b/src/writer.rs index e9d8ca17..e927d52a 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -112,7 +112,7 @@ impl Writer<'_> { /// Writes a single element to the output. #[inline] - pub fn write_element(&mut self, val: &T) -> WriteResult { + pub fn write_element(&mut self, val: &T) -> Result<(), T::Error> { if let Some(len) = val.encoded_length() { self.buf.reserve_additional(len)?; } @@ -126,12 +126,12 @@ impl Writer<'_> { /// If `content_length` is provided, it reduces the number of /// re-allocations required. #[inline] - pub fn write_tlv WriteResult>( + pub fn write_tlv, F: FnOnce(&mut WriteBuf) -> Result<(), E>>( &mut self, tag: Tag, content_length: Option, body: F, - ) -> WriteResult { + ) -> Result<(), E> { tag.write_bytes(self.buf)?; match content_length { @@ -157,7 +157,7 @@ impl Writer<'_> { self.buf.push_byte(0)?; let start_len = self.buf.len(); body(self.buf)?; - self.insert_length(start_len) + Ok(self.insert_length(start_len)?) } } } @@ -181,7 +181,11 @@ impl Writer<'_> { } /// This is an alias for `write_element::>`. - pub fn write_explicit_element(&mut self, val: &T, tag: u32) -> WriteResult { + pub fn write_explicit_element( + &mut self, + val: &T, + tag: u32, + ) -> Result<(), T::Error> { let tag = crate::explicit_tag(tag); self.write_tlv(tag, val.encoded_length(), |dest| { Writer::new(dest).write_element(val) @@ -193,7 +197,7 @@ impl Writer<'_> { &mut self, val: &T, tag: u32, - ) -> WriteResult { + ) -> Result<(), T::Error> { let tag = crate::implicit_tag(tag, T::TAG); self.write_tlv(tag, val.data_length(), |dest| val.write_data(dest)) } @@ -202,7 +206,9 @@ impl Writer<'_> { /// Constructs a writer and invokes a callback which writes ASN.1 elements into /// the writer, then returns the generated DER bytes. #[inline] -pub fn write) -> WriteResult>(f: F) -> WriteResult> { +pub fn write, F: Fn(&mut Writer<'_>) -> Result<(), E>>( + f: F, +) -> Result, E> { let mut v = WriteBuf::new(vec![]); let mut w = Writer::new(&mut v); f(&mut w)?; @@ -212,7 +218,7 @@ pub fn write) -> WriteResult>(f: F) -> WriteResult /// Writes a single top-level ASN.1 element, returning the generated DER bytes. /// Most often this will be used where `T` is a type with /// `#[derive(asn1::Asn1Write)]`. -pub fn write_single(v: &T) -> WriteResult> { +pub fn write_single(v: &T) -> Result, T::Error> { write(|w| w.write_element(v)) } @@ -238,6 +244,7 @@ mod tests { fn assert_writes(data: &[(T, &[u8])]) where T: Asn1Writable, + ::Error: std::fmt::Debug, { for (val, expected) in data { let result = write_single(val).unwrap(); @@ -742,8 +749,10 @@ mod tests { ); assert_eq!( - write(|w| { w.write_implicit_element(&SequenceWriter::new(&|_w| { Ok(()) }), 2) }) - .unwrap(), + write(|w| { + w.write_implicit_element(&SequenceWriter::::new(&|_w| Ok(())), 2) + }) + .unwrap(), b"\xa2\x00" ); } diff --git a/tests/derive_test.rs b/tests/derive_test.rs index 95dddcca..595d1cb5 100644 --- a/tests/derive_test.rs +++ b/tests/derive_test.rs @@ -1,11 +1,10 @@ use std::fmt; -fn assert_roundtrips< - 'a, - T: asn1::Asn1Readable<'a> + asn1::Asn1Writable + PartialEq + fmt::Debug, ->( +fn assert_roundtrips<'a, T: asn1::Asn1Readable<'a> + asn1::Asn1Writable + PartialEq + fmt::Debug>( data: &[(asn1::ParseResult, &'a [u8])], -) { +) where + ::Error: fmt::Debug, +{ for (value, der_bytes) in data { let parsed = asn1::parse_single::(der_bytes); assert_eq!(value, &parsed); @@ -19,6 +18,7 @@ fn assert_roundtrips< #[test] fn test_struct_no_fields() { #[derive(asn1::Asn1Read, asn1::Asn1Write, Debug, PartialEq, Eq)] + #[error_type(asn1::WriteError)] struct NoFields; assert_roundtrips(&[ @@ -33,6 +33,7 @@ fn test_struct_no_fields() { #[test] fn test_struct_simple_fields() { #[derive(asn1::Asn1Read, asn1::Asn1Write, Debug, PartialEq, Eq)] + #[error_type(asn1::WriteError)] struct SimpleFields { a: u64, b: u64, @@ -46,6 +47,7 @@ fn test_struct_simple_fields() { #[test] fn test_tuple_struct_simple_fields() { #[derive(asn1::Asn1Read, asn1::Asn1Write, Debug, PartialEq, Eq)] + #[error_type(asn1::WriteError)] struct SimpleFields(u8, u8); assert_roundtrips(&[(Ok(SimpleFields(2, 3)), b"\x30\x06\x02\x01\x02\x02\x01\x03")]); @@ -54,6 +56,7 @@ fn test_tuple_struct_simple_fields() { #[test] fn test_struct_lifetime() { #[derive(asn1::Asn1Read, asn1::Asn1Write, Debug, PartialEq, Eq)] + #[error_type(asn1::WriteError)] struct Lifetimes<'a> { a: &'a [u8], } @@ -64,6 +67,7 @@ fn test_struct_lifetime() { #[test] fn test_optional() { #[derive(asn1::Asn1Read, asn1::Asn1Write, Debug, PartialEq, Eq)] + #[error_type(asn1::WriteError)] struct OptionalFields { zzz: Option, } @@ -81,9 +85,11 @@ fn test_optional() { #[test] fn test_explicit() { #[derive(asn1::Asn1Read, asn1::Asn1Write, Debug, PartialEq, Eq)] + #[error_type(asn1::WriteError)] struct EmptySequence; #[derive(asn1::Asn1Read, asn1::Asn1Write, Debug, PartialEq, Eq)] + #[error_type(asn1::WriteError)] struct ExplicitFields { #[explicit(5)] a: Option, @@ -120,6 +126,7 @@ fn test_explicit() { #[test] fn test_explicit_tlv() { #[derive(asn1::Asn1Read, asn1::Asn1Write, Debug, PartialEq, Eq)] + #[error_type(asn1::WriteError)] struct ExplicitTlv<'a> { #[explicit(5)] a: Option>, @@ -139,9 +146,11 @@ fn test_explicit_tlv() { #[test] fn test_implicit() { #[derive(asn1::Asn1Read, asn1::Asn1Write, Debug, PartialEq, Eq)] + #[error_type(asn1::WriteError)] struct EmptySequence; #[derive(asn1::Asn1Read, asn1::Asn1Write, Debug, PartialEq, Eq)] + #[error_type(asn1::WriteError)] struct ImplicitFields { #[implicit(5)] a: Option, @@ -178,6 +187,7 @@ fn test_implicit() { #[test] fn test_default() { #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] struct DefaultFields { #[default(13)] a: u8, @@ -238,6 +248,7 @@ fn test_default_not_literal() { const OID2: asn1::ObjectIdentifier = asn1::oid!(1, 2, 3, 4); #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] struct DefaultFields { #[default(OID1)] a: asn1::ObjectIdentifier, @@ -260,6 +271,7 @@ fn test_default_not_literal() { #[test] fn test_default_const_generics() { #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug)] + #[error_type(asn1::WriteError)] struct DefaultFields { #[default(15)] a: asn1::Explicit, @@ -312,6 +324,7 @@ fn test_default_const_generics() { #[test] fn test_default_bool() { #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] struct DefaultField { #[default(false)] a: bool, @@ -334,6 +347,7 @@ fn test_struct_field_types() { // cover their encoded_length implementations. #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] struct TlvField<'a> { t: asn1::Tlv<'a>, } @@ -353,6 +367,7 @@ fn test_struct_field_types() { ]); #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] struct ChoiceFields<'a> { c1: asn1::Choice1<&'a [u8]>, c2: asn1::Choice2, @@ -375,6 +390,7 @@ fn test_struct_field_types() { ]); #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] struct LongField<'a> { f: &'a [u8], } @@ -389,6 +405,7 @@ fn test_struct_field_types() { #[test] fn test_enum() { #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] enum BasicChoice { A(u64), B(()), @@ -415,6 +432,7 @@ fn test_enum() { #[test] fn test_enum_lifetimes() { #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] enum LifetimesChoice<'a> { A(u64), B(&'a [u8]), @@ -441,6 +459,7 @@ fn test_enum_lifetimes() { #[test] fn test_enum_explicit() { #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] enum ExplicitChoice<'a> { #[explicit(5)] A(u64), @@ -468,9 +487,11 @@ fn test_enum_explicit() { #[test] fn test_enum_implicit() { #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] struct EmptySequence; #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] enum ImplicitChoice<'a> { #[implicit(5)] A(u64), @@ -502,11 +523,13 @@ fn test_enum_implicit() { #[test] fn test_enum_in_explicit() { #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] enum BasicChoice { A(u64), } #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] struct StructWithExplicitChoice { #[explicit(0)] c: Option, @@ -526,14 +549,17 @@ fn test_enum_in_explicit() { #[test] fn test_error_parse_location() { #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] struct InnerSeq(u64); #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] enum InnerEnum { Int(u64), } #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] struct OuterSeq { inner: InnerSeq, inner_enum: Option, @@ -558,6 +584,7 @@ fn test_error_parse_location() { #[test] fn test_required_implicit() { #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] struct RequiredImplicit { #[implicit(0, required)] value: u8, @@ -585,6 +612,7 @@ fn test_required_implicit() { #[test] fn test_required_explicit() { #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] struct RequiredExplicit { #[explicit(0, required)] value: u8, @@ -618,6 +646,7 @@ fn test_defined_by() { const OID2: asn1::ObjectIdentifier = asn1::oid!(1, 2, 5); #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] struct S<'a> { oid: asn1::DefinedByMarker, #[defined_by(oid)] @@ -625,6 +654,7 @@ fn test_defined_by() { } #[derive(asn1::Asn1DefinedByRead, asn1::Asn1DefinedByWrite, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] enum Value<'a> { #[defined_by(OID1)] OctetString(&'a [u8]), @@ -663,6 +693,7 @@ fn test_defined_by_default() { const OID2: asn1::ObjectIdentifier = asn1::oid!(1, 2, 5); #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] struct S<'a> { oid: asn1::DefinedByMarker, #[defined_by(oid)] @@ -670,6 +701,7 @@ fn test_defined_by_default() { } #[derive(asn1::Asn1DefinedByRead, asn1::Asn1DefinedByWrite, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] enum Value<'a> { #[defined_by(OID1)] Integer(u32), @@ -701,6 +733,7 @@ fn test_defined_by_optional() { const OID2: asn1::ObjectIdentifier = asn1::oid!(1, 2, 5); #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] struct S<'a> { oid: asn1::DefinedByMarker, #[defined_by(oid)] @@ -708,6 +741,7 @@ fn test_defined_by_optional() { } #[derive(asn1::Asn1DefinedByRead, asn1::Asn1DefinedByWrite, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] enum Value<'a> { #[defined_by(OID1)] OctetString(&'a [u8]), @@ -747,6 +781,7 @@ fn test_defined_by_mod() { } #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] struct S<'a> { oid: asn1::DefinedByMarker, #[defined_by(oid)] @@ -754,6 +789,7 @@ fn test_defined_by_mod() { } #[derive(asn1::Asn1DefinedByRead, asn1::Asn1DefinedByWrite, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] enum Value<'a> { #[defined_by(oids::OID1)] OctetString(&'a [u8]), @@ -782,6 +818,7 @@ fn test_defined_by_explicit() { pub const OID1: asn1::ObjectIdentifier = asn1::oid!(1, 2, 3); #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] struct S<'a> { oid: asn1::DefinedByMarker, #[defined_by(oid)] @@ -789,6 +826,7 @@ fn test_defined_by_explicit() { } #[derive(asn1::Asn1DefinedByRead, asn1::Asn1DefinedByWrite, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] enum Value<'a> { #[defined_by(OID1)] OctetString(&'a [u8]), @@ -806,6 +844,7 @@ fn test_defined_by_explicit() { #[test] fn test_generics() { #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] struct S { value: T, } @@ -835,6 +874,7 @@ fn test_perfect_derive() { } #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] struct S { value: T::Type, } @@ -842,6 +882,7 @@ fn test_perfect_derive() { assert_roundtrips::>(&[(Ok(S { value: 12 }), b"\x30\x03\x02\x01\x0c")]); #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] struct TaggedRequiredFields { #[implicit(1, required)] a: T::Type, @@ -855,6 +896,7 @@ fn test_perfect_derive() { )]); #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] struct TaggedOptionalFields { #[implicit(1)] a: Option, @@ -874,6 +916,7 @@ fn test_perfect_derive() { ]); #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)] + #[error_type(asn1::WriteError)] enum TaggedEnum { #[implicit(0)] Implicit(T::Type), @@ -900,6 +943,7 @@ fn test_defined_by_perfect_derive() { } #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug)] + #[error_type(asn1::WriteError)] struct S { oid: asn1::DefinedByMarker, #[defined_by(oid)] @@ -910,6 +954,7 @@ fn test_defined_by_perfect_derive() { pub const OID2: asn1::ObjectIdentifier = asn1::oid!(1, 2, 4); #[derive(asn1::Asn1DefinedByRead, asn1::Asn1DefinedByWrite, PartialEq, Debug)] + #[error_type(asn1::WriteError)] enum Value { #[defined_by(OID1)] A(T::Type), diff --git a/tests/roundtrip_tests.rs b/tests/roundtrip_tests.rs index 79980565..3ef9a46e 100644 --- a/tests/roundtrip_tests.rs +++ b/tests/roundtrip_tests.rs @@ -1,6 +1,9 @@ fn assert_roundtrips(i: T) where - for<'a> T: asn1::Asn1Writable + asn1::Asn1Readable<'a> + std::fmt::Debug + PartialEq, + for<'a> T: asn1::Asn1Writable + + asn1::Asn1Readable<'a> + + std::fmt::Debug + + PartialEq, { let result = asn1::write_single::(&i).unwrap(); let parsed = asn1::parse_single::(&result).unwrap();